diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..9d622b9 --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +# Run the following command to reformat a file: +# clang-format -i -style=Google +# Or use clang-format-diff to only reformat the changed lines: +# https://clang.llvm.org/docs/ClangFormat.html +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 100 +PointerAlignment: Left diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 0000000..87bbc4a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,43 @@ +--- +name: "🐛 Bug Report" +about: Submit a bug report to help us improve MLC-LLM +title: '[Bug] ' +labels: ['bug'] +assignees: '' + +--- + +## 🐛 Bug + + + +## To Reproduce + +Steps to reproduce the behavior: + +1. +1. +1. + + + +## Expected behavior + + + +## Environment + + - Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): + - Operating system (e.g. Ubuntu/Windows/MacOS/...): + - Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...) + - How you installed MLC-LLM (`conda`, source): + - How you installed TVM-Unity (`pip`, source): + - Python version (e.g. 3.10): + - GPU driver version (if applicable): + - CUDA/cuDNN version (if applicable): + - TVM Unity Hash Tag (`python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"`, applicable if you compile models): + - Any other relevant information: + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..6e88ecf --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,9 @@ +blank_issues_enabled: false + +contact_links: + - name: Check the MLC-LLM Documentation + url: https://llm.mlc.ai/docs/ + about: Our documentation might provide answers to your questions. + - name: Chat on Discord + url: https://discord.gg/9Xpy2HGBuD + about: Join the Discord Server to live chat with the community. diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md new file mode 100644 index 0000000..58aba64 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.md @@ -0,0 +1,17 @@ +--- +name: "\U0001F4DA Documentation" +about: Report an issue related to https://llm.mlc.ai/docs/ +title: '[Doc] ' +labels: ['documentation'] +assignees: '' + +--- + +## 📚 Documentation + +### Suggestion + + +### Bug +- Link to the buggy documentation/tutorial: +- Description of the bug: diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000..5d92d35 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,23 @@ +--- +name: "\U0001F680 Feature Request" +about: Submit a proposal/request for a new MLC-LLM feature, or an enhancement on existing features. +title: '[Feature Request] ' +labels: ['feature request'] +assignees: '' + +--- + +## 🚀 Feature + + +## Motivation + + + +## Alternatives + + + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/general.md b/.github/ISSUE_TEMPLATE/general.md new file mode 100644 index 0000000..f441937 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/general.md @@ -0,0 +1,13 @@ +--- +name: "❓ General Questions" +about: General questions you have about MLC-LLM. +title: '[Question] ' +labels: ['question'] +assignees: '' + +--- + +## ❓ General Questions + + + diff --git a/.github/ISSUE_TEMPLATE/model-request.md b/.github/ISSUE_TEMPLATE/model-request.md new file mode 100644 index 0000000..fb48ce5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/model-request.md @@ -0,0 +1,17 @@ +--- +name: "️️⚙️ Model Request" +about: Request a new model in MLC-LLM +title: '[Model Request] ' +labels: ['new-models'] +assignees: '' + +--- + +## ⚙️ Request New Models + +- Link to an existing implementation (e.g. Hugging Face/Github): +- Is this model architecture supported by MLC-LLM? (the list of [supported models](https://llm.mlc.ai/docs/prebuilt_models.html)) + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/speed-report.md b/.github/ISSUE_TEMPLATE/speed-report.md new file mode 100644 index 0000000..d84a41b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/speed-report.md @@ -0,0 +1,24 @@ +--- +name: " 🏎️ Speed Report" +about: Submit a speed report of an model running in MLC-LLM +title: '[Speed] ' +labels: ['performance'] +assignees: '' + +--- + +# 🏎️ Speed Report + + + +- The model code: + + +- The model configuration (e.g. quantization mode, running data type, etc.): +- Device (e.g. MacBook Pro M2, PC+RTX 3080): +- OS (if applicable): +- Encode speed (Token/s): +- Decode speed (Token/s): +- Memory usage (if applicable): + + \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/tracking.md b/.github/ISSUE_TEMPLATE/tracking.md new file mode 100644 index 0000000..d84b745 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/tracking.md @@ -0,0 +1,40 @@ +--- +name: "Tracking" +about: A tracking issue that tracks ongoing item in the project +title: '[Tracking] ' +labels: ['status: tracking'] +assignees: '' + +--- + + + + +## Overview + + + + +## Action Items + + +- [ ] + + +## Links to Related Issues and PRs + + + diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml new file mode 100644 index 0000000..644df9c --- /dev/null +++ b/.github/workflows/documentation.yml @@ -0,0 +1,39 @@ +name: Build Docs + +on: + push: + branches: + - main + +jobs: + test_linux: + name: Deploy Docs + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + + - name: Configuring build Environment + run: | + sudo apt-get update + python -m pip install -U pip wheel + + - name: Setup Ruby + uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.0' + + - name: Installing dependencies + run: | + python -m pip install -r docs/requirements.txt + gem install jekyll jekyll-remote-theme + + - name: Deploying on GitHub Pages + if: github.ref == 'refs/heads/main' + run: | + git remote set-url origin https://x-access-token:${{ secrets.MLC_GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY + git config --global user.email "mlc-gh-actions-bot@nomail" + git config --global user.name "mlc-gh-actions-bot" + ./scripts/gh_deploy_site.sh diff --git a/.github/workflows/update-relax.yml b/.github/workflows/update-relax.yml new file mode 100644 index 0000000..ccd5dcb --- /dev/null +++ b/.github/workflows/update-relax.yml @@ -0,0 +1,32 @@ +name: 'Relax Submodule Sync' + +on: + workflow_dispatch: + +jobs: + sync: + name: 'Relax Submodule Sync' + runs-on: ubuntu-latest + + defaults: + run: + shell: bash + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + + - name: Git Sumbodule Update + run: | + git submodule update --remote 3rdparty/tvm + + - name: Commit update + env: + GITHUB_TOKEN: ${{ secrets.MLC_GITHUB_TOKEN }} + run: | + git config --global user.name 'Git bot' + git config --global user.email 'bot@noreply.github.com' + git remote set-url origin https://$GITHUB_TOKEN@github.com/mlc-ai/mlc-llm + git commit -am "Auto updated submodule references" && git push || echo "No changes to commit" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9454e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,324 @@ +tmp/ +dist/ +params/ +debug/ +*.bak +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +.DS_Store + +*.S +# C extensions +*.so + +build/ + +*.ll +.npm +# Distribution / packaging +.Python +env/ +build/ +build-*/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +.conda/ +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Generated by python/gen_requirements.py +python/requirements/*.txt + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/_staging/ + +# PyBuilder +target/ +/target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject +*~ +*.pyc +*~ +config.mk +config.cmake +Win32 +*.dir +perf +*.wasm +.emscripten + +## IOS +DerivedData/ + +## Java +*.class +jvm/*/target/ +jvm/*/*/target/ +jvm/native/*/generated +jvm/native/src/main/native/org_apache_tvm_native_c_api.h +*.worksheet +*.idea +*.iml +*.classpath +*.project +*.settings +*/node_modules/ + +## Various settings +*.pbxuser +!default.pbxuser +*.mode1v3 +!default.mode1v3 +*.mode2v3 +!default.mode2v3 +*.perspectivev3 +!default.perspectivev3 +xcuserdata/ +.pkl_memoize_* + +.emscripten* +.m2 + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +## Other +*.moved-aside +*.xccheckout +*.xcscmblueprint +.DS_Store +tags +cscope* +*.lock + +# vim temporary files +*.swp +*.swo + +# TVM generated code +perf +.bash_history +# *.json +*.params +*.ro +*.onnx +*.h5 +synset.txt +cat.jpg +cat.png +docs.tgz +cat.png +*.mlmodel +tvm_u.* +tvm_t.* +# Mac OS X +.DS_Store + +# Jetbrain +.idea +.ipython +.jupyter +.nv +.pylint.d +.python_history +.pytest_cache +.local +cmake-build-debug + +# Visual Studio +.vs + +# Visual Studio Code +.vscode + +# tmp file +.nfs* + +# keys +*.pem +*.p12 +*.pfx +*.cer +*.crt +*.der + +# patch sentinel +patched.txt + +# Python type checking +.mypy_cache/ +.pyre/ + +# pipenv files +Pipfile +Pipfile.lock + +# conda package artifacts +conda/Dockerfile.cuda* +conda/pkg +.node_repl_history +# nix files +.envrc +*.nix + +# Docker files +.sudo_as_admin_successful + +# Downloaded models/datasets +.tvm_test_data +.dgl +.caffe2 + +# Local docs build +_docs/ +jvm/target +.config/configstore/ +.ci-py-scripts/ + +# Generated Hexagon files +src/runtime/hexagon/rpc/hexagon_rpc.h +src/runtime/hexagon/rpc/hexagon_rpc_skel.c +src/runtime/hexagon/rpc/hexagon_rpc_stub.c + +# Local tvm-site checkout +tvm-site/ + +# Generated docs files +gallery/how_to/work_with_microtvm/micro_tvmc.py + +# Test sample data files +!tests/python/ci/sample_prs/*.json + +# Used in CI to communicate between Python and Jenkins +.docker-image-names/ + +# Printed TIR code on disk +*.tir + +# GDB history file +.gdb_history + +dist diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..10ef4b2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "3rdparty/argparse"] + path = 3rdparty/argparse + url = https://github.com/p-ranav/argparse +[submodule "3rdparty/tokenizers-cpp"] + path = 3rdparty/tokenizers-cpp + url = https://github.com/mlc-ai/tokenizers-cpp +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://github.com/google/googletest.git +[submodule "3rdparty/tvm"] + path = 3rdparty/tvm + url = https://github.com/mlc-ai/relax.git diff --git a/0001-Add-events-timing-to-MLCChat.patch b/0001-Add-events-timing-to-MLCChat.patch new file mode 100644 index 0000000..01aad40 --- /dev/null +++ b/0001-Add-events-timing-to-MLCChat.patch @@ -0,0 +1,145 @@ +From c9950224e21153b59d0e610ab06bd5bfedf98a26 Mon Sep 17 00:00:00 2001 +From: Stefanos Laskaridis +Date: Sun, 3 Mar 2024 17:18:06 +0000 +Subject: [PATCH] Add events timing to MLCChat++ + +--- + python/mlc_chat/chat_module.py | 24 +++++++++++++++++++++++- + python/mlc_chat/cli/chat.py | 7 +++++++ + python/mlc_chat/interface/chat.py | 6 +++++- + 3 files changed, 35 insertions(+), 2 deletions(-) + +diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py +index 62ca0135..79de7756 100644 +--- a/python/mlc_chat/chat_module.py ++++ b/python/mlc_chat/chat_module.py +@@ -7,6 +7,7 @@ import json + import os + import subprocess + import sys ++import time + import warnings + from dataclasses import asdict, dataclass, fields + from enum import Enum +@@ -719,6 +720,9 @@ class ChatModule: # pylint: disable=too-many-instance-attributes + device_type = self.device.device_type + device_id = self.device.device_id + ++ self.energy_events = {} ++ self.generate_counter = 0 ++ + # 1. Populate chat module and their functions + fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create") + assert fcreate_chat_mod is not None +@@ -844,23 +848,35 @@ class ChatModule: # pylint: disable=too-many-instance-attributes + num_return_sequences = generation_config.n + return_str = False + +- for _ in range(num_return_sequences): ++ for idx in range(num_return_sequences): + if stateless: + self.reset_chat() ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.start"] = time.time_ns() + self._prefill(prompt, generation_config=generation_config) ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.end"] = time.time_ns() + + if not progress_callback: ++ decode_counter = 0 + while not self._stopped(): ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns() + self._decode(generation_config=generation_config) ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns() ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns() + new_msg = self._get_message() ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns() + new_msgs.append(new_msg) + else: + # apply callback with a rate of callback_interval + i, new_msg = 0, "" ++ decode_counter = 0 + while not self._stopped(): ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns() + self._decode(generation_config=generation_config) ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns() + if i % progress_callback.callback_interval == 0 or self._stopped(): ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns() + new_msg = self._get_message() ++ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns() + progress_callback(new_msg) + i += 1 + progress_callback(stopped=True) +@@ -999,11 +1015,15 @@ class ChatModule: # pylint: disable=too-many-instance-attributes + app_config_json: str + The partial config that is used to partially override the model configuration. + """ ++ self.energy_events[f"load_model.start"] = time.time_ns() + self._reload_func(lib, model_path, app_config_json) ++ self.energy_events[f"load_model.end"] = time.time_ns() + + def _unload(self): + r"""Unload the chat module and clear memory of all loaded models.""" ++ self.energy_events[f"unload_model.start"] = time.time_ns() + self._unload_func() ++ self.energy_events[f"unload_model.end"] = time.time_ns() + + def _prefill( + self, +@@ -1209,4 +1229,6 @@ class ChatModule: # pylint: disable=too-many-instance-attributes + + def _process_system_prompts(self): + r"""Pre-process by prefilling the system prompts, running prior to any user input.""" ++ self.energy_events["prompt.system.start"] = time.time_ns() + self._process_system_prompts_func() ++ self.energy_events["prompt.system.end"] = time.time_ns() +diff --git a/python/mlc_chat/cli/chat.py b/python/mlc_chat/cli/chat.py +index 7ec6efb2..96edef2d 100644 +--- a/python/mlc_chat/cli/chat.py ++++ b/python/mlc_chat/cli/chat.py +@@ -37,6 +37,12 @@ def main(argv): + default=None, + help=HELP["model_lib_path"] + ' (default: "%(default)s")', + ) ++ parser.add_argument( ++ "--energy-events", ++ type=str, ++ default="energy_events.txt", ++ help="Energy events file to use for energy profiling (default: energy_events.txt)" ++ ) + parsed = parser.parse_args(argv) + chat( + model=parsed.model, +@@ -44,4 +50,5 @@ def main(argv): + opt=parsed.opt, + overrides=parsed.overrides, + model_lib_path=parsed.model_lib_path, ++ energy_events_filename=parsed.energy_events, + ) +diff --git a/python/mlc_chat/interface/chat.py b/python/mlc_chat/interface/chat.py +index cd473f79..3d23df40 100644 +--- a/python/mlc_chat/interface/chat.py ++++ b/python/mlc_chat/interface/chat.py +@@ -122,6 +122,7 @@ def chat( + opt: str, + overrides: ChatConfigOverride, + model_lib_path: Optional[str], ++ energy_events_filename: str, + ): + """chat with a model.""" + # Set up chat config and generate config +@@ -146,9 +147,12 @@ def chat( + if prompt[:6] == "/reset": + cm.reset_chat() + elif prompt[:5] == "/exit": ++ with open(energy_events_filename, 'w', encoding='utf-8') as f: ++ for event_key, event_value in cm.energy_events.items(): ++ f.write(f"{event_key} {event_value}\n") + break + elif prompt[:6] == "/stats": +- print(cm.stats(), flush=True) ++ print(cm.stats(verbose=True), flush=True) + elif prompt[:4] == "/set": + gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1]) + generate_config = gen_config_overrides.apply(generate_config) +-- +2.43.0 + diff --git a/0002-Parse-stats-into-json-format.patch b/0002-Parse-stats-into-json-format.patch new file mode 100644 index 0000000..78d638b --- /dev/null +++ b/0002-Parse-stats-into-json-format.patch @@ -0,0 +1,60 @@ +From 9b3d6f2fc6e84ec29afe501611708da488950081 Mon Sep 17 00:00:00 2001 +From: Stefanos Laskaridis +Date: Mon, 4 Mar 2024 16:56:11 +0000 +Subject: [PATCH] Parse stats into json format + +--- + python/mlc_chat/interface/chat.py | 31 ++++++++++++++++++++++++++++++- + 1 file changed, 30 insertions(+), 1 deletion(-) + +diff --git a/python/mlc_chat/interface/chat.py b/python/mlc_chat/interface/chat.py +index 3d23df40..0df8bb15 100644 +--- a/python/mlc_chat/interface/chat.py ++++ b/python/mlc_chat/interface/chat.py +@@ -1,5 +1,7 @@ + """Python entrypoint of chat.""" + import dataclasses ++import re ++import json + from typing import List, Optional, Union + + from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error +@@ -152,7 +154,34 @@ def chat( + f.write(f"{event_key} {event_value}\n") + break + elif prompt[:6] == "/stats": +- print(cm.stats(verbose=True), flush=True) ++ # print(cm.stats(verbose=True), flush=True) ++ # ----------- prefill ----------- ++ # throughput: 87.899 tok/s ++ # total tokens: 10 tok ++ # total time: 0.114 s ++ # ------------ decode ------------ ++ # throughput: 54.603 tok/s ++ # total tokens: 18 tok ++ # total time: 0.330 s ++ # Parse the above metrics into json format ++ stats = cm.stats(verbose=True) ++ if stats.startswith("{"): # This is already handled by the backend ++ print(stats, flush=True) ++ else: # This is in case the backend has not been changed ++ stats = stats.strip().split("\n") ++ float_re = re.compile(r"\d+\.\d+") ++ int_re = re.compile(r"\d+") ++ stats_dict = {} ++ try: ++ for i in range(0, len(stats), 4): ++ stats_dict[stats[i].strip('-').strip()] = { ++ "throughput": f"{float(re.findall(float_re, stats[i + 1])[0])} tok/s", ++ "total_tokens": f"{int(re.findall(int_re, stats[i + 2])[0])} tok", ++ "total_time": f"{float(re.findall(float_re, stats[i + 3])[0])} s", ++ } ++ print(json.dumps(stats_dict, indent=4), flush=True) ++ except IndexError: ++ print(stats, flush=True) + elif prompt[:4] == "/set": + gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1]) + generate_config = gen_config_overrides.apply(generate_config) +-- +2.43.0 + diff --git a/3rdparty/argparse b/3rdparty/argparse new file mode 160000 index 0000000..557948f --- /dev/null +++ b/3rdparty/argparse @@ -0,0 +1 @@ +Subproject commit 557948f1236db9e27089959de837cc23de6c6bbd diff --git a/3rdparty/googletest b/3rdparty/googletest new file mode 160000 index 0000000..4580469 --- /dev/null +++ b/3rdparty/googletest @@ -0,0 +1 @@ +Subproject commit 45804691223635953f311cf31a10c632553bbfc3 diff --git a/3rdparty/tokenizers-cpp b/3rdparty/tokenizers-cpp new file mode 160000 index 0000000..27dbe17 --- /dev/null +++ b/3rdparty/tokenizers-cpp @@ -0,0 +1 @@ +Subproject commit 27dbe17d7268801ec720569167af905c88d3db50 diff --git a/3rdparty/tvm b/3rdparty/tvm new file mode 160000 index 0000000..59c3556 --- /dev/null +++ b/3rdparty/tvm @@ -0,0 +1 @@ +Subproject commit 59c3556043abdc88f3ed98e07aa6176ac9a3f0cd diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..81ded0a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,177 @@ +cmake_minimum_required(VERSION 3.18) +project(mlc_llm C CXX) + +include(CheckCXXCompilerFlag) +if(MSVC) + set(CMAKE_CXX_FLAGS "/fp:fast ${CMAKE_CXX_FLAGS}") +else() + set(CMAKE_CXX_FLAGS "-ffast-math ${CMAKE_CXX_FLAGS}") +endif() + +if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) + include(${CMAKE_BINARY_DIR}/config.cmake) +else() + if(EXISTS ${CMAKE_SOURCE_DIR}/config.cmake) + include(${CMAKE_SOURCE_DIR}/config.cmake) + endif() +endif() + +if(NOT CMAKE_BUILD_TYPE) + set( + CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Build type" FORCE + ) + message(STATUS "Setting default build type to " ${CMAKE_BUILD_TYPE}) +endif(NOT CMAKE_BUILD_TYPE) + +option(MLC_HIDE_PRIVATE_SYMBOLS "Hide private symbols" ON) + +if (MLC_LLM_INSTALL_STATIC_LIB) + set(BUILD_STATIC_RUNTIME ON) +endif() + +set(MLC_VISIBILITY_FLAG "") +if (MLC_HIDE_PRIVATE_SYMBOLS) + set(HIDE_PRIVATE_SYMBOLS ON) + if (NOT MSVC) + set(MLC_VISIBILITY_FLAG "-fvisibility=hidden") + endif() + message(STATUS "Hide private symbols") +endif() + +option(BUILD_CPP_TEST "Build cpp unittests" OFF) + +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# tvm runtime config: minimize runtime components +set(USE_RPC OFF) +set(USE_MICRO OFF) +#set(USE_GRAPH_EXECUTOR OFF) +set(USE_GRAPH_EXECUTOR_DEBUG OFF) +set(USE_AOT_EXECUTOR OFF) +set(USE_PROFILER OFF) +set(USE_GTEST OFF) +set(USE_LIBBACKTRACE OFF) +#set(BUILD_DUMMY_LIBTVM ON) +if (NOT DEFINED TVM_HOME) + set(TVM_HOME 3rdparty/tvm) +endif (NOT DEFINED TVM_HOME) +message(STATUS "TVM_HOME: ${TVM_HOME}") +add_subdirectory(${TVM_HOME} tvm EXCLUDE_FROM_ALL) + +set(MLC_LLM_RUNTIME_LINKER_LIB "") +set(TOKENZIER_CPP_PATH 3rdparty/tokenizers-cpp) +add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) + +if (DEFINED BENCHMARK_PER_LAYER) + message(STATUS "BENCHMARK_PER_LAYER: ${BENCHMARK_PER_LAYER}") +endif() + + +tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) +add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) + +set( + MLC_LLM_INCLUDES + ${TVM_HOME}/include + ${TVM_HOME}/3rdparty/dlpack/include + ${TVM_HOME}/3rdparty/dmlc-core/include + ${TVM_HOME}/3rdparty/picojson +) + +if (BENCHMARK_PER_LAYER) + set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} DMLC_USE_LOGGING_LIBRARY= BENCHMARK_PER_LAYER=${BENCHMARK_PER_LAYER}) +else() + set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} DMLC_USE_LOGGING_LIBRARY=) +endif() +set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} __STDC_FORMAT_MACROS=1) +set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} PICOJSON_USE_INT64) + +target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES}) +target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS}) +target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include) +target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS) + +add_library(mlc_llm SHARED $) +add_library(mlc_llm_static STATIC $) +add_dependencies(mlc_llm_static tokenizers_cpp sentencepiece-static tokenizers_c tvm_runtime) +set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) + +target_link_libraries(mlc_llm PUBLIC tvm_runtime) +target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) + +find_library(FLASH_ATTN_LIBRARY flash_attn) + +if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") + message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.") +else () + target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) +endif() + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + target_compile_definitions(mlc_llm PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(mlc_llm_objs PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(mlc_llm_static PRIVATE "TVM_LOG_DEBUG") +endif() + +if (BUILD_CPP_TEST) + message(STATUS "Building cpp unittests") + add_subdirectory(3rdparty/googletest) + file(GLOB_RECURSE MLC_LLM_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/cpp/*unittest.cc) + add_executable(mlc_llm_cpp_tests ${MLC_LLM_TEST_SRCS}) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${MLC_LLM_INCLUDES}) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${PROJECT_SOURCE_DIR}/cpp) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + target_link_libraries(mlc_llm_cpp_tests PUBLIC mlc_llm gtest gtest_main) +endif(BUILD_CPP_TEST) + +if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(mlc_llm PRIVATE log) + target_link_libraries(tokenizers_cpp PRIVATE log) +endif() + +add_library(mlc_llm_module SHARED $) +target_link_libraries(mlc_llm_module PUBLIC tvm) +target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp) + + +set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") +set_property(TARGET mlc_llm APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") + +find_program(CARGO_EXECUTABLE cargo) + +if(NOT CARGO_EXECUTABLE) + message(FATAL_ERROR "Cargo is not found! Please install cargo.") +endif() + +# when this option is on, +# we install all static lib deps into lib +if (MLC_LLM_INSTALL_STATIC_LIB) + install(TARGETS + mlc_llm_static + tokenizers_cpp + sentencepiece-static + tvm_runtime + LIBRARY DESTINATION lib${LIB_SUFFIX} + ) + # tokenizers need special handling as it builds from rust + if(MSVC) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.lib + DESTINATION lib${LIB_SUFFIX} + ) + else() + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.a + DESTINATION lib${LIB_SUFFIX} + ) + endif() +else() + install(TARGETS tvm_runtime mlc_llm mlc_llm_module + mlc_llm_static + tokenizers_cpp + sentencepiece-static + RUNTIME_DEPENDENCY_SET tokenizers_c + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib${LIB_SUFFIX} + ) +endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..3f70fac --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,6 @@ +MLC LLM Contributors +==================== + + +## List of Contributors +- [Full List of Contributors](https://github.com/mlc-ai/mlc-llm/graphs/contributors) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..586f143 --- /dev/null +++ b/LICENSE @@ -0,0 +1,213 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +/* + * Modifications Copyright (c) 2024 Brave Software + * + * The MLC-LLM repository has been forked from https://github.com/mlc-ai/mlc-llm. + * Our modifications relate to: + * - enabling per operation profiling in the backend + * - streamlining the build process + * - changing the mobile application to streamline automation + * - integrating event and operation tracing + */ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..9bea5cc --- /dev/null +++ b/README.md @@ -0,0 +1,220 @@ +[discord-url]: https://discord.gg/9Xpy2HGBuD + +# MLC LLM + +[Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] + +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. + +**Universal deployment.** MLC LLM supports the following platforms and hardware: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AMD GPUNVIDIA GPUApple GPUIntel GPU
Linux / Win✅ Vulkan, ROCm✅ Vulkan, CUDAN/A✅ Vulkan
macOS✅ Metal (dGPU)N/A✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple A-series GPU
Android✅ OpenCL on Adreno GPU✅ OpenCL on Mali GPU
+ + +**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below +showcases our single batch decoding performance with prefilling = 1 and decoding = 256. + +Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: +

+ + +

+ +Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: +

+ +

+ +## News + +* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. +* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). +* [08/25/2023] CodeLlama support is up. +* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. +* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. +* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. +* [07/19/2023] Support for Llama2-7B/13B/70B is up. +* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. +* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. +* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. +* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. + +## Getting Started + +Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. + +## Model Support + +MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can +use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ArchitecturePrebuilt Model Variants
LlamaLlama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
StableLM
Mistral
Phi
+ +## Universal Deployment APIs + +MLC LLM provides multiple sets of APIs across platforms and environments. These include +* [Python API](https://llm.mlc.ai/docs/deploy/python.html) +* [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) +* [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) +* [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) +* [Swift API for iOS App](https://llm.mlc.ai/docs/deploy/ios.html) +* [Java API and Android App](https://llm.mlc.ai/docs/deploy/android.html) + +## Citation + +Please consider citing our project if you find it useful: + +```bibtex +@software{mlc-llm, + author = {MLC team}, + title = {{MLC-LLM}}, + url = {https://github.com/mlc-ai/mlc-llm}, + year = {2023} +} +``` + +The underlying techniques of MLC LLM include: + +
+ References (Click to expand) + + ```bibtex + @inproceedings{tensorir, + author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, + title = {TensorIR: An Abstraction for Automatic Tensorized Program Optimization}, + year = {2023}, + isbn = {9781450399166}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + url = {https://doi.org/10.1145/3575693.3576933}, + doi = {10.1145/3575693.3576933}, + booktitle = {Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, + pages = {804–817}, + numpages = {14}, + keywords = {Tensor Computation, Machine Learning Compiler, Deep Neural Network}, + location = {Vancouver, BC, Canada}, + series = {ASPLOS 2023} + } + + @inproceedings{metaschedule, + author = {Shao, Junru and Zhou, Xiyou and Feng, Siyuan and Hou, Bohan and Lai, Ruihang and Jin, Hongyi and Lin, Wuwei and Masuda, Masahiro and Yu, Cody Hao and Chen, Tianqi}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, + pages = {35783--35796}, + publisher = {Curran Associates, Inc.}, + title = {Tensor Program Optimization with Probabilistic Programs}, + url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/e894eafae43e68b4c8dfdacf742bcbf3-Paper-Conference.pdf}, + volume = {35}, + year = {2022} + } + + @inproceedings{tvm, + author = {Tianqi Chen and Thierry Moreau and Ziheng Jiang and Lianmin Zheng and Eddie Yan and Haichen Shen and Meghan Cowan and Leyuan Wang and Yuwei Hu and Luis Ceze and Carlos Guestrin and Arvind Krishnamurthy}, + title = {{TVM}: An Automated {End-to-End} Optimizing Compiler for Deep Learning}, + booktitle = {13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)}, + year = {2018}, + isbn = {978-1-939133-08-3}, + address = {Carlsbad, CA}, + pages = {578--594}, + url = {https://www.usenix.org/conference/osdi18/presentation/chen}, + publisher = {USENIX Association}, + month = oct, + } + ``` +
+ +## Links + +- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic +walkthrough of our approaches. +- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. +- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. + diff --git a/android/.gitignore b/android/.gitignore new file mode 100644 index 0000000..002b05d --- /dev/null +++ b/android/.gitignore @@ -0,0 +1,19 @@ +app/src/main/jni/*.h +app/src/main/jni/*.cc +app/src/main/obj + +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/android/README.md b/android/README.md new file mode 100644 index 0000000..502eb53 --- /dev/null +++ b/android/README.md @@ -0,0 +1,3 @@ +# MLC-LLM Android + +[Documentation page](https://llm.mlc.ai/docs/deploy/android.html) diff --git a/android/app/.gitignore b/android/app/.gitignore new file mode 100644 index 0000000..558f311 --- /dev/null +++ b/android/app/.gitignore @@ -0,0 +1,2 @@ +/build +/src/main/libs \ No newline at end of file diff --git a/android/app/build.gradle b/android/app/build.gradle new file mode 100644 index 0000000..debbb90 --- /dev/null +++ b/android/app/build.gradle @@ -0,0 +1,73 @@ +plugins { + id 'com.android.application' + id 'org.jetbrains.kotlin.android' +} + +android { + namespace 'ai.mlc.mlcchat' + compileSdk 34 + + defaultConfig { + applicationId "ai.mlc.mlcchat32" + minSdk 26 + targetSdk 33 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary true + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = '1.8' + } + buildFeatures { + compose true + } + composeOptions { + kotlinCompilerExtensionVersion '1.4.3' + } + packagingOptions { + resources { + excludes += '/META-INF/{AL2.0,LGPL2.1}' + } + } +} + +dependencies { + implementation project(":library") + implementation 'androidx.core:core-ktx:1.10.1' + implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1' + implementation 'androidx.activity:activity-compose:1.7.1' + implementation platform('androidx.compose:compose-bom:2022.10.00') + implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1' + implementation 'androidx.compose.ui:ui' + implementation 'androidx.compose.ui:ui-graphics' + implementation 'androidx.compose.ui:ui-tooling-preview' + implementation 'androidx.compose.material3:material3:1.1.0' + implementation 'androidx.compose.material:material-icons-extended' + implementation 'androidx.appcompat:appcompat:1.6.1' + implementation 'androidx.navigation:navigation-compose:2.5.3' + implementation 'com.google.code.gson:gson:2.10.1' + implementation fileTree(dir: 'src/main/libs', include: ['*.aar', '*.jar'], exclude: []) + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.5' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' + androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00') + androidTestImplementation 'androidx.compose.ui:ui-test-junit4' + debugImplementation 'androidx.compose.ui:ui-tooling' + debugImplementation 'androidx.compose.ui:ui-test-manifest' + +} \ No newline at end of file diff --git a/android/app/proguard-rules.pro b/android/app/proguard-rules.pro new file mode 100644 index 0000000..481bb43 --- /dev/null +++ b/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..e25d837 --- /dev/null +++ b/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/app/src/main/ic_launcher-playstore.png b/android/app/src/main/ic_launcher-playstore.png new file mode 100644 index 0000000..3c16fd6 Binary files /dev/null and b/android/app/src/main/ic_launcher-playstore.png differ diff --git a/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt new file mode 100644 index 0000000..78c19e4 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -0,0 +1,912 @@ +package ai.mlc.mlcchat + +import ai.mlc.mlcllm.ChatModule +import android.app.Application +import android.content.ClipData +import android.content.ClipboardManager +import android.content.Context +import android.nfc.Tag +import android.os.Environment +import android.util.Log +import android.widget.Toast +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.toMutableStateList +import androidx.lifecycle.AndroidViewModel +import androidx.lifecycle.viewModelScope +import com.google.gson.Gson +import com.google.gson.annotations.SerializedName +import kotlinx.coroutines.launch +import java.io.File +import java.io.FileOutputStream +import java.lang.Thread.sleep +import java.net.URL +import java.nio.channels.Channels +import org.json.JSONObject +import java.util.Date +import java.util.UUID +import java.util.concurrent.Executors +import kotlin.concurrent.thread + + +class AppViewModel(application: Application) : AndroidViewModel(application) { + val modelList = emptyList().toMutableStateList() + val chatState = ChatState() + val modelSampleList = emptyList().toMutableStateList() + private var showAlert = mutableStateOf(false) + private var alertMessage = mutableStateOf("") + private var appConfig = AppConfig( + emptyList().toMutableList(), + emptyList().toMutableList() + ) + private val application = getApplication() + private val appDirFile = application.getExternalFilesDir("") + private val gson = Gson() + private val modelIdSet = emptySet().toMutableSet() + + companion object { + const val AppConfigFilename = "app-config.json" + const val ModelConfigFilename = "mlc-chat-config.json" + const val ParamsConfigFilename = "ndarray-cache.json" + const val ModelUrlSuffix = "resolve/main/" + } + + init { + loadAppConfig() + } + + fun isShowingAlert(): Boolean { + return showAlert.value + } + + fun errorMessage(): String { + return alertMessage.value + } + + fun dismissAlert() { + require(showAlert.value) + showAlert.value = false + } + + fun copyError() { + require(showAlert.value) + val clipboard = + application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager + clipboard.setPrimaryClip(ClipData.newPlainText("MLCChat", errorMessage())) + } + + private fun issueAlert(error: String) { + showAlert.value = true + alertMessage.value = error + } + + fun requestDeleteModel(modelId: String) { + deleteModel(modelId) + issueAlert("Model: $modelId has been deleted") + } + + + private fun loadAppConfig() { + val appConfigFile = File(appDirFile, AppConfigFilename) + val jsonString: String = if (!appConfigFile.exists()) { + application.assets.open(AppConfigFilename).bufferedReader().use { it.readText() } + } else { + appConfigFile.readText() + } + appConfig = gson.fromJson(jsonString, AppConfig::class.java) + appConfig.modelLibs = emptyList().toMutableList() + modelList.clear() + modelIdSet.clear() + modelSampleList.clear() + for (modelRecord in appConfig.modelList) { + appConfig.modelLibs.add(modelRecord.modelLib) + val modelDirFile = File(appDirFile, modelRecord.modelId) + val modelConfigFile = File(modelDirFile, ModelConfigFilename) + if (modelConfigFile.exists()) { + val modelConfigString = modelConfigFile.readText() + val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java) + modelConfig.modelId = modelRecord.modelId + modelConfig.modelLib = modelRecord.modelLib + modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes + addModelConfig(modelConfig, modelRecord.modelUrl, true) + } else { + downloadModelConfig( + if (modelRecord.modelUrl.endsWith("/")) modelRecord.modelUrl else "${modelRecord.modelUrl}/", + modelRecord, + true + ) + } + } + } + + private fun readInputFile(): Array>? { + val inputFile = File(appDirFile, "input.json") + val jsonString: String = if (!inputFile.exists()) { + application.assets.open("input.json").bufferedReader().use { it.readText() } + } else { + inputFile.readText() + } + return gson.fromJson(jsonString, Array>::class.java) + } + + private fun updateAppConfig(action: () -> Unit) { + action() + val jsonString = gson.toJson(appConfig) + val appConfigFile = File(appDirFile, AppConfigFilename) + appConfigFile.writeText(jsonString) + } + + private fun addModelConfig(modelConfig: ModelConfig, modelUrl: String, isBuiltin: Boolean) { + require(!modelIdSet.contains(modelConfig.modelId)) + modelIdSet.add(modelConfig.modelId) + modelList.add( + ModelState( + modelConfig, + modelUrl + if (modelUrl.endsWith("/")) "" else "/", + File(appDirFile, modelConfig.modelId) + ) + ) + if (!isBuiltin) { + updateAppConfig { + appConfig.modelList.add( + ModelRecord( + modelUrl, + modelConfig.modelId, + modelConfig.estimatedVramBytes, + modelConfig.modelLib + ) + ) + } + } + } + + private fun deleteModel(modelId: String) { + val modelDirFile = File(appDirFile, modelId) + modelDirFile.deleteRecursively() + require(!modelDirFile.exists()) + modelIdSet.remove(modelId) + modelList.removeIf { modelState -> modelState.modelConfig.modelId == modelId } + updateAppConfig { + appConfig.modelList.removeIf { modelRecord -> modelRecord.modelId == modelId } + } + } + + private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean { + if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true + viewModelScope.launch { + issueAlert("Model lib ${modelConfig.modelLib} is not supported.") + } + return false + } + + + private fun downloadModelConfig( + modelUrl: String, + modelRecord: ModelRecord, + isBuiltin: Boolean + ) { + thread(start = true) { + try { + val url = URL("${modelUrl}${ModelUrlSuffix}${ModelConfigFilename}") + val tempId = UUID.randomUUID().toString() + val tempFile = File( + application.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS), + tempId + ) + url.openStream().use { + Channels.newChannel(it).use { src -> + FileOutputStream(tempFile).use { fileOutputStream -> + fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) + } + } + } + require(tempFile.exists()) + viewModelScope.launch { + try { + val modelConfigString = tempFile.readText() + val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java) + modelConfig.modelId = modelRecord.modelId + modelConfig.modelLib = modelRecord.modelLib + modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes + if (modelIdSet.contains(modelConfig.modelId)) { + tempFile.delete() + issueAlert("${modelConfig.modelId} has been used, please consider another local ID") + return@launch + } + if (!isModelConfigAllowed(modelConfig)) { + tempFile.delete() + return@launch + } + val modelDirFile = File(appDirFile, modelConfig.modelId) + val modelConfigFile = File(modelDirFile, ModelConfigFilename) + tempFile.copyTo(modelConfigFile, overwrite = true) + tempFile.delete() + require(modelConfigFile.exists()) + addModelConfig(modelConfig, modelUrl, isBuiltin) + } catch (e: Exception) { + viewModelScope.launch { + Toast.makeText( + application, + "Add model failed: ${e.localizedMessage}", + Toast.LENGTH_LONG + ).show() + Log.e("mlc-llm", e.localizedMessage) + } + } + } + } catch (e: Exception) { + viewModelScope.launch { + Toast.makeText( + application, + "Download model config failed: ${e.localizedMessage}", + Toast.LENGTH_LONG + ).show() + Log.e("mlc-llm", e.localizedMessage) + } + } + + } + } + + inner class ModelState( + val modelConfig: ModelConfig, + private val modelUrl: String, + private val modelDirFile: File + ) { + var modelInitState = mutableStateOf(ModelInitState.Initializing) + private var paramsConfig = ParamsConfig(emptyList()) + val progress = mutableStateOf(0) + val total = mutableStateOf(1) + val id: UUID = UUID.randomUUID() + private val remainingTasks = emptySet().toMutableSet() + private val downloadingTasks = emptySet().toMutableSet() + private val maxDownloadTasks = 3 + private val gson = Gson() + + + init { + switchToInitializing() + } + + private fun switchToInitializing() { + val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) + if (paramsConfigFile.exists()) { + loadParamsConfig() + switchToIndexing() + } else { + downloadParamsConfig() + } + } + + private fun loadParamsConfig() { + val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) + require(paramsConfigFile.exists()) + val jsonString = paramsConfigFile.readText() + paramsConfig = gson.fromJson(jsonString, ParamsConfig::class.java) + } + + private fun downloadParamsConfig() { + thread(start = true) { + val url = URL("${modelUrl}${ModelUrlSuffix}${ParamsConfigFilename}") + val tempId = UUID.randomUUID().toString() + val tempFile = File(modelDirFile, tempId) + url.openStream().use { + Channels.newChannel(it).use { src -> + FileOutputStream(tempFile).use { fileOutputStream -> + fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) + } + } + } + require(tempFile.exists()) + val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) + tempFile.renameTo(paramsConfigFile) + require(paramsConfigFile.exists()) + viewModelScope.launch { + loadParamsConfig() + switchToIndexing() + } + } + } + + fun handleStart() { + switchToDownloading() + } + + fun handlePause() { + switchToPausing() + } + + fun handleClear() { + require( + modelInitState.value == ModelInitState.Downloading || + modelInitState.value == ModelInitState.Paused || + modelInitState.value == ModelInitState.Finished + ) + switchToClearing() + } + + private fun switchToClearing() { + if (modelInitState.value == ModelInitState.Paused) { + modelInitState.value = ModelInitState.Clearing + clear() + } else if (modelInitState.value == ModelInitState.Finished) { + modelInitState.value = ModelInitState.Clearing + if (chatState.modelName.value == modelConfig.modelId) { + chatState.requestTerminateChat { clear() } + } else { + clear() + } + } else { + modelInitState.value = ModelInitState.Clearing + } + } + + fun handleDelete() { + require( + modelInitState.value == ModelInitState.Downloading || + modelInitState.value == ModelInitState.Paused || + modelInitState.value == ModelInitState.Finished + ) + switchToDeleting() + } + + private fun switchToDeleting() { + if (modelInitState.value == ModelInitState.Paused) { + modelInitState.value = ModelInitState.Deleting + delete() + } else if (modelInitState.value == ModelInitState.Finished) { + modelInitState.value = ModelInitState.Deleting + if (chatState.modelName.value == modelConfig.modelId) { + chatState.requestTerminateChat { delete() } + } else { + delete() + } + } else { + modelInitState.value = ModelInitState.Deleting + } + } + + private fun switchToIndexing() { + modelInitState.value = ModelInitState.Indexing + progress.value = 0 + total.value = modelConfig.tokenizerFiles.size + paramsConfig.paramsRecords.size + for (tokenizerFilename in modelConfig.tokenizerFiles) { + val file = File(modelDirFile, tokenizerFilename) + if (file.exists()) { + ++progress.value + } else { + remainingTasks.add( + DownloadTask( + URL("${modelUrl}${ModelUrlSuffix}${tokenizerFilename}"), + file + ) + ) + } + } + for (paramsRecord in paramsConfig.paramsRecords) { + val file = File(modelDirFile, paramsRecord.dataPath) + if (file.exists()) { + ++progress.value + } else { + remainingTasks.add( + DownloadTask( + URL("${modelUrl}${ModelUrlSuffix}${paramsRecord.dataPath}"), + file + ) + ) + } + } + if (progress.value < total.value) { + switchToPaused() + } else { + switchToFinished() + } + } + + private fun switchToDownloading() { + modelInitState.value = ModelInitState.Downloading + for (downloadTask in remainingTasks) { + if (downloadingTasks.size < maxDownloadTasks) { + handleNewDownload(downloadTask) + } else { + return + } + } + } + + private fun handleNewDownload(downloadTask: DownloadTask) { + require(modelInitState.value == ModelInitState.Downloading) + require(!downloadingTasks.contains(downloadTask)) + downloadingTasks.add(downloadTask) + thread(start = true) { + val tempId = UUID.randomUUID().toString() + val tempFile = File(modelDirFile, tempId) + downloadTask.url.openStream().use { + Channels.newChannel(it).use { src -> + FileOutputStream(tempFile).use { fileOutputStream -> + fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) + } + } + } + require(tempFile.exists()) + tempFile.renameTo(downloadTask.file) + require(downloadTask.file.exists()) + viewModelScope.launch { + handleFinishDownload(downloadTask) + } + } + } + + private fun handleNextDownload() { + require(modelInitState.value == ModelInitState.Downloading) + for (downloadTask in remainingTasks) { + if (!downloadingTasks.contains(downloadTask)) { + handleNewDownload(downloadTask) + break + } + } + } + + private fun handleFinishDownload(downloadTask: DownloadTask) { + remainingTasks.remove(downloadTask) + downloadingTasks.remove(downloadTask) + ++progress.value + require( + modelInitState.value == ModelInitState.Downloading || + modelInitState.value == ModelInitState.Pausing || + modelInitState.value == ModelInitState.Clearing || + modelInitState.value == ModelInitState.Deleting + ) + if (modelInitState.value == ModelInitState.Downloading) { + if (remainingTasks.isEmpty()) { + if (downloadingTasks.isEmpty()) { + switchToFinished() + } + } else { + handleNextDownload() + } + } else if (modelInitState.value == ModelInitState.Pausing) { + if (downloadingTasks.isEmpty()) { + switchToPaused() + } + } else if (modelInitState.value == ModelInitState.Clearing) { + if (downloadingTasks.isEmpty()) { + clear() + } + } else if (modelInitState.value == ModelInitState.Deleting) { + if (downloadingTasks.isEmpty()) { + delete() + } + } + } + + private fun clear() { + val files = modelDirFile.listFiles { dir, name -> + !(dir == modelDirFile && name == ModelConfigFilename) + } + require(files != null) + for (file in files) { + file.deleteRecursively() + require(!file.exists()) + } + val modelConfigFile = File(modelDirFile, ModelConfigFilename) + require(modelConfigFile.exists()) + switchToIndexing() + } + + private fun delete() { + modelDirFile.deleteRecursively() + require(!modelDirFile.exists()) + requestDeleteModel(modelConfig.modelId) + } + + private fun switchToPausing() { + modelInitState.value = ModelInitState.Pausing + } + + private fun switchToPaused() { + modelInitState.value = ModelInitState.Paused + } + + + private fun switchToFinished() { + modelInitState.value = ModelInitState.Finished + } + + fun startChat() { + chatState.requestReloadChat( + modelConfig, + modelDirFile.absolutePath, + ) + } + + } + + inner class ChatState { + val messages = emptyList().toMutableStateList() + val report = mutableStateOf("") + val modelName = mutableStateOf("") + lateinit var modelLoadTime: TimeRecord + + private var modelChatState = mutableStateOf(ModelChatState.Ready) + @Synchronized get + @Synchronized set + private val backend = ChatModule() + private var modelLib = "" + private var modelPath = "" + private val executorService = Executors.newSingleThreadExecutor() + + private fun mainResetChat() { + executorService.submit { + callBackend { backend.resetChat() } + viewModelScope.launch { + clearHistory() + switchToReady() + } + } + } + + private fun clearHistory() { + messages.clear() + report.value = "" + } + + + private fun switchToResetting() { + modelChatState.value = ModelChatState.Resetting + } + + private fun switchToGenerating() { + modelChatState.value = ModelChatState.Generating + } + + private fun switchToReloading() { + modelChatState.value = ModelChatState.Reloading + } + + private fun switchToReady() { + modelChatState.value = ModelChatState.Ready + } + + private fun switchToFailed() { + modelChatState.value = ModelChatState.Falied + } + + private fun callBackend(callback: () -> Unit): Boolean { + try { + callback() + } catch (e: Exception) { + viewModelScope.launch { + val stackTrace = e.stackTraceToString() + val errorMessage = e.localizedMessage + appendMessage( + MessageRole.Bot, + "MLCChat failed\n\nStack trace:\n$stackTrace\n\nError message:\n$errorMessage" + ) + switchToFailed() + } + return false + } + return true + } + + fun requestResetChat() { + require(interruptable()) + interruptChat( + prologue = { + switchToResetting() + }, + epilogue = { + mainResetChat() + } + ) + } + + private fun interruptChat(prologue: () -> Unit, epilogue: () -> Unit) { + // prologue runs before interruption + // epilogue runs after interruption + require(interruptable()) + if (modelChatState.value == ModelChatState.Ready) { + prologue() + epilogue() + } else if (modelChatState.value == ModelChatState.Generating) { + prologue() + executorService.submit { + viewModelScope.launch { epilogue() } + } + } else { + require(false) + } + } + + fun requestTerminateChat(callback: () -> Unit) { + require(interruptable()) + interruptChat( + prologue = { + switchToTerminating() + }, + epilogue = { + mainTerminateChat(callback) + } + ) + } + + private fun mainTerminateChat(callback: () -> Unit) { + executorService.submit { + callBackend { backend.unload() } + viewModelScope.launch { + clearHistory() + switchToReady() + callback() + } + } + } + + private fun switchToTerminating() { + modelChatState.value = ModelChatState.Terminating + } + + + fun requestReloadChat(modelConfig: ModelConfig, modelPath: String) { + + if (this.modelName.value == modelConfig.modelId && this.modelLib == modelConfig.modelLib && this.modelPath == modelPath) { + return + } + require(interruptable()) + interruptChat( + prologue = { + switchToReloading() + }, + epilogue = { + mainReloadChat(modelConfig, modelPath) + } + ) + } + + private fun mainReloadChat(modelConfig: ModelConfig, modelPath: String) { + val timeStart = Date() + clearHistory() + this.modelName.value = modelConfig.modelId + this.modelLib = modelConfig.modelLib + this.modelPath = modelPath + executorService.submit { + viewModelScope.launch { + Toast.makeText(application, "Initialize...", Toast.LENGTH_SHORT).show() + } + if (!callBackend { + backend.unload() + backend.reload( + modelConfig.modelLib, + modelPath + ) + }) return@submit + viewModelScope.launch { + Toast.makeText(application, "Ready to chat", Toast.LENGTH_SHORT).show() + switchToReady() + } + + val duration = Date().time - timeStart.time + this.modelLoadTime = TimeRecord(timeStart, duration) + } + } + + fun requestGenerate(prompt: String) { + require(chatable()) + switchToGenerating() + executorService.submit { + appendMessage(MessageRole.User, prompt) + appendMessage(MessageRole.Bot, "") + if (!callBackend { backend.prefill(prompt) }) return@submit + while (!backend.stopped()) { + if (!callBackend { + backend.decode() + val newText = backend.message + viewModelScope.launch { updateMessage(MessageRole.Bot, newText) } + }) return@submit + if (modelChatState.value != ModelChatState.Generating) return@submit + } + val runtimeStats = backend.runtimeStatsText() + viewModelScope.launch { + report.value = runtimeStats + if (modelChatState.value == ModelChatState.Generating) switchToReady() + } + } + } + + private fun appendMessage(role: MessageRole, text: String) { + messages.add(MessageData(role, text)) + } + + + private fun updateMessage(role: MessageRole, text: String) { + messages[messages.size - 1] = MessageData(role, text) + } + + fun chatable(): Boolean { + return modelChatState.value == ModelChatState.Ready + } + + fun interruptable(): Boolean { + return modelChatState.value == ModelChatState.Ready + || modelChatState.value == ModelChatState.Generating + || modelChatState.value == ModelChatState.Falied + } + + fun requestAutomation(measurementFilename: String) { + + val conversationsRecordManager = ConversationsRecordManager() + val conversations = readInputFile() + + if (conversations == null) { + Log.e("AppViewModel", "Couldn't not load input.json") + return + } + + require(chatable()) + switchToGenerating() + + executorService.submit { + + // per conversation + conversations.forEachIndexed { c_idx, conversation -> + + val conversationRecord = ConversationRecord(this.modelName.value, this.modelLoadTime) + + conversation.forEachIndexed { q_idx, question -> + + appendMessage(MessageRole.User, "{$c_idx}_{$q_idx}: $question") + //Log.d("AppViewModel", "Prompt: $question") + + val timeStart = Date() + + if (!callBackend { backend.prefill(question) }) return@submit + while (!backend.stopped()) { + if (!callBackend { + backend.decode() + //val newText = backend.message + }) return@submit + if (modelChatState.value != ModelChatState.Generating) return@submit + } + + val runtimeStatsText = backend.verboseRuntimeStatsText() + val jsonResult = parseJSON(runtimeStatsText) + + val originalSessionTokens = -1 + var inputTokens = -1 + var outputTokens = -1 + + jsonResult?.let { + inputTokens = it["prefill"]?.get("total tokens")?.split(" ")?.first()?.toInt()!! + outputTokens = it["decode"]?.get("total tokens")?.split(" ")?.first()?.toInt()!! + } + + //Log.d("AppViewModel", "Answer: " + backend.message) + + val duration = Date().time - timeStart.time + val questionRecord = QuestionRecord( + TimeRecord(timeStart, duration), + question, + backend.message, + originalSessionTokens, + inputTokens, + outputTokens, + runtimeStatsText + ) + conversationRecord.questionRecords.add(questionRecord) + val message = backend.message + appendMessage(MessageRole.Bot, "{$c_idx}_{$q_idx}: $message") + sleep(5000) + } + + // Save energy events for particular session + val file = File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS).toString() + File.separator + "melt_measurements", + "${measurementFilename}_conv$c_idx.csv" + ) + + if (!file.parentFile.exists()) { + file.parentFile.mkdirs() + } + + backend.saveEnergyEventsToCSV(file) + + // add metrics + conversationsRecordManager.addConversationRecord(conversationRecord) + + // clear context + backend.resetChat() + backend.resetEnergyEvents() + + appendMessage(MessageRole.Bot, "--sleep--") + sleep(60 * 1000) + } + + conversationsRecordManager.saveToFile(measurementFilename) + + // Notify BladeRunner that task is complete + val client = RestAwaitLib("192.168.1.42", 5100) + val response = client.continueExecution() + Log.d("AppViewModel", response) + } + } + } + + fun parseJSON(jsonString: String): Map>? { + return try { + val jsonObject = JSONObject(jsonString) + val result = mutableMapOf>() + + jsonObject.keys().forEach { key -> + val innerJson = jsonObject.getJSONObject(key) + val innerMap = mutableMapOf() + + innerJson.keys().forEach { innerKey -> + innerMap[innerKey] = innerJson.getString(innerKey) + } + + result[key] = innerMap + } + + result + } catch (e: Exception) { + println("Error parsing JSON: $e") + null + } + } +} + +enum class ModelInitState { + Initializing, + Indexing, + Paused, + Downloading, + Pausing, + Clearing, + Deleting, + Finished +} + +enum class ModelChatState { + Generating, + Resetting, + Reloading, + Terminating, + Ready, + Falied +} + +enum class MessageRole { + Bot, + User +} + +data class DownloadTask(val url: URL, val file: File) + +data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID()) + +data class AppConfig( + @SerializedName("model_libs") var modelLibs: MutableList, + @SerializedName("model_list") val modelList: MutableList, +) + +data class ModelRecord( + @SerializedName("model_url") val modelUrl: String, + @SerializedName("model_id") val modelId: String, + @SerializedName("estimated_vram_bytes") val estimatedVramBytes: Long?, + @SerializedName("model_lib") val modelLib: String +) + +data class ModelConfig( + @SerializedName("model_lib") var modelLib: String, + @SerializedName("model_id") var modelId: String, + @SerializedName("estimated_vram_bytes") var estimatedVramBytes: Long?, + @SerializedName("tokenizer_files") val tokenizerFiles: List, + @SerializedName("context_window_size") val contextWindowSize: Int, + @SerializedName("prefill_chunk_size") val prefillChunkSize: Int, +) + +data class ParamsRecord( + @SerializedName("dataPath") val dataPath: String +) + +data class ParamsConfig( + @SerializedName("records") val paramsRecords: List +) \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt new file mode 100644 index 0000000..61351f1 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt @@ -0,0 +1,236 @@ +package ai.mlc.mlcchat + +import androidx.compose.foundation.background +import androidx.compose.foundation.gestures.detectTapGestures +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.lazy.rememberLazyListState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.selection.SelectionContainer +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material.icons.filled.Replay +import androidx.compose.material.icons.filled.Send +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.runtime.saveable.rememberSaveable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import androidx.navigation.NavController +import kotlinx.coroutines.launch + +@ExperimentalMaterial3Api +@Composable +fun ChatView( + navController: NavController, chatState: AppViewModel.ChatState +) { + val localFocusManager = LocalFocusManager.current + var isDialogOpen by remember { mutableStateOf(false) } + var filenameInput by remember { mutableStateOf("") } + + Scaffold(topBar = { + TopAppBar( + title = { + Text( + text = "MLCChat: " + chatState.modelName.value.split("-")[0], + color = MaterialTheme.colorScheme.onPrimary + ) + }, + colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary), + navigationIcon = { + IconButton( + onClick = { navController.popBackStack() }, + enabled = chatState.interruptable() + ) { + Icon( + imageVector = Icons.Filled.ArrowBack, + contentDescription = "back home page", + tint = MaterialTheme.colorScheme.onPrimary + ) + } + }, + actions = { + IconButton( + onClick = { + //chatState.requestAutomation() + isDialogOpen = true + }, + enabled = chatState.interruptable() + ) { + Icon( + imageVector = Icons.Filled.Replay, + contentDescription = "reset the chat", + tint = MaterialTheme.colorScheme.onPrimary + ) + } + }) + }, modifier = Modifier.pointerInput(Unit) { + detectTapGestures(onTap = { + localFocusManager.clearFocus() + }) + }) { paddingValues -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(paddingValues) + .padding(horizontal = 10.dp) + ) { + val lazyColumnListState = rememberLazyListState() + val coroutineScope = rememberCoroutineScope() + Text( + text = chatState.report.value, + textAlign = TextAlign.Center, + modifier = Modifier + .fillMaxWidth() + .wrapContentHeight() + .padding(top = 5.dp) + ) + Divider(thickness = 1.dp, modifier = Modifier.padding(vertical = 5.dp)) + LazyColumn( + modifier = Modifier.weight(9f), + verticalArrangement = Arrangement.spacedBy(5.dp, alignment = Alignment.Bottom), + state = lazyColumnListState + ) { + coroutineScope.launch { + lazyColumnListState.animateScrollToItem(chatState.messages.size) + } + items( + items = chatState.messages, + key = { message -> message.id }, + ) { message -> + MessageView(messageData = message) + } + item { + // place holder item for scrolling to the bottom + } + } + Divider(thickness = 1.dp, modifier = Modifier.padding(top = 5.dp)) + SendMessageView(chatState = chatState) + } + } + + if (isDialogOpen) { + AlertDialog( + onDismissRequest = { isDialogOpen = false }, + title = { Text("Enter filename prefix") }, + text = { + OutlinedTextField( + value = filenameInput, + onValueChange = { filenameInput = it }, + label = { Text("Filename prefix") } + ) + }, + confirmButton = { + Button( + onClick = { + if (filenameInput.isNotBlank()) { + chatState.requestAutomation(filenameInput) + isDialogOpen = false + filenameInput = "" + } + } + ) { + Text("Run") + } + }, + dismissButton = { + Button( + onClick = { isDialogOpen = false } + ) { + Text("Cancel") + } + } + ) + } +} + +@Composable +fun MessageView(messageData: MessageData) { + SelectionContainer { + if (messageData.role == MessageRole.Bot) { + Row( + horizontalArrangement = Arrangement.Start, + modifier = Modifier.fillMaxWidth() + ) { + Text( + text = messageData.text, + textAlign = TextAlign.Left, + color = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier + .wrapContentWidth() + .background( + color = MaterialTheme.colorScheme.secondaryContainer, + shape = RoundedCornerShape(5.dp) + ) + .padding(5.dp) + .widthIn(max = 300.dp) + ) + + } + } else { + Row( + horizontalArrangement = Arrangement.End, + modifier = Modifier.fillMaxWidth() + ) { + Text( + text = messageData.text, + textAlign = TextAlign.Right, + color = MaterialTheme.colorScheme.onPrimaryContainer, + modifier = Modifier + .wrapContentWidth() + .background( + color = MaterialTheme.colorScheme.primaryContainer, + shape = RoundedCornerShape(5.dp) + ) + .padding(5.dp) + .widthIn(max = 300.dp) + ) + + } + } + } +} + +@ExperimentalMaterial3Api +@Composable +fun SendMessageView(chatState: AppViewModel.ChatState) { + val localFocusManager = LocalFocusManager.current + Row( + horizontalArrangement = Arrangement.spacedBy(5.dp), + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier + .height(IntrinsicSize.Max) + .fillMaxWidth() + .padding(bottom = 5.dp) + ) { + var text by rememberSaveable { mutableStateOf("") } + OutlinedTextField( + value = text, + onValueChange = { text = it }, + label = { Text(text = "Input") }, + modifier = Modifier + .weight(9f), + ) + IconButton( + onClick = { + localFocusManager.clearFocus() + chatState.requestGenerate(text) + text = "" + }, + modifier = Modifier + .aspectRatio(1f) + .weight(1f), + enabled = (text != "" && chatState.chatable()) + ) { + Icon( + imageVector = Icons.Filled.Send, + contentDescription = "send message", + ) + } + } +} diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ConversationsRecordManager.kt b/android/app/src/main/java/ai/mlc/mlcchat/ConversationsRecordManager.kt new file mode 100644 index 0000000..4cf28bb --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/ConversationsRecordManager.kt @@ -0,0 +1,89 @@ +package ai.mlc.mlcchat + +import android.os.Environment +import android.util.Log +import org.json.JSONArray +import org.json.JSONObject +import java.io.File +import java.io.IOException +import java.util.* + +data class ConversationRecord( + val modelName: String, + val modelLoadTime: TimeRecord, + val questionRecords: MutableList = mutableListOf() +) + +data class QuestionRecord( + val time: TimeRecord, + val input: String, + val output: String, + val original_session_tokens: Int, + val input_tokens: Int, + val output_tokens: Int, + val runtimeStats: String +) + +data class TimeRecord( + val start: Date, + val duration: Long +) + +class ConversationsRecordManager { + private val conversations: ArrayList = ArrayList() + + fun addConversationRecord(conversation: ConversationRecord) { + conversations.add(conversation) + } + + fun saveToFile(fileName: String) { + val file = File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS).toString() + File.separator + "melt_measurements", + "$fileName.json" + ) + + if (!file.parentFile.exists()) { + file.parentFile.mkdirs() + } + + try { + val jsonArray = JSONArray() + for (session in conversations) { + val sessionObject = JSONObject() + sessionObject.put("modelName", session.modelName) + + val modelLoadTime = JSONObject() + val startEpoch = session.modelLoadTime.start.time / 1000.0 + modelLoadTime.put("start", startEpoch) + modelLoadTime.put("duration", session.modelLoadTime.duration / 1000.0) + sessionObject.put("modelLoadTime", modelLoadTime) + + val questionRecordsArray = JSONArray() + for (questionRecord in session.questionRecords) { + val chatRecordObject = JSONObject() + val timeRecordObject = JSONObject() + + val questionStartEpoch = questionRecord.time.start.time / 1000.0 + timeRecordObject.put("start", questionStartEpoch) + timeRecordObject.put("duration", questionRecord.time.duration / 1000.0) + + chatRecordObject.put("time", timeRecordObject) + chatRecordObject.put("input", questionRecord.input) + chatRecordObject.put("output", questionRecord.output) + chatRecordObject.put("original_session_tokens", questionRecord.original_session_tokens) + chatRecordObject.put("input_tokens", questionRecord.input_tokens) + chatRecordObject.put("output_tokens", questionRecord.output_tokens) + chatRecordObject.put("runtimeStats", questionRecord.runtimeStats) + + questionRecordsArray.put(chatRecordObject) + } + sessionObject.put("questionRecords", questionRecordsArray) + + jsonArray.put(sessionObject) + } + file.writeText(jsonArray.toString(4)) + Log.d("SessionRecordManager", "JSON data successfully saved at: " + file.absolutePath) + } catch (e: IOException) { + Log.e("SessionRecordManager", "Failed to write JSON data: ${e.localizedMessage}") + } + } +} diff --git a/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt b/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt new file mode 100644 index 0000000..c586869 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt @@ -0,0 +1,29 @@ +package ai.mlc.mlcchat + +import ai.mlc.mlcchat.ui.theme.MLCChatTheme +import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Surface +import androidx.compose.ui.Modifier + + +class MainActivity : ComponentActivity() { + + @ExperimentalMaterial3Api + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContent { + Surface( + modifier = Modifier + .fillMaxSize() + ) { + MLCChatTheme { + NavView() + } + } + } + } +} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt b/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt new file mode 100644 index 0000000..fe897ce --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt @@ -0,0 +1,18 @@ +package ai.mlc.mlcchat + +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.runtime.Composable +import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.navigation.compose.NavHost +import androidx.navigation.compose.composable +import androidx.navigation.compose.rememberNavController + +@ExperimentalMaterial3Api +@Composable +fun NavView(appViewModel: AppViewModel = viewModel()) { + val navController = rememberNavController() + NavHost(navController = navController, startDestination = "home") { + composable("home") { StartView(navController, appViewModel) } + composable("chat") { ChatView(navController, appViewModel.chatState) } + } +} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/RestAwaitLib.kt b/android/app/src/main/java/ai/mlc/mlcchat/RestAwaitLib.kt new file mode 100644 index 0000000..409b72f --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/RestAwaitLib.kt @@ -0,0 +1,25 @@ +package ai.mlc.mlcchat + +import java.io.BufferedReader +import java.io.InputStreamReader +import java.net.HttpURLConnection +import java.net.URL + +class RestAwaitLib(private val host: String, private val port: Int) { + + fun continueExecution(): String { + val url = URL("http://$host:$port/continue") + val connection = url.openConnection() as HttpURLConnection + connection.requestMethod = "GET" + + val responseCode = connection.responseCode + if (responseCode == HttpURLConnection.HTTP_OK) { + val reader = BufferedReader(InputStreamReader(connection.inputStream)) + val response = reader.readText() + reader.close() + return response + } else { + throw RuntimeException("GET request failed with response code: $responseCode") + } + } +} diff --git a/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt new file mode 100644 index 0000000..a58129e --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt @@ -0,0 +1,251 @@ +package ai.mlc.mlcchat + +import androidx.compose.foundation.gestures.detectTapGestures +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.aspectRatio +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.layout.wrapContentHeight +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.text.selection.SelectionContainer +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.Chat +import androidx.compose.material.icons.outlined.Delete +import androidx.compose.material.icons.outlined.Download +import androidx.compose.material.icons.outlined.Pause +import androidx.compose.material.icons.outlined.Schedule +import androidx.compose.material3.AlertDialog +import androidx.compose.material3.Divider +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.LinearProgressIndicator +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.Scaffold +import androidx.compose.material3.Text +import androidx.compose.material3.TextButton +import androidx.compose.material3.TopAppBar +import androidx.compose.material3.TopAppBarDefaults +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.saveable.rememberSaveable +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import androidx.navigation.NavController + + +@ExperimentalMaterial3Api +@Composable +fun StartView( + navController: NavController, + appViewModel: AppViewModel +) { + val localFocusManager = LocalFocusManager.current + Scaffold( + topBar = { + TopAppBar( + title = { Text(text = "MLCChat", color = MaterialTheme.colorScheme.onPrimary) }, + colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary) + ) + }, + modifier = Modifier.pointerInput(Unit) { + detectTapGestures(onTap = { + localFocusManager.clearFocus() + }) + } + ) + { paddingValues -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(paddingValues) + .padding(horizontal = 10.dp) + ) { + Text(text = "Model List", modifier = Modifier.padding(top = 10.dp)) + LazyColumn() { + items(items = appViewModel.modelList, + key = { modelState -> modelState.id } + ) { modelState -> + ModelView( + navController = navController, + modelState = modelState, + appViewModel = appViewModel + ) + } + } + } + if (appViewModel.isShowingAlert()) { + AlertDialog( + onDismissRequest = { appViewModel.dismissAlert() }, + onConfirmation = { appViewModel.copyError() }, + error = appViewModel.errorMessage() + ) + } + } +} + +@ExperimentalMaterial3Api +@Composable +fun AlertDialog( + onDismissRequest: () -> Unit, + onConfirmation: () -> Unit, + error: String, +) { + AlertDialog( + title = { Text(text = "Error") }, + text = { Text(text = error) }, + onDismissRequest = { onDismissRequest() }, + confirmButton = { + TextButton(onClick = { onConfirmation() }) { Text("Copy") } + }, + dismissButton = { + TextButton(onClick = { onDismissRequest() }) { Text("Dismiss") } + } + ) +} + +@Composable +fun ModelView( + navController: NavController, + modelState: AppViewModel.ModelState, + appViewModel: AppViewModel +) { + var isDeletingModel by rememberSaveable { mutableStateOf(false) } + Column( + verticalArrangement = Arrangement.SpaceBetween, + modifier = Modifier + .wrapContentHeight() + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(5.dp), + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier + .fillMaxWidth() + .wrapContentHeight() + ) { + Text( + text = modelState.modelConfig.modelId, + textAlign = TextAlign.Left, + modifier = Modifier + .wrapContentHeight() + .weight(8f) + ) + Divider( + modifier = Modifier + .height(20.dp) + .width(1.dp) + ) + if (modelState.modelInitState.value == ModelInitState.Paused) { + IconButton( + onClick = { modelState.handleStart() }, modifier = Modifier + .aspectRatio(1f) + .weight(1f) + ) { + Icon( + imageVector = Icons.Outlined.Download, + contentDescription = "start downloading", + ) + } + + } else if (modelState.modelInitState.value == ModelInitState.Downloading) { + IconButton( + onClick = { modelState.handlePause() }, modifier = Modifier + .aspectRatio(1f) + .weight(1f) + ) { + Icon( + imageVector = Icons.Outlined.Pause, + contentDescription = "pause downloading", + ) + } + } else if (modelState.modelInitState.value == ModelInitState.Finished) { + IconButton( + onClick = { + modelState.startChat() + navController.navigate("chat") + }, + enabled = appViewModel.chatState.interruptable(), + modifier = Modifier + .aspectRatio(1f) + .weight(1f) + ) { + Icon( + imageVector = Icons.Outlined.Chat, + contentDescription = "start chatting", + ) + } + } else { + IconButton( + enabled = false, onClick = {}, modifier = Modifier + .aspectRatio(1f) + .weight(1f) + ) { + Icon( + imageVector = Icons.Outlined.Schedule, + contentDescription = "pending", + ) + } + } + if (modelState.modelInitState.value == ModelInitState.Downloading || + modelState.modelInitState.value == ModelInitState.Paused || + modelState.modelInitState.value == ModelInitState.Finished + ) { + IconButton( + onClick = { isDeletingModel = true }, + modifier = Modifier + .aspectRatio(1f) + .weight(1f) + ) { + Icon( + imageVector = Icons.Outlined.Delete, + contentDescription = "start downloading", + tint = MaterialTheme.colorScheme.error + ) + } + } + } + LinearProgressIndicator( + progress = modelState.progress.value.toFloat() / modelState.total.value, + modifier = Modifier.fillMaxWidth() + ) + if (isDeletingModel) { + Row( + horizontalArrangement = Arrangement.End, + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier + .fillMaxWidth() + .wrapContentHeight() + ) { + TextButton(onClick = { isDeletingModel = false }) { + Text(text = "cancel") + } + TextButton(onClick = { + isDeletingModel = false + modelState.handleClear() + }) { + Text(text = "clear data", color = MaterialTheme.colorScheme.error) + } + TextButton(onClick = { + isDeletingModel = false + modelState.handleDelete() + }) { + Text(text = "delete model", color = MaterialTheme.colorScheme.error) + } + } + } + } +} + diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt new file mode 100644 index 0000000..75a3557 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt @@ -0,0 +1,44 @@ +package ai.mlc.mlcchat.ui.theme + +import androidx.compose.ui.graphics.Color + +val Blue10 = Color(0xFF000F5E) +val Blue20 = Color(0xFF001E92) +val Blue30 = Color(0xFF002ECC) +val Blue40 = Color(0xFF1546F6) +val Blue80 = Color(0xFFB8C3FF) +val Blue90 = Color(0xFFDDE1FF) + +val DarkBlue10 = Color(0xFF00036B) +val DarkBlue20 = Color(0xFF000BA6) +val DarkBlue30 = Color(0xFF1026D3) +val DarkBlue40 = Color(0xFF3648EA) +val DarkBlue80 = Color(0xFFBBC2FF) +val DarkBlue90 = Color(0xFFDEE0FF) + +val Yellow10 = Color(0xFF261900) +val Yellow20 = Color(0xFF402D00) +val Yellow30 = Color(0xFF5C4200) +val Yellow40 = Color(0xFF7A5900) +val Yellow80 = Color(0xFFFABD1B) +val Yellow90 = Color(0xFFFFDE9C) + +val Red10 = Color(0xFF410001) +val Red20 = Color(0xFF680003) +val Red30 = Color(0xFF930006) +val Red40 = Color(0xFFBA1B1B) +val Red80 = Color(0xFFFFB4A9) +val Red90 = Color(0xFFFFDAD4) + +val Grey10 = Color(0xFF191C1D) +val Grey20 = Color(0xFF2D3132) +val Grey80 = Color(0xFFC4C7C7) +val Grey90 = Color(0xFFE0E3E3) +val Grey95 = Color(0xFFEFF1F1) +val Grey99 = Color(0xFFFBFDFD) + +val BlueGrey30 = Color(0xFF45464F) +val BlueGrey50 = Color(0xFF767680) +val BlueGrey60 = Color(0xFF90909A) +val BlueGrey80 = Color(0xFFC6C5D0) +val BlueGrey90 = Color(0xFFE2E1EC) \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt new file mode 100644 index 0000000..cbc6156 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt @@ -0,0 +1,107 @@ +package ai.mlc.mlcchat.ui.theme + +import android.app.Activity +import android.os.Build +import androidx.compose.foundation.isSystemInDarkTheme +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.darkColorScheme +import androidx.compose.material3.dynamicDarkColorScheme +import androidx.compose.material3.dynamicLightColorScheme +import androidx.compose.material3.lightColorScheme +import androidx.compose.runtime.Composable +import androidx.compose.runtime.SideEffect +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.toArgb +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalView +import androidx.core.view.WindowCompat + +private val DarkColorScheme = darkColorScheme( + primary = Blue80, + onPrimary = Blue20, + primaryContainer = Blue30, + onPrimaryContainer = Blue90, + inversePrimary = Blue40, + secondary = DarkBlue80, + onSecondary = DarkBlue20, + secondaryContainer = DarkBlue30, + onSecondaryContainer = DarkBlue90, + tertiary = Yellow80, + onTertiary = Yellow20, + tertiaryContainer = Yellow30, + onTertiaryContainer = Yellow90, + error = Red80, + onError = Red20, + errorContainer = Red30, + onErrorContainer = Red90, + background = Grey10, + onBackground = Grey90, + surface = Grey10, + onSurface = Grey80, + inverseSurface = Grey90, + inverseOnSurface = Grey20, + surfaceVariant = BlueGrey30, + onSurfaceVariant = BlueGrey80, + outline = BlueGrey60 +) + +private val LightColorScheme = lightColorScheme( + primary = Blue40, + onPrimary = Color.White, + primaryContainer = Blue90, + onPrimaryContainer = Blue10, + inversePrimary = Blue80, + secondary = DarkBlue40, + onSecondary = Color.White, + secondaryContainer = DarkBlue90, + onSecondaryContainer = DarkBlue10, + tertiary = Yellow40, + onTertiary = Color.White, + tertiaryContainer = Yellow90, + onTertiaryContainer = Yellow10, + error = Red40, + onError = Color.White, + errorContainer = Red90, + onErrorContainer = Red10, + background = Grey99, + onBackground = Grey10, + surface = Grey99, + onSurface = Grey10, + inverseSurface = Grey20, + inverseOnSurface = Grey95, + surfaceVariant = BlueGrey90, + onSurfaceVariant = BlueGrey30, + outline = BlueGrey50 +) + +@Composable +fun MLCChatTheme( + darkTheme: Boolean = isSystemInDarkTheme(), + // Dynamic color is available on Android 12+ + dynamicColor: Boolean = true, + content: @Composable () -> Unit +) { + val colorScheme = when { + dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> { + val context = LocalContext.current + if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context) + } + + darkTheme -> DarkColorScheme + else -> LightColorScheme + } + val view = LocalView.current + if (!view.isInEditMode) { + SideEffect { + val window = (view.context as Activity).window + window.statusBarColor = colorScheme.primary.toArgb() + WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme + } + } + + MaterialTheme( + colorScheme = colorScheme, + typography = Typography, + content = content + ) +} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt new file mode 100644 index 0000000..30e70c2 --- /dev/null +++ b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt @@ -0,0 +1,34 @@ +package ai.mlc.mlcchat.ui.theme + +import androidx.compose.material3.Typography +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.sp + +// Set of Material typography styles to start with +val Typography = Typography( + bodyLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 16.sp, + lineHeight = 24.sp, + letterSpacing = 0.5.sp + ) + /* Other default text styles to override + titleLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 22.sp, + lineHeight = 28.sp, + letterSpacing = 0.sp + ), + labelSmall = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Medium, + fontSize = 11.sp, + lineHeight = 16.sp, + letterSpacing = 0.5.sp + ) + */ +) \ No newline at end of file diff --git a/android/app/src/main/res/drawable/ic_android_black_24dp.xml b/android/app/src/main/res/drawable/ic_android_black_24dp.xml new file mode 100644 index 0000000..fe51230 --- /dev/null +++ b/android/app/src/main/res/drawable/ic_android_black_24dp.xml @@ -0,0 +1,5 @@ + + + diff --git a/android/app/src/main/res/drawable/mlc_logo_108.xml b/android/app/src/main/res/drawable/mlc_logo_108.xml new file mode 100644 index 0000000..d5307e0 --- /dev/null +++ b/android/app/src/main/res/drawable/mlc_logo_108.xml @@ -0,0 +1,11 @@ + + + diff --git a/android/app/src/main/res/values/colors.xml b/android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000..f8c6127 --- /dev/null +++ b/android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ + + + #FFBB86FC + #FF6200EE + #FF3700B3 + #FF03DAC5 + #FF018786 + #FF000000 + #FFFFFFFF + \ No newline at end of file diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000..5a127eb --- /dev/null +++ b/android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + MLCChat++ + \ No newline at end of file diff --git a/android/app/src/main/res/values/themes.xml b/android/app/src/main/res/values/themes.xml new file mode 100644 index 0000000..a16e9d4 --- /dev/null +++ b/android/app/src/main/res/values/themes.xml @@ -0,0 +1,6 @@ + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/img/project-structure.svg b/docs/_static/img/project-structure.svg new file mode 100644 index 0000000..e4ad7db --- /dev/null +++ b/docs/_static/img/project-structure.svg @@ -0,0 +1,1189 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/community/faq.rst b/docs/community/faq.rst new file mode 100644 index 0000000..3913dd9 --- /dev/null +++ b/docs/community/faq.rst @@ -0,0 +1,16 @@ +.. _FAQ: + +Frequently Asked Questions +========================== + +This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! + +... How can I customize the temperature, and repetition penalty of models? + Please check our :doc:`/get_started/mlc_chat_config` tutorial. + +... What's the quantization algorithm MLC-LLM using? + Please check our :doc:`/compilation/configure_quantization` tutorial. + +... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation? + This happens if you compiled TVM-Unity from source and didn't hide LLVM symbols in cmake configurations. + Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`. diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst new file mode 100644 index 0000000..33e8982 --- /dev/null +++ b/docs/community/guideline.rst @@ -0,0 +1,125 @@ +.. _community_guide: + +Community Guideline +=================== + +.. contents:: + :depth: 2 + :local: + +Welcome to the MLC-LLM community! Just like you, all of us are in awe of the immense power of large language models. +Our goal for MLC-LLM is to foster a project that is driven by an open-source community, working together to democratize +this technology and make it accessible across various devices. We are thrilled to have you as part of our +community and eagerly anticipate your valuable contributions. + + +.. _community_discussion: + +Participate in Community Discussions +------------------------------------ + +We encourage open discussions. If you encounter a bug or have a feature request, please file an issue in MLC-LLM's +GitHub `issue tracker `__. You are encouraged to tag the issue with labels +such as "bug," "feature request," or "iOS" so that the relevant developers can quickly notice your concern. + +Additionally, we have set up a `discord server `__ for online discussions. +While we encourage participation in the Discord server, we also recommend creating a GitHub issue even if the +topic has been discussed there. This ensures that the discussion is archived and searchable for future reference. + +Before submitting an issue, we kindly ask you to check our :doc:`/community/faq` to see if your question has already been answered. + +.. _contribute-to-mlc-llm: + +Contribute to MLC-LLM +--------------------- + +.. _fork-and-create-pull-requests: + +Fork and Create Pull Requests +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Ready to contribute to MLC-LLM? Awesome! We are excited to see you are ready to contribute your code. +The standard way to make changes to MLC-LLM code base is through creating a `pull-request `__, +and we will review your code and merge it to the code base when it is ready. + +The first step to becoming a developer is to `fork `__ the repository to your own +github account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name. + +You can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos) +on GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. + +.. _contribute-new-models: + +Contribute New Models to MLC-LLM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together! + +* If you add a new model variant to MLC-LLM by following our :doc:`/compilation/define_new_models` tutorial. + Please create a pull request to add your model architecture (currently model architectures are placed under + `relax_models `__ folder). + +.. _coding-styles: + +Coding Styles +^^^^^^^^^^^^^ + +For python codes, we generally follow the `PEP8 style guide `__. +The python comments follow `NumPy style `__ python docstrings. +To make things easy, you can use `black `__ to automatically format your python code. + +.. code:: bash + + pip install black + black your_python_file.py + +For C++ codes, we generally follow the `Google C++ style guide `__. +The C++ comments should be `Doxygen compatible `__. +Fo your convenience, you can use `clang-format `__ to automatically format your C++ code. + +.. code:: bash + + clang-format -i your_cpp_file.cpp + +.. _general-development-process: + +General Development Process +--------------------------- + +Everyone in the community is welcome to send patches, documents, and propose new directions to the project. +The key guideline here is to enable everyone in the community to get involved and participate in the decision and development. +We encourage public discussion in different channels, so that everyone in the community can participate +and get informed in developments. + +Code reviews are one of the key ways to ensure the quality of the code. High-quality code reviews prevent technical debt +for long-term and are crucial to the success of the project. A pull request needs to be reviewed before it gets merged. +A committer who has the expertise of the corresponding area would moderate the pull request and merge the code when +it is ready. The corresponding committer could request multiple reviewers who are familiar with the area of the code. +We encourage contributors to request code reviews themselves and help review each other's code -- remember everyone +is volunteering their time to the community, high-quality code review itself costs as much as the actual code +contribution, you could get your code quickly reviewed if you do others the same favor. + +The community should strive to reach a consensus on technical decisions through discussion. We expect committers to +moderate technical discussions in a diplomatic way, and provide suggestions with clear technical reasoning when necessary. + + +.. _roles-committers: + +Committers +^^^^^^^^^^ + +Committers are individuals who are granted with write access to the project. A committer is usually responsible for +a certain area or several areas of the code where they oversee the code review process. +The area of contribution can take all forms, including code contributions and code reviews, documents, education, and outreach. +The review of pull requests will be assigned to the committers who recently contribute to the area this PR belongs to. +Committers are essential for a high quality and healthy project. The community actively looks for new committers +from contributors. Each existing committer can nominate new committers to MLC projects. + +.. _roles-contributors: + +Contributors +^^^^^^^^^^^^ +We also welcome contributors if you are not ready to be a committer yet. Everyone who contributes to +the project (in the form of code, bugfix, documentation, tutorials, etc) is a contributor. +We maintain a `page `__ to acknowledge contributors, +please let us know if you contribute to the project and if your name is not included in the list. diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst new file mode 100644 index 0000000..e9a3d63 --- /dev/null +++ b/docs/compilation/compile_models.rst @@ -0,0 +1,1056 @@ +.. _compile-model-libraries: + +Compile Model Libraries +======================= + +To run a model with MLC LLM in any platform, you need: + +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC + `_.) +2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). + +If you are simply adding a model variant, follow :ref:`convert-weights-via-MLC` suffices. + +This page describes how to compile a model library with MLC LLM. Model compilation optimizes +the model inference for a given platform, allowing users bring their own new model +architecture, use different quantization modes, and customize the overall model +optimization flow. + +We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for all platforms. + +.. note:: + Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required + backend to compile models with MLC LLM. + + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + the CLI app / Python API that can be used to chat with the compiled model. + Finally, we strongly recommend you to read :ref:`project-overview` first to get + familiarized with the high-level terminologies. + +.. contents:: Table of Contents + :depth: 1 + :local: + +0. Verify Installation +---------------------- + +**Step 1. Verify mlc_chat** + +We use the python package ``mlc_chat`` to compile models. This can be installed by +following :ref:`install-mlc-packages`, either by building from source, or by +installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: + +.. code:: bash + + $ mlc_chat --help + # You should see help information with this line + usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} + +.. note:: + If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + +**Step 2. Verify TVM** + +To compile models, you also need to follow :ref:`install-tvm-unity`. +Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): + +.. code:: bash + + $ python -c "import tvm; print(tvm.__file__)" + /some-path/lib/python3.11/site-packages/tvm/__init__.py + +1. Clone from HF and convert_weight +----------------------------------- + +This replicates :ref:`convert-weights-via-MLC`, see that page for more details. + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + +2. Generate mlc-chat-config and compile +--------------------------------------- + +A model library is specified by: + + - The model architecture (e.g. ``llama-2``, ``gpt-neox``) + - Quantization (e.g. ``q4f16_1``, ``q0f32``) + - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning + - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) + +All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. + +.. code:: shell + + # Create output directory for the model library compiled + mkdir dist/libs + +.. tabs:: + + .. group-tab:: Linux - CUDA + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device cuda -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so + + + .. group-tab:: Metal + + For M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so + + Cross-Compiling for Intel Mac on M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib + + For Intel Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib + + + .. group-tab:: Vulkan + + For Linux: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so + + For Windows: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.dll + + .. group-tab:: iOS/iPadOS + + You need a Mac to compile models for it. + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ + --conv-template redpajama_chat --context-window-size 768 \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device iphone -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar + + .. note:: + If it runs into error + + .. code:: text + + Compilation error: + xcrun: error: unable to find utility "metal", not a developer tool or in PATH + xcrun: error: unable to find utility "metallib", not a developer tool or in PATH + + , please check and make sure you have Command Line Tools for Xcode installed correctly. + You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling. + + .. group-tab:: Android + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ + --conv-template redpajama_chat --context-window-size 768 \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device android -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar + + .. group-tab:: WebGPU + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm + + .. note:: + To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + Otherwise, it would run into error + + .. code:: text + + RuntimeError: Cannot find libraries: wasm_runtime.bc + + .. note:: + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + Otherwise, you may run into issues like: + + .. code:: text + + TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from + 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. + +.. note:: + + For the ``conv-template``, `conv_template.cc `__ + contains a full list of conversation templates that MLC provides. If the model you are adding + requires a new conversation template, you would need to add your own. + Follow `this PR `__ as an example. + However, adding your own template would require you :ref:`build mlc_chat from source ` + in order for it to be recognized by the runtime. + + For more details, please see :ref:`configure-mlc-chat-json`. + +3. Verify output and chat +------------------------- + +By executing the compile command above, we generate the model weights, model lib, and a chat config. +We can check the output with the commands below: + +.. tabs:: + + .. group-tab:: Linux - CUDA + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so # ===> the model library + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + We can now chat with the model using the command line interface (CLI) app or the Python API. + + .. code:: shell + + python + >>> from mlc_chat import ChatModule + >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ + model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") + >>> cm.generate("hi") + 'Hi! How can I assist you today?' + + .. group-tab:: Metal + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so # ===> the model library (will be -metal_x86_64.dylib for Intel Mac) + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + We can now chat with the model using the command line interface (CLI) app or the Python API. + + .. code:: shell + + python + >>> from mlc_chat import ChatModule + >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ + model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") + >>> cm.generate("hi") + 'Hi! How can I assist you today?' + + + .. group-tab:: Vulkan + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so # ===> the model library (will be .dll for Windows) + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + We can now chat with the model using the command line interface (CLI) app or the Python API. + + .. code:: shell + + python + >>> from mlc_chat import ChatModule + >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ + model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") + >>> cm.generate("hi") + 'Hi! How can I assist you today?' + + .. group-tab:: iOS/iPadOS + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar # ===> the model library + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar`` + will be packaged as a static library into the iOS app. Checkout :ref:`deploy-ios` for more details. + + .. group-tab:: Android + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar # ===> the model library + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar`` + will be packaged as a static library into the android app. Checkout :ref:`deploy-android` for more details. + + .. group-tab:: WebGPU + + .. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm # ===> the model library + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + + To use this in WebGPU runtime, checkout :ref:`webllm-runtime`. + +Compile Commands for More Models +-------------------------------- + +This section lists compile commands for more models that you can try out. Note that this can be easily +generalized to any model variant, as long as mlc-llm supports the architecture. + +.. tabs:: + + .. tab:: Model: Llama-2-7B + + Please `request for access `_ to the Llama-2 weights from Meta first. + After granted access, first create directory ``dist/models`` and download the model to the directory. + For example, you can run the following code: + + .. code:: shell + + mkdir -p dist/models && cd dist/models + git lfs install + git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf + cd ../.. + + Then convert the HF weights into MLC-compatible weights. Note that all platforms + can share the same compiled/quantized weights. + + .. code:: shell + + mlc_chat convert_weight ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC + + Afterwards, run the following command to generate mlc config and compile the model. + + .. code:: shell + + # Create output directory for the model library compiled + mkdir dist/libs + + .. tabs:: + + .. tab:: Target: CUDA + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so + + .. tab:: Metal + + For M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal.so + + Cross-Compiling for Intel Mac on M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib + + For Intel Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal_x86_64.dylib + + .. tab:: Vulkan + + For Linux: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.so + + For Windows: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.dll + + .. tab:: WebGPU + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --context-window-size 2048 --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device webgpu -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-webgpu.wasm + + .. note:: + To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + Otherwise, it would run into error + + .. code:: text + + RuntimeError: Cannot find libraries: wasm_runtime.bc + + .. tab:: iPhone/iPad + + You need a Mac to compile models for it. + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device iphone -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-iphone.tar + + .. tab:: Android + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device android -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-android.tar + + .. tab:: Mistral-7B-Instruct-v0.2 + + Note that Mistral uses sliding window attention (SWA). Thus, instead of specifying + ``context-window-size``, we specify ``sliding-window-size``. + + First create directory ``dist/models`` and download the model to the directory. + For example, you can run the following code: + + .. code:: shell + + mkdir -p dist/models && cd dist/models + git lfs install + git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 + cd ../.. + + Then convert the HF weights into MLC-compatible weights. Note that all platforms + can share the same compiled/quantized weights. + + .. code:: shell + + mlc_chat convert_weight ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC + + Afterwards, run the following command to generate mlc config and compile the model. + + .. code:: shell + + # Create output directory for the model library compiled + mkdir dist/libs + + .. tabs:: + + .. tab:: Target: CUDA + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device cuda -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so + + .. tab:: Metal + + For M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal.so + + + For Intel Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal_x86_64.dylib + + .. tab:: Vulkan + + For Linux: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.so + + For Windows: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.dll + + .. tab:: WebGPU + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --prefill-chunk-size 1024 --conv-template mistral_default \ + -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device webgpu -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-webgpu.wasm + + .. note:: + To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + Otherwise, it would run into error + + .. code:: text + + RuntimeError: Cannot find libraries: wasm_runtime.bc + + .. note:: + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + Otherwise, you may run into issues like: + + .. code:: text + + TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from + 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. + + .. tab:: iPhone/iPad + + You need a Mac to compile models for it. + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 \ + -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device iphone -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-iphone.tar + + .. tab:: Android + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + --device android -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar + + .. tab:: Other models + + First create directory ``dist/models`` and download the model to the directory. + For example, you can run the following code: + + .. code:: shell + + mkdir -p dist/models && cd dist/models + git lfs install + git clone https://huggingface.co/DISTRIBUTOR/HF_MODEL + cd ../.. + + Then convert the HF weights into MLC-compatible weights. Note that all platforms + can share the same compiled/quantized weights. + + .. code:: shell + + mlc_chat convert_weight ./dist/models/HF_MODEL/ --quantization q4f16_1 -o dist/OUTPUT-MLC + + Afterwards, run the following command to generate mlc config and compile the model. + + .. code:: shell + + # Create output directory for the model library compiled + mkdir dist/libs + + .. tabs:: + + .. tab:: Target: CUDA + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device cuda -o dist/libs/OUTPUT-cuda.so + + .. tab:: Metal + + For M-chip Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal.so + + + For Intel Mac: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal_x86_64.dylib + + .. tab:: Vulkan + + For Linux: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.so + + For Windows: + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.dll + + .. tab:: WebGPU + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device webgpu -o dist/libs/OUTPUT-webgpu.wasm + + .. note:: + To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + Otherwise, it would run into error + + .. code:: text + + RuntimeError: Cannot find libraries: wasm_runtime.bc + + .. note:: + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + Otherwise, you may run into issues like: + + .. code:: text + + TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from + 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. + + .. tab:: iPhone/iPad + + You need a Mac to compile models for it. + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ + --context-window-size 768 -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device iphone -o dist/libs/OUTPUT-iphone.tar + + .. tab:: Android + + .. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ + --context-window-size 768 -o dist/OUTPUT-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device android -o dist/libs/OUTPUT-android.tar + +For each model and each backend, the above only provides the most recommended build command (which is the most optimized). +You can also try with different argument values (e.g., different quantization modes, context window size, etc.), +whose build results affect runtime memory requirement, and it is possible that they may not run as +fast and robustly as the provided one when running the model. + +.. note:: + Uing 3-bit quantization usually can be overly aggressive and only works for limited settings. + If you encounter issues where the compiled model does not perform as expected, + consider utilizing a higher number of bits for quantization (e.g., 4-bit quantization). + +If you are interested in distributing the model besides local execution, please checkout :ref:`distribute-compiled-models`. + + +.. _compile-command-specification: + +Compile Command Specification +----------------------------- + +As you have seen in the section above, the model compilation is split into three steps: convert weights, generate +``mlc-chat-config.json``, and compile the model. This section describes the list of options that can be used +during compilation. + +1. Convert Weight +^^^^^^^^^^^^^^^^^ + +Weight conversion command follows the pattern below: + +.. code:: text + + mlc_chat convert_weight \ + CONFIG \ + --quantization QUANTIZATION_MODE \ + [--model-type MODEL_TYPE] \ + [--device DEVICE] \ + [--source SOURCE] \ + [--source-format SOURCE_FORMAT] \ + --output OUTPUT + +Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. + +--CONFIG It can be one of the following: + + 1. Path to a HuggingFace model directory that contains a ``config.json`` or + 2. Path to ``config.json`` in HuggingFace format, or + 3. The name of a pre-defined model architecture. + + A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary + size, the number of layers, the hidden size, number of attention heads, etc. + Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. + + A HuggingFace directory often contains a ``config.json`` which defines the model architecture, + the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, + as well as an optional ``generation_config.json`` provides additional default configuration for + text generation. + Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. + + For existing pre-defined model architecture, see ``MODEL_PRESETS`` + `here `_. + +--quantization QUANTIZATION_MODE The quantization mode we use to compile. + + See :ref:`quantization_mode` for more information. + Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and + ``q4f16_awq``. + + We encourage you to use 4-bit quantization, as the text generated by 3-bit + quantized models may have bad quality depending on the model. + +--model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``config.json``. + +--device DEVICE The device used to do quantization such as "cuda" or "cuda:0". Will detect from + local available GPUs if not specified. + +--source SOURCE The path to original model weight, infer from ``config`` if missing. + +--source-format SOURCE_FORMAT The format of source model weight, infer from ``config`` if missing. + +--output OUTPUT The output directory to save the quantized model weight. + Will create ``params_shard_*.bin`` and ```ndarray-cache.json``` in this directory. + +2. Generate MLC Chat Config +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to compile a model, we first need to generate the ``mlc-chat-config.json``. This file contains specifications +like ``context-window-size`` and ``sliding-window-size``, among others that can alter the model compiled. We also process +tokenizers in this step. + +Config generation command follows the pattern below: + +.. code:: text + + mlc_chat gen_config \ + CONFIG \ + --quantization QUANTIZATION_MODE \ + [--model-type MODEL_TYPE] \ + --conv-template CONV_TEMPLATE \ + [--context-window-size CONTEXT_WINDOW_SIZE] \ + [--sliding-window-size SLIDING_WINDOW_SIZE] \ + [--prefill-chunk-size PREFILL_CHUNK_SIZE] \ + [--tensor-parallel-shard TENSOR_PARALLEL_SHARDS] \ + --output OUTPUT + +Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. + +--CONFIG It can be one of the following: + + 1. Path to a HuggingFace model directory that contains a ``config.json`` or + 2. Path to ``config.json`` in HuggingFace format, or + 3. The name of a pre-defined model architecture. + + A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary + size, the number of layers, the hidden size, number of attention heads, etc. + Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. + + A HuggingFace directory often contains a ``config.json`` which defines the model architecture, + the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, + as well as an optional ``generation_config.json`` provides additional default configuration for + text generation. + Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. + + For existing pre-defined model architecture, see ``MODEL_PRESETS`` + `here `_. + +--quantization QUANTIZATION_MODE The quantization mode we use to compile. + + See :ref:`quantization_mode` for more information. + Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and + ``q4f16_awq``. + + We encourage you to use 4-bit quantization, as the text generated by 3-bit + quantized models may have bad quality depending on the model. + +--model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``config.json``. + +--conv-template CONV_TEMPLATE Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model + For existing pre-defined templates, see ``CONV_TEMPLATES`` + `here `_. + +--context-window-size CONTEXT_WINDOW_SIZE Option to provide the maximum sequence length supported by the model. + This is usually explicitly shown as context length or context window in the model card. + If this option is not set explicitly, by default, + it will be determined by ``context_window_size`` or ``max_position_embeddings`` in ``config.json``, + and the latter is usually inaccurate for some models. + +--sliding-window-size SLIDING_WINDOW (Experimental) The sliding window size in sliding window attention (SWA). + This optional field overrides the ``sliding_window`` in ``config.json`` for + those models that use SWA. Currently only useful when compiling mistral-based models. + This flag subjects to future refactoring. + +--prefill-chunk-size PREFILL_CHUNK_SIZE (Experimental) The chunk size during prefilling. By default, + the chunk size is the same as ``context_window_size`` or ``sliding_window_size``. + This flag subjects to future refactoring. + +--tensor-parallel-shard TENSOR_PARALLEL_SHARDS Number of shards to split the model into in tensor parallelism multi-gpu inference. + +--output OUTPUT The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer configuration. + +3. Compile Model Library +^^^^^^^^^^^^^^^^^^^^^^^^ + +After generating ``mlc-chat-config.json``, we can compile the model into a model library (files ending in ``.so``, ``.tar``, etc. that contains +the inference logic of a model). + +Model compilation command follows the pattern below: + +.. code:: text + + mlc_chat compile \ + MODEL \ + [--quantization QUANTIZATION_MODE] \ + [--model-type MODEL_TYPE] \ + [--device DEVICE] \ + [--host HOST] \ + [--opt OPT] \ + [--system-lib-prefix SYSTEM_LIB_PREFIX] \ + --output OUTPUT \ + [--overrides OVERRIDES] + +Note that ``MODEL`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. + +--MODEL A path to ``mlc-chat-config.json``, or an MLC model directory that contains ``mlc-chat-config.json``. + +--quantization QUANTIZATION_MODE The quantization mode we use to compile. If unprovided, will infer from ``MODEL``. + + See :ref:`quantization_mode` for more information. + Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and + ``q4f16_awq``. + + We encourage you to use 4-bit quantization, as the text generated by 3-bit + quantized models may have bad quality depending on the model. + +--model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``mlc-chat-config.json``. + +--device DEVICE The GPU device to compile the model to. If not set, it is inferred from GPUs available locally. + +--host HOST The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS. + Examples of the LLVM triple: + + 1) iPhones: arm64-apple-ios; + 2) ARM64 Android phones: aarch64-linux-android; + 3) WebAssembly: wasm32-unknown-unknown-wasm; + 4) Windows: x86_64-pc-windows-msvc; + 5) ARM macOS: arm64-apple-darwin. + +--opt OPT Optimization flags. MLC LLM maintains a predefined set of optimization flags, + denoted as ``O0``, ``O1``, ``O2``, ``O3``, where ``O0`` means no optimization, ``O2`` + means majority of them, and ``O3`` represents extreme optimization that could + potentially break the system. + + Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. + ``--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"``. + +--system-lib-prefix SYSTEM_LIB_PREFIX Adding a prefix to all symbols exported. Similar to ``objcopy --prefix-symbols``. + This is useful when compiling multiple models into a single library to avoid symbol + conflicts. Different from objcopy, this takes no effect for shared library. + + +--output OUTPUT The path to the output file. The suffix determines if the output file is a shared library or + objects. Available suffixes: + + 1) Linux: .so (shared), .tar (objects); + 2) macOS: .dylib (shared), .tar (objects); + 3) Windows: .dll (shared), .tar (objects); + 4) Android, iOS: .tar (objects); + 5) Web: .wasm (web assembly). + +--overrides OVERRIDES Model configuration override. Configurations to override ``mlc-chat-config.json``. Supports + ``context_window_size``, ``prefill_chunk_size``, ``sliding_window``, ``max_batch_size`` and + ``tensor_parallel_shards``. Meanwhile, model config could be explicitly specified via details + knobs, e.g. ``--overrides "context_window_size=1024;prefill_chunk_size=128"``. diff --git a/docs/compilation/configure_quantization.rst b/docs/compilation/configure_quantization.rst new file mode 100644 index 0000000..d66f841 --- /dev/null +++ b/docs/compilation/configure_quantization.rst @@ -0,0 +1,22 @@ +🚧 Configure Quantization +========================= + +Quantization Algorithm +---------------------- + +The default quantization algorithm used in MLC-LLM is grouping quantization method discussed in the papers `The case for 4-bit precision: k-bit Inference Scaling Laws `__ and `LUT-GEMM: Quantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models `__. + +.. _quantization_mode: + +Quantization Mode +----------------- + +In MLC-LLM we use a short code that indicates the quantization mode to use. + +The format of the code is ``qAfB(_id)``, where ``A`` represents the number +of bits for storing weights and ``B`` represents the number of bits for storing activations. +The ``_id`` is an integer identifier to distinguish different quantization algorithms (e.g. symmetric, non-symmetric, AWQ, etc). + +Currently, available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and ``q4f16_awq`` (not stable). + +More details to come. \ No newline at end of file diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst new file mode 100644 index 0000000..6b39cf8 --- /dev/null +++ b/docs/compilation/convert_weights.rst @@ -0,0 +1,183 @@ +.. _convert-weights-via-MLC: + +Convert Weights via MLC +======================= + +To run a model with MLC LLM in any platform, you need: + +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC + `_.) +2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). + +In many cases, we only need to convert weights and reuse existing model library. +This page demonstrates adding a model variant with ``mlc_chat convert_weight``, which +takes a hugginface model as input and converts/quantizes into MLC-compatible weights. + +Specifically, we add RedPjama-INCITE-**Instruct**-3B-v1, while MLC already +provides a model library for RedPjama-INCITE-**Chat**-3B-v1, which we can reuse. + +This can be extended to, e.g.: + +- Add ``OpenHermes-Mistral`` when MLC already supports Mistral +- Add ``Llama-2-uncensored`` when MLC already supports Llama-2 + +.. note:: + Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required + backend to compile models with MLC LLM. + + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + the CLI app / Python API that can be used to chat with the compiled model. + Finally, we strongly recommend you to read :ref:`project-overview` first to get + familiarized with the high-level terminologies. + +.. contents:: Table of Contents + :depth: 1 + :local: + +.. _verify_installation_for_compile: + +0. Verify installation +---------------------- + +**Step 1. Verify mlc_chat** + +We use the python package ``mlc_chat`` to compile models. This can be installed by +following :ref:`install-mlc-packages`, either by building from source, or by +installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: + +.. code:: bash + + $ mlc_chat --help + # You should see help information with this line + usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} + +.. note:: + If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + +**Step 2. Verify TVM** + +To compile models, you also need to follow :ref:`install-tvm-unity`. +Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): + +.. code:: bash + + $ python -c "import tvm; print(tvm.__file__)" + /some-path/lib/python3.11/site-packages/tvm/__init__.py + + +1. Clone from HF and convert_weight +----------------------------------- + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. See :ref:`compile-command-specification` +for specification of ``convert_weight``. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1 + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ + --quantization q4f16_1 \ + -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC + +.. _generate_mlc_chat_config: + +2. Generate MLC Chat Config +--------------------------- + +Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +See :ref:`compile-command-specification` for specification of ``gen_config``. + +.. code:: shell + + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/ + + +.. note:: + The file ``mlc-chat-config.json`` is crucial in both model compilation + and runtime chatting. Here we only care about the latter case. + + You can **optionally** customize + ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). + You can also simply use the default configuration. + + `conv_template.cc `__ + contains a full list of conversation templates that MLC provides. If the model you are adding + requires a new conversation template, you would need to add your own. + Follow `this PR `__ as an example. However, + adding your own template would require you :ref:`build mlc_chat from source ` in order for it + to be recognized by the runtime. + +By now, you should have the following files. + +.. code:: shell + + ~/mlc-llm > ls dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + +.. _distribute-compiled-models: + +(Optional) 3. Upload weights to HF +---------------------------------- + +Optionally, you can upload what we have to huggingface. + +.. code:: shell + + # First, please create a repository on Hugging Face. + # With the repository created, run + git lfs install + git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo + cd my-redpajama3b-weight-huggingface-repo + cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* . + git add . && git commit -m "Add redpajama-3b instruct model weights" + git push origin main + +This would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC +`_, but +for **Instruct** instead of **Chat**. + +Good job, you have successfully distributed the model you compiled. +Next, we will talk about how we can consume the model weights in applications. + +Download the Distributed Models and Run in Python +------------------------------------------------- + +Running the distributed models are similar to running prebuilt model weights and libraries in :ref:`Model Prebuilts`. + +.. code:: shell + + # Clone prebuilt libs so we can reuse them: + mkdir -p dist/ + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + + # Or download the model library (only needed if we do not reuse the model lib): + cd dist/prebuilt_libs + wget url-to-my-model-lib + cd ../.. + + # Download the model weights + cd dist + git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC + cd .. + + # Run the model in Python; note that we reuse `-Chat` model library + python + >>> from mlc_chat import ChatModule + >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ + model_lib_path="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend + >>> cm.generate("hi") + 'Hi! How can I assist you today?' diff --git a/docs/compilation/define_new_models.rst b/docs/compilation/define_new_models.rst new file mode 100644 index 0000000..4c73864 --- /dev/null +++ b/docs/compilation/define_new_models.rst @@ -0,0 +1,25 @@ +Define New Model Architectures +============================== + +This page guides you how to add a new model architecture in MLC. + +This notebook (runnable in Colab) should contain all necessary information to add a model in +MLC LLM: +https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_add_new_model_architecture_in_tvm_nn_module.ipynb + +In the notebook, we leverage ``tvm.nn.module`` to define a model in MLC LLM. We also use ``JIT`` +(just-in-time compilation) to debug the implementation. + +You can also refer to the PRs below on specific examples of adding a model architecture in MLC LLM: + +- `GPTNeoX PR `_ +- `GPT-2 PR `_ +- `Mistral PR `_ + +.. note:: + + As mentioned in :ref:`Model Prebuilts`, when adding a model variant that has + its architecture already supported in mlc-llm , you **only need to convert weights** + (e.g. adding ``CodeLlama`` when MLC supports ``llama-2``; adding ``OpenHermes Mistral`` + when MLC supports ``mistral``). On the other hand, a new model architecture + (or inference logic) requires more work (following the tutorial above). \ No newline at end of file diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst new file mode 100644 index 0000000..2ea4ba5 --- /dev/null +++ b/docs/compilation/get-vicuna-weight.rst @@ -0,0 +1,68 @@ +Getting Vicuna Weights +====================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. + +Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves. + +In this tutorial, we will show how to apply the delta weights to LLaMA weights to get Vicuna weights. + +Install FastChat +---------------- + +FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip. + +.. code-block:: bash + + pip install fschat + +Download HuggingFace LLaMA Weights +---------------------------------- + +The HuggingFace LLaMA weights are hosted using Git-LFS. Therefore, it is necessary to install Git-LFS first (you can ignore this step if git-lfs is already installed). + +.. code-block:: bash + + conda install git-lfs + git lfs install + +Then download the weights (both the LLaMA weight and Vicuna delta weight): + +.. code-block:: bash + + git clone https://huggingface.co/decapoda-research/llama-7b-hf + git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 + + +There is a name misalignment issue in the LLaMA weights and Vicuna delta weights. +Please follow these steps to modify the content of the "config.json" file: + +.. code-block:: bash + + sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json + +Then use ``fschat`` to apply the delta to LLaMA weights + +.. code-block:: bash + + python3 -m fastchat.model.apply_delta \ + --base-model-path llama-7b-hf \ + --target-model-path vicuna-7b-v1.1 \ + --delta-path vicuna-7b-delta-v1.1 + +You will get the Vicuna weights in ``vicuna-7b-v1.1`` folder, which can be used as input of MLC-LLM to further compile models. + + +(Optional) Move Vicuna Weights to dist folder +--------------------------------------------- + +The default model path of MLC-LLM is ``dist`` folder. Therefore, it is recommended to move the Vicuna weights to ``dist`` folder. + +.. code-block:: bash + + mkdir -p dist/models + mv vicuna-7b-v1.1 dist/models/vicuna-7b-v1.1 diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..0f7ed19 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +import os +import sys + +import tlcpack_sphinx_addon + +# -- General configuration ------------------------------------------------ + +sys.path.insert(0, os.path.abspath("../python")) +sys.path.insert(0, os.path.abspath("../")) +autodoc_mock_imports = ["torch"] +# do not load mlc-llm.so in docs +os.environ["SKIP_LOADING_MLCLLM_SO"] = "1" + +# General information about the project. +project = "mlc-llm" +author = "MLC LLM Contributors" +copyright = "2023, %s" % author + +# Version information. + +version = "0.1.0" +release = "0.1.0" + +extensions = [ + "sphinx_tabs.tabs", + "sphinx_toolbox.collapse", + "sphinxcontrib.httpdomain", + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx_reredirects", +] + +redirects = {"get_started/try_out": "../index.html#getting-started"} + +source_suffix = [".rst"] + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +# -- Options for HTML output ---------------------------------------------- + +# The theme is set by the make target +import sphinx_rtd_theme + +html_theme = "sphinx_rtd_theme" +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +templates_path = [] + +html_static_path = [] + +footer_copyright = "© 2023 MLC LLM" +footer_note = " " + +html_logo = "_static/img/mlc-logo-with-text-landscape.svg" + +html_theme_options = { + "logo_only": True, +} + +header_links = [ + ("Home", "https://llm.mlc.ai/"), + ("Github", "https://github.com/mlc-ai/mlc-llm"), + ("Discord Server", "https://discord.gg/9Xpy2HGBuD"), +] + +header_dropdown = { + "name": "Other Resources", + "items": [ + ("MLC Course", "https://mlc.ai/"), + ("MLC Blog", "https://blog.mlc.ai/"), + ("Web LLM", "https://webllm.mlc.ai/"), + ], +} + +html_context = { + "footer_copyright": footer_copyright, + "footer_note": footer_note, + "header_links": header_links, + "header_dropdown": header_dropdown, + "display_github": True, + "github_user": "mlc-ai", + "github_repo": "mlc-llm", + "github_version": "main/docs/", + "theme_vcs_pageview_mode": "edit", + # "header_logo": "/path/to/logo", + # "header_logo_link": "", + # "version_selecter": "", +} + + +# add additional overrides +templates_path += [tlcpack_sphinx_addon.get_templates_path()] +html_static_path += [tlcpack_sphinx_addon.get_static_path()] diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst new file mode 100644 index 0000000..7bcda64 --- /dev/null +++ b/docs/deploy/android.rst @@ -0,0 +1,187 @@ +.. _deploy-android: + +Android App +=========== + +.. contents:: Table of Contents + :local: + :depth: 2 + +Demo App +-------- + +The demo APK below is built for Samsung S23 with Snapdragon 8 Gen 2 chip. + +.. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png + :width: 135 + :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk + +Prerequisite +------------ + +**Rust** (`install `__) is needed to cross-compile HuggingFace tokenizers to Android. Make sure rustc, cargo, and rustup are available in ``$PATH``. + +**Android Studio** (`install `__) with NDK and CMake. To install NDK and CMake, in the Android Studio welcome page, click "Projects → SDK Manager → SDK Tools". Set up the following environment variables: + +- ``ANDROID_NDK`` so that ``$ANDROID_NDK/build/cmake/android.toolchain.cmake`` is available. +- ``TVM_NDK_CC`` that points to NDK's clang compiler. + +.. code-block:: bash + + # Example on macOS + ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 + TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android24-clang + # Example on Windows + ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 + TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang + +**JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. It could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. Set up the following environment variable: + +- ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. + +Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. We recommended setting the `JAVA_HOME` to the JDK bundled with Android Studio. e.g. `export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home` for macOS. + +**TVM Unity runtime** is placed under `3rdparty/tvm `__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable: + +- ``TVM_HOME`` so that its headers are available under ``$TVM_HOME/include/tvm/runtime``. + +(Optional) **TVM Unity compiler** Python package (:ref:`install ` or :ref:`build from source `). It is *NOT* required if models are prebuilt, but to compile PyTorch models from HuggingFace in the following section, the compiler is a must-dependency. + +.. note:: + ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + +Check if **environment variable** are properly set as the last check. One way to ensure this is to place them in ``$HOME/.zshrc``, ``$HOME/.bashrc`` or environment management tools. + +.. code-block:: bash + + source $HOME/.cargo/env # Rust + export ANDROID_NDK=... # Android NDK toolchain + export TVM_NDK_CC=... # Android NDK clang + export JAVA_HOME=... # Java + export TVM_HOME=... # TVM Unity runtime + +Compile PyTorch Models from HuggingFace +--------------------------------------- + +To deploy models on Android with reasonable performance, one has to cross-compile to and fully utilize mobile GPUs using TVM Unity. MLC provides a few pre-compiled models, or one could compile the models on their own. + +**Cloning MLC LLM from GitHub**. Download MLC LLM via the following command: + +.. code-block:: bash + + git clone --recursive https://github.com/mlc-ai/mlc-llm/ + ^^^^^^^^^^^ + cd ./mlc-llm/ + +.. note:: + ❗ The ``--recursive`` flag is necessary to download submodules like `3rdparty/tvm `__. If you see any file missing during compilation, please double check if git submodules are properly cloned. + +**Download the PyTorch model** using Git Large File Storage (LFS), and by default, under ``./dist/models/``: + +.. code-block:: bash + + MODEL_NAME=Llama-2-7b-chat-hf + QUANTIZATION=q4f16_1 + + git lfs install + git clone https://huggingface.co/meta-llama/$MODEL_NAME \ + ./dist/models/ + +**Compile Android-capable models**. Install TVM Unity compiler as a Python package, and then compile the model for android using the following commands: + +.. code-block:: bash + + # convert weights + mlc_chat convert_weight ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION -o dist/$MODEL_NAME-$QUANTIZATION-MLC/ + + # create mlc-chat-config.json + mlc_chat gen_config ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION \ + --conv-template llama-2 --context-window-size 768 -o dist/${MODEL_NAME}-${QUANTIZATION}-MLC/ + + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/mlc-chat-config.json \ + --device android -o ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/${MODEL_NAME}-${QUANTIZATION}-android.tar + +This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION-MLC`` which contains the necessary components to run the model, as explained below. + +.. note:: + ❗ To run 7B models like llama-2-7B, Mistral-7B, it is recommended to use smaller values of parameter ``--context-window-size`` (``--sliding-window-size`` and ``--prefill-chunk-size`` for sliding window attention) to reduce the memory footprint of the model. Default configurations for certains models can be found under the Android tab in the `Compile Models `_ section. + +**Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}-MLC``, and the result consists of 3 major components: + +- Runtime configuration: It configures conversation templates including system prompts, repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` alongside with tokenizer configurations. +- Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_1-android.tar``. +- Model weights: the model weights are sharded as ``params_shard_*.bin`` and the metadata is stored in ``ndarray-cache.json`` + +Create Android Project using Compiled Models +-------------------------------------------- + +The source code for MLC LLM is available under ``android/``, including scripts to build dependencies. Enter the directory first: + +.. code-block:: bash + + cd ./android/library + +**Build necessary dependencies.** Configure the list of models the app comes with using the JSON file ``app-config.json`` which contains two properties `model_list` and `model_lib_path_for_prepare_libs` ``model_lib_path_for_prepare_libs`` contains list of model library paths under `./dist/` that will be bundled with the apk. The ``model_list`` property contains data for models that are not bundled with the apk, but downloaded from the internet at run-time. Each model defined in `model_list` contain the following fields: + +``model_url`` + (Required) URL to the repo containing the weights. + +``model_id`` + (Required) Unique local identifier to identify the model. + +``model_lib`` + (Required) Matches the system-lib-prefix, generally set during ``mlc_chat compile`` which can be specified using + ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during ``mlc_chat compile``, the ``model_lib`` field should be updated accordingly. + +``estimated_vram_bytes`` + (Optional) Estimated requirements of VRAM to run the model. + +To change the configuration, edit ``app-config.json``: + +.. code-block:: bash + + vim ./src/main/assets/app-config.json + +Then bundle the android library ``${MODEL_NAME}-${QUANTIZATION}-android.tar`` compiled from ``mlc_chat compile`` in the previous steps, with TVM Unity's Java runtime by running the commands below: + +.. code-block:: bash + + ./prepare_libs.sh + +which generates the two files below: + +.. code-block:: bash + + >>> find ./build/output -type f + ./build/output/arm64-v8a/libtvm4j_runtime_packed.so + ./build/output/tvm4j_core.jar + +The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``, while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ to it. + +**Build the Android app**. Open folder ``./android`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone. + +.. note:: + ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. + +Incorporate Model Weights +------------------------- + +Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. + +**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/app/release/app-release.apk``. + +**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: + +.. code-block:: bash + + adb devices + +**Install the APK and weights to your phone**. Run the commands below replacing ``${MODEL_NAME}`` and ``${QUANTIZATION}`` with the actual model name (e.g. Llama-2-7b-chat-hf) and quantization format (e.g. q4f16_1). + +.. code-block:: bash + + adb install android/app/release/app-release.apk + adb push dist/${MODEL_NAME}-${QUANTIZATION}-MLC /data/local/tmp/${MODEL_NAME}-${QUANTIZATION}/ + adb shell "mkdir -p /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" + adb shell "mv /data/local/tmp/${MODEL_NAME}-${QUANTIZATION} /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst new file mode 100644 index 0000000..83a2a9d --- /dev/null +++ b/docs/deploy/cli.rst @@ -0,0 +1,106 @@ +.. _deploy-cli: + +CLI +=============== + +MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. + +.. contents:: Table of Contents + :local: + :depth: 2 + +Option 1. Conda Prebuilt +~~~~~~~~~~~~~~~~~~~~~~~~ + +The prebuilt package supports Metal on macOS and Vulkan on Linux and Windows, and can be installed via Conda one-liner. + +To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source `. + +.. code:: shell + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + mlc_chat chat -h + +.. note:: + The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source. + + +Option 2. Build MLC Runtime from Source +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We also provide options to build mlc runtime libraries and ``mlc_chat`` from source. +This step is useful if the prebuilt is unavailable on your platform, or if you would like to build a runtime +that supports other GPU runtime than the prebuilt version. We can build a customized version +of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. + +First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). +Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. + +.. `|` adds a blank line + +| + +Run Models through MLCChat CLI +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Once ``mlc_chat`` is installed, you are able to run any MLC-compiled model on the command line. + +To run a model with MLC LLM in any platform, you can either: + +- Use off-the-shelf model prebuilts from the MLC Huggingface repo (see :ref:`Model Prebuilts` for details). +- Use locally compiled model weights and libraries following :doc:`the model compilation page `. + +**Option 1: Use model prebuilts** + +To run ``mlc_chat``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. +For example, to run the MLC Llama 2 7B Q4F16_1 model (`Repo link `_), +simply use ``HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC``. The model weights and library will be downloaded +automatically from Huggingface. + +.. code:: shell + + mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 + +.. code:: shell + + You can use the following special commands: + /help print the special commands + /exit quit the cli + /stats print out the latest stats (token/sec) + /reset restart a fresh chat + /set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). + Multi-line input: Use escape+enter to start a new line. + + [INST]: What's the meaning of life + [/INST]: + Ah, a question that has puzzled philosophers and theologians for centuries! The meaning + of life is a deeply personal and subjective topic, and there are many different + perspectives on what it might be. However, here are some possible answers that have been + proposed by various thinkers and cultures: + ... + + +**Option 2: Use locally compiled model weights and libraries** + +For models other than the prebuilt ones we provided: + +1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral``), + follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. +2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. + +Once you have the model locally compiled with a model library and model weights, to run ``mlc_chat``, simply + +- Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model`` +- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib-path`` + +.. code:: shell + + mlc_chat chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ + --device "cuda:0" --overrides context_window_size=1024 \ + --model-lib-path dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so + # Metal on macOS: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-metal.so + # Same rule applies for other platforms diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst new file mode 100644 index 0000000..0d3b4f6 --- /dev/null +++ b/docs/deploy/ios.rst @@ -0,0 +1,491 @@ +.. _deploy-ios: + +iOS App and Swift API +===================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source. +If you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a +developer seeking to integrate new features into the package, building the iOS package from the source is required. + +Use Pre-built iOS App +--------------------- +The MLC Chat app is now available in App Store at no cost. You can download and explore it by simply clicking the button below: + + .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg + :width: 135 + :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 + + +Build iOS App from Source +------------------------- + +This section shows how we can build the app from the source. + +Step 1. Install Build Dependencies +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First and foremost, please clone the `MLC LLM GitHub repository `_. + +Please follow :doc:`/install/tvm` to install TVM Unity. +Note that we **do not** have to run `build.py` since we can use prebuilt weights. +We only need TVM Unity's utility to combine the libraries (`local-id-iphone.tar`) into a single library. + +We also need to have the following build dependencies: + +* CMake >= 3.24, +* Git and Git-LFS, +* `Rust and Cargo `_, which are required by Hugging Face's tokenizer. + + +Step 2. Download Prebuilt Weights and Library +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You also need to obtain a copy of the MLC-LLM source code +by cloning the `MLC LLM GitHub repository `_. +To simplify the build, we will use prebuilt model +weights and libraries here. Run the following command +in the root directory of the MLC-LLM. + +.. code:: bash + + mkdir -p dist/prebuilt + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib + + cd dist/prebuilt + git lfs install + git clone https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + cd ../.. + +Validate that the files and directories exist: + +.. code:: bash + + >>> ls -l ./dist/prebuilt/lib/*/*-iphone.tar + ./dist/prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar + ./dist/prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar + ... + + >>> ls -l ./dist/prebuilt/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + # chat config: + mlc-chat-config.json + # model weights: + ndarray-cache.json + params_shard_*.bin + ... + + +Step 3. Build Auxiliary Components +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Tokenizer and runtime** + +In addition to the model itself, a lightweight runtime and tokenizer are +required to actually run the LLM. You can build and organize these +components by following these steps: + +.. code:: bash + + git submodule update --init --recursive + cd ./ios + ./prepare_libs.sh + +This will create a ``./build`` folder that contains the following files. +Please make sure all the following files exist in ``./build/``. + +.. code:: bash + + >>> ls ./build/lib/ + libmlc_llm.a # A lightweight interface to interact with LLM, tokenizer, and TVM Unity runtime + libmodel_iphone.a # The compiled model lib + libsentencepiece.a # SentencePiece tokenizer + libtokenizers_cpp.a # Huggingface tokenizer + libtvm_runtime.a # TVM Unity runtime + +**Add prepackage model** + +We can also *optionally* add prepackage weights into the app, +run the following command under the ``./ios`` directory: + +.. code:: bash + + cd ./ios + open ./prepare_params.sh # make sure builtin_list only contains "RedPajama-INCITE-Chat-3B-v1-q4f16_1" + ./prepare_params.sh + +The outcome should be as follows: + +.. code:: bash + + >>> ls ./dist/ + RedPajama-INCITE-Chat-3B-v1-q4f16_1 + +Step 4. Build iOS App +^^^^^^^^^^^^^^^^^^^^^ + +Open ``./ios/MLCChat.xcodeproj`` using Xcode. Note that you will need an +Apple Developer Account to use Xcode, and you may be prompted to use +your own developer team credential and product bundle identifier. + +Ensure that all the necessary dependencies and configurations are +correctly set up in the Xcode project. + +Once you have made the necessary changes, build the iOS app using Xcode. +If you have an Apple Silicon Mac, you can select target "My Mac (designed for iPad)" +to run on your Mac. You can also directly run it on your iPad or iPhone. + +.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg + :align: center + :width: 60% + +| + +Customize the App +----------------- + +We can customize the iOS app in several ways. +`MLCChat/app-config.json `_ +controls the list of local and remote models to be packaged into the app, given a local path or a URL respectively. Only models in ``model_list`` will have their libraries brought into the app when running `./prepare_libs` to package them into ``libmodel_iphone.a``. Each model defined in `app-config.json` contain the following fields: + +``model_path`` + (Required if local model) Name of the local folder containing the weights. + +``model_url`` + (Required if remote model) URL to the repo containing the weights. + +``model_id`` + (Required) Unique local identifier to identify the model. + +``model_lib`` + (Required) Matches the system-lib-prefix, generally set during ``mlc_chat compile`` which can be specified using + ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` + for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during + ``mlc_chat compile``, the ``model_lib`` field should be updated accordingly. + +``required_vram_bytes`` + (Required) Estimated requirements of VRAM to run the model. + +``model_lib_path_for_prepare_libs`` + (Required) List of paths to the model libraries in the app (respective ``.tar`` file in the ``binary-mlc-llm-libs`` + repo, relative path in the ``dist`` artifact folder or full path to the library). Only used while running + ``prepare_libs.sh`` to determine which model library to use during runtime. Useful when selecting a library with + different settings (e.g. ``prefill_chunk_size``, ``context_window_size``, and ``sliding_window_size``). + +Additionally, the app prepackages the models under ``./ios/dist``. +This built-in list can be controlled by editing ``prepare_params.sh``. +You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above. + + +Bring Your Own Model Variant +---------------------------- + +In cases where the model you are adding is simply a variant of an existing +model, we only need to convert weights and reuse existing model library. For instance: + +- Adding ``NeuralHermes`` when MLC already supports the ``Mistral`` architecture + + +In this section, we walk you through adding ``NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC`` to the MLC iOS app. +According to the model's ``config.json`` on `its Huggingface repo `_, +it reuses the Mistral model architecture. + +.. note:: + + This section largely replicates :ref:`convert-weights-via-MLC`. + See that page for more details. Note that the weights are shared across + all platforms in MLC. + +**Step 1 Clone from HF and convert_weight** + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. See :ref:`compile-command-specification` +for specification of ``convert_weight``. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/mlabonne/NeuralHermes-2.5-Mistral-7B + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ + --quantization q4f16_1 \ + -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC + +**Step 2 Generate MLC Chat Config** + +Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +See :ref:`compile-command-specification` for specification of ``gen_config``. + +.. code:: shell + + mlc_chat gen_config ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ + --quantization q3f16_1 --conv-template neural_hermes_mistral \ + -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC + +For the ``conv-template``, `conv_template.cc `__ +contains a full list of conversation templates that MLC provides. + +If the model you are adding requires a new conversation template, you would need to add your own. +Follow `this PR `__ as an example. +We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. + +For more details, please see :ref:`configure-mlc-chat-json`. + +**Step 3 Upload weights to HF** + +.. code:: shell + + # First, please create a repository on Hugging Face. + # With the repository created, run + git lfs install + git clone https://huggingface.co/my-huggingface-account/my-mistral-weight-huggingface-repo + cd my-mistral-weight-huggingface-repo + cp path/to/mlc-llm/dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC/* . + git add . && git commit -m "Add mistral model weights" + git push origin main + +After successfully following all steps, you should end up with a Huggingface repo similar to +`NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC `__, +which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. + + +**Step 4 Register as a ModelRecord** + +Finally, we modify the code snippet for +`app-config.json `__ +pasted above. + +We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib`` for +``Mistral-7B``. + +.. code:: javascript + + "model_list": [ + // Other records here omitted... + { + // Substitute model_url with the one you created `my-huggingface-account/my-mistral-weight-huggingface-repo` + "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", + "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", + "model_lib": "mistral_q3f16_1", + "model_lib_path": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", + "estimated_vram_bytes": 3316000000 + } + ] + + +Now, the app will use the ``NeuralHermes-Mistral`` model you just added. + + +Bring Your Own Model Library +---------------------------- + +A model library is specified by: + + - The model architecture (e.g. ``mistral``, ``phi-msft``) + - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) + - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning + - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) + +In cases where the model you want to run is not compatible with the provided MLC +prebuilt model libraries (e.g. having a different quantization, a different +metadata spec, or even a different model architecture), you need to build your +own model library. + +In this section, we walk you through adding ``phi-2`` to the iOS app. + +This section largely replicates :ref:`compile-model-libraries`. See that page for +more details, specifically the ``iOS`` option. + +**Step 0. Install dependencies** + +To compile model libraries for iOS, you need to :ref:`build mlc_chat from source `. + +**Step 1. Clone from HF and convert_weight** + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/microsoft/phi-2 + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/phi-2/ \ + --quantization q4f16_1 \ + -o dist/phi-2-q4f16_1-MLC + +**Step 2. Generate mlc-chat-config and compile** + +A model library is specified by: + + - The model architecture (e.g. ``mistral``, ``phi-msft``) + - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) + - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning + - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) + +All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. + +.. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/phi-2/ \ + --quantization q4f16_1 --conv-template phi-2 \ + -o dist/phi-2-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ + --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar + +.. note:: + When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size + while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ + to decrease memory usage. Otherwise, during runtime, you may run out of memory. + + +**Step 3. Distribute model library and model weights** + +After following the steps above, you should end up with: + +.. code:: shell + + ~/mlc-llm > ls dist/libs + phi-2-q4f16_1-iphone.tar # ===> the model library + + ~/mlc-llm > ls dist/phi-2-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + +Upload the ``phi-2-q4f16_1-iphone.tar`` to a github repository (for us, +it is in `binary-mlc-llm-libs `__). Then +upload the weights ``phi-2-q4f16_1-MLC`` to a Huggingface repo: + +.. code:: shell + + # First, please create a repository on Hugging Face. + # With the repository created, run + git lfs install + git clone https://huggingface.co/my-huggingface-account/my-phi-weight-huggingface-repo + cd my-phi-weight-huggingface-repo + cp path/to/mlc-llm/dist/phi-2-q4f16_1-MLC/* . + git add . && git commit -m "Add phi-2 model weights" + git push origin main + +This would result in something like `phi-2-q4f16_1-MLC +`_. + + +**Step 4. Calculate estimated VRAM usage** + +Given the compiled library, it is possible to calculate an upper bound for the VRAM +usage during runtime. This useful to better understand if a model is able to fit particular +hardware. We can calculate this estimate using the following command: + +.. code:: shell + + ~/mlc-llm > python -m mlc_chat.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ + > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json + INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) + INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` + + +**Step 5. Register as a ModelRecord** + +Finally, we update the code snippet for +`app-config.json `__ +pasted above. + +We simply specify the Huggingface link as ``model_url``, while using the new ``model_lib`` for +``phi-2``. Regarding the field ``estimated_vram_bytes``, we can use the output of the last step +rounded up to MB. + +.. code:: javascript + + "model_list": [ + // Other records here omitted... + { + // Substitute model_url with the one you created `my-huggingface-account/my-phi-weight-huggingface-repo` + "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", + "model_id": "phi-2-q4f16_1", + "model_lib": "phi_msft_q4f16_1", + "model_lib_path": "lib/phi-2/phi-2-q4f16_1-iphone.tar", + "estimated_vram_bytes": 3043000000 + } + ] + + +Now, the app will use the ``phi-2`` model library you just added. + + +Build Apps with MLC Swift API +----------------------------- + +We also provide a Swift package that you can use to build +your own app. The package is located under `ios/MLCSwift`. + +- First make sure you have run the same steps listed + in the previous section. This will give us the necessary libraries + under ``/path/to/ios/build/lib``. +- Then you can add ``ios/MLCSwift`` package to your app in Xcode. + Under "Frameworks, Libraries, and Embedded Content", click add package dependencies + and add local package that points to ``ios/MLCSwift``. +- Finally, we need to add the libraries dependencies. Under build settings: + + - Add library search path ``/path/to/ios/build/lib``. + - Add the following items to "other linker flags". + + .. code:: + + -Wl,-all_load + -lmodel_iphone + -lmlc_llm -ltvm_runtime + -ltokenizers_cpp + -lsentencepiece + -ltokenizers_c + + +You can then import the `MLCSwift` package into your app. +The following code shows an illustrative example of how to use the chat module. + +.. code:: swift + + import MLCSwift + + let threadWorker = ThreadWorker() + let chat = ChatModule() + + threadWorker.push { + let modelLib = "model-lib-name" + let modelPath = "/path/to/model/weights" + let input = "What is the capital of Canada?" + chat.reload(modelLib, modelPath: modelPath) + + chat.prefill(input) + while (!chat.stopped()) { + displayReply(chat.getMessage()) + chat.decode() + } + } + +.. note:: + + Because the chat module makes heavy use of GPU and thread-local + resources, it needs to run on a dedicated background thread. + Therefore, **avoid using** `DispatchQueue`, which can cause context switching to + different threads and segfaults due to thread-safety issues. + Use the `ThreadWorker` class to launch all the jobs related + to the chat module. You can check out the source code of + the MLCChat app for a complete example. diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst new file mode 100644 index 0000000..06a1d3f --- /dev/null +++ b/docs/deploy/javascript.rst @@ -0,0 +1,360 @@ +.. _webllm-runtime: + +WebLLM and Javascript API +========================= + +.. contents:: Table of Contents + :local: + :depth: 2 + +`WebLLM `_ is an MLC chat web runtime +that allows you to build chat applications directly in the browser, leveraging +`WebGPU `_ and providing users a natural layer of abstraction. + +Try out the Prebuilt Webpage +---------------------------- + +To get started, you can try out `WebLLM prebuilt webpage `__. + +A WebGPU-compatible browser and a local GPU are needed to run WebLLM. +You can download the latest Google Chrome and use `WebGPU Report `__ +to verify the functionality of WebGPU on your browser. + + +Use WebLLM NPM Package +---------------------- + +WebLLM is available as an `npm package `_. +The source code is available in `the WebLLM repo `_, +where you can make your own modifications and build from source. + +Note that the `WebLLM prebuilt webpage `__ above +is powered by the WebLLM npm package, specifically with the code in +the `simple-chat `__ example. + +Each of the model in the `WebLLM prebuilt webpage `__ +is registered as an instance of ``ModelRecord``. Looking at the most straightforward example +`get-started `__, +we see the code snippet: + +.. code:: typescript + + const myAppConfig: AppConfig = { + model_list: [ + { + "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", + "local_id": "Llama-2-7b-chat-hf-q4f32_1", + "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f32_1-ctx4k_cs1k-webgpu.wasm", + }, + { + "model_url": "https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/resolve/main/", + "local_id": "Mistral-7B-Instruct-v0.2-q4f16_1", + "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + "required_features": ["shader-f16"], + }, + // Add your own models here... + ] + } + const selectedModel = "Llama-2-7b-chat-hf-q4f32_1" + // const selectedModel = "Mistral-7B-Instruct-v0.1-q4f16_1" + await chat.reload(selectedModel, undefined, myAppConfig); + +Just like any other platforms, to run a model with on WebLLM, you need: + +1. **Model weights** converted to MLC format (e.g. `Llama-2-7b-hf-q4f32_1-MLC + `_.): downloaded through ``model_url`` +2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__): downloaded through ``model_lib_url``. + +Verify Installation for Adding Models +------------------------------------- + +In sections below, we walk you through two examples of adding models to WebLLM. Before proceeding, +please verify installation of ``mlc_chat`` and ``tvm``: + +**Step 1. Verify mlc_chat** + +We use the python package ``mlc_chat`` to compile models. This can be installed by +following :ref:`install-mlc-packages`, either by building from source, or by +installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: + +.. code:: bash + + $ mlc_chat --help + # You should see help information with this line + usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} + +.. note:: + If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + +**Step 2. Verify TVM** + +To compile models, you also need to follow :ref:`install-tvm-unity`. +Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): + +.. code:: bash + + $ python -c "import tvm; print(tvm.__file__)" + /some-path/lib/python3.11/site-packages/tvm/__init__.py + + +.. _webllm-add-model-variant: + +Bring Your Own Model Variant +---------------------------- + +In cases where the model you are adding is simply a variant of an existing +model, we only need to convert weights and reuse existing model library. For instance: + +- Adding ``OpenMistral`` when MLC supports ``Mistral`` +- Adding ``Llama2-uncensored`` when MLC supports ``Llama2`` + + +In this section, we walk you through adding ``WizardMath-7B-V1.1-q4f16_1`` to the +`get-started `__ example. +According to the model's ``config.json`` on `its Huggingface repo `_, +it reuses the Mistral model architecture. + +.. note:: + + This section largely replicates :ref:`convert-weights-via-MLC`. + See that page for more details. Note that the weights are shared across + all platforms in MLC. + +**Step 1 Clone from HF and convert_weight** + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. See :ref:`compile-command-specification` +for specification of ``convert_weight``. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/WizardLM/WizardMath-7B-V1.1 + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/WizardMath-7B-V1.1/ \ + --quantization q4f16_1 \ + -o dist/WizardMath-7B-V1.1-q4f16_1-MLC + +**Step 2 Generate MLC Chat Config** + +Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +See :ref:`compile-command-specification` for specification of ``gen_config``. + +.. code:: shell + + mlc_chat gen_config ./dist/models/WizardMath-7B-V1.1/ \ + --quantization q4f16_1 --conv-template wizard_coder_or_math \ + -o dist/WizardMath-7B-V1.1-q4f16_1-MLC/ + +For the ``conv-template``, `conv_template.cc `__ +contains a full list of conversation templates that MLC provides. + +If the model you are adding requires a new conversation template, you would need to add your own. +Follow `this PR `__ as an example. Besides, you also need to add the new template to ``/path/to/web-llm/src/conversation.ts``. +We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. + +For more details, please see :ref:`configure-mlc-chat-json`. + +.. note:: + + If you added your conversation template in ``src/conversation.ts``, you need to build WebLLM + from source following the instruction in + `the WebLLM repo's README `_. + + Alternatively, you could use the ``"custom"`` conversation template so that you can pass in + your own ``ConvTemplateConfig`` in runtime without having to build the package from source. + +**Step 3 Upload weights to HF** + +.. code:: shell + + # First, please create a repository on Hugging Face. + # With the repository created, run + git lfs install + git clone https://huggingface.co/my-huggingface-account/my-wizardMath-weight-huggingface-repo + cd my-wizardMath-weight-huggingface-repo + cp path/to/mlc-llm/dist/WizardMath-7B-V1.1-q4f16_1-MLC/* . + git add . && git commit -m "Add wizardMath model weights" + git push origin main + +After successfully following all steps, you should end up with a Huggingface repo similar to +`WizardMath-7B-V1.1-q4f16_1-MLC `__, +which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. + + +**Step 4 Register as a ModelRecord** + +Finally, we modify the code snippet for +`get-started `__ +pasted above. + +We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib_url`` for +``Mistral-7B``. Note that we need the suffix to be ``/resolve/main/``. + +.. code:: typescript + + const myAppConfig: AppConfig = { + model_list: [ + // Other records here omitted... + { + // Substitute model_url with the one you created `my-huggingface-account/my-wizardMath-weight-huggingface-repo` + "model_url": "https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC/resolve/main/", + "local_id": "WizardMath-7B-V1.1-q4f16_1", + "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + "required_features": ["shader-f16"], + }, + ] + } + + const selectedModel = "WizardMath-7B-V1.1-q4f16_1" + await chat.reload(selectedModel, undefined, myAppConfig); + +Now, running the ``get-started`` example will use the ``WizardMath`` model you just added. +See `get-started's README `__ +on how to run it. + + +Bring Your Own Model Library +---------------------------- + +A model library is specified by: + + - The model architecture (e.g. ``llama-2``, ``gpt-neox``) + - Quantization (e.g. ``q4f16_1``, ``q0f32``) + - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning + - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) + +In cases where the model you want to run is not compatible with the provided MLC +prebuilt model libraries (e.g. having a different quantization, a different +metadata spec, or even a different model architecture), you need to build your +own model library. + +In this section, we walk you through adding ``RedPajama-INCITE-Chat-3B-v1`` to the +`get-started `__ example. + +This section largely replicates :ref:`compile-model-libraries`. See that page for +more details, specifically the ``WebGPU`` option. + +**Step 0. Install dependencies** + +To compile model libraries for webgpu, you need to :ref:`build mlc_chat from source `. +Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error: + +.. code:: text + + RuntimeError: Cannot find libraries: wasm_runtime.bc + +**Step 1. Clone from HF and convert_weight** + +You can be under the mlc-llm repo, or your own working directory. Note that all platforms +can share the same compiled/quantized weights. + +.. code:: shell + + # Create directory + mkdir -p dist/models && cd dist/models + # Clone HF weights + git lfs install + git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 + cd ../.. + # Convert weight + mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + +**Step 2. Generate mlc-chat-config and compile** + +A model library is specified by: + + - The model architecture (e.g. ``llama-2``, ``gpt-neox``) + - Quantization (e.g. ``q4f16_1``, ``q0f32``) + - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning + - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) + +All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. + +.. code:: shell + + # 1. gen_config: generate mlc-chat-config.json and process tokenizers + mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + --quantization q4f16_1 --conv-template redpajama_chat \ + -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ + # 2. compile: compile model library with specification in mlc-chat-config.json + mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm + +.. note:: + When compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or + lower ``context_window_size`` to decrease memory usage. Otherwise, during runtime, + you may run into issues like: + + .. code:: text + + TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from + 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. + + +**Step 3. Distribute model library and model weights** + +After following the steps above, you should end up with: + +.. code:: shell + + ~/mlc-llm > ls dist/libs + RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm # ===> the model library + + ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + mlc-chat-config.json # ===> the chat config + ndarray-cache.json # ===> the model weight info + params_shard_0.bin # ===> the model weights + params_shard_1.bin + ... + tokenizer.json # ===> the tokenizer files + tokenizer_config.json + +Upload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm`` to a github repository (for us, +it is in `binary-mlc-llm-libs `__). Then +upload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC`` to a Huggingface repo: + +.. code:: shell + + # First, please create a repository on Hugging Face. + # With the repository created, run + git lfs install + git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo + cd my-redpajama3b-weight-huggingface-repo + cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* . + git add . && git commit -m "Add redpajama-3b instruct model weights" + git push origin main + +This would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC +`_. + +**Step 4. Register as a ModelRecord** + +Finally, we are able to run the model we added in WebLLM's `get-started `__: + +.. code:: typescript + + const myAppConfig: AppConfig = { + model_list: [ + // Other records here omitted... + { + "model_url": "https://huggingface.co/my-hf-account/my-redpajama3b-weight-huggingface-repo/resolve/main/", + "local_id": "RedPajama-INCITE-Instruct-3B-v1", + "model_lib_url": "https://raw.githubusercontent.com/my-gh-account/my-repo/main/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm", + "required_features": ["shader-f16"], + }, + ] + } + + const selectedModel = "RedPajama-INCITE-Instruct-3B-v1" + await chat.reload(selectedModel, undefined, myAppConfig); + +Now, running the ``get-started`` example will use the ``RedPajama`` model you just added. +See `get-started's README `__ +on how to run it. \ No newline at end of file diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst new file mode 100644 index 0000000..3dd1b67 --- /dev/null +++ b/docs/deploy/python.rst @@ -0,0 +1,363 @@ +.. _deploy-python: + +Python API +========== + +.. contents:: Table of Contents + :local: + :depth: 2 + +We expose Python API for the MLC-Chat for easy integration into other Python projects. + +The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via +the :doc:`installation page <../install/mlc_llm>`. + +Instead of following this page, you could also checkout the following tutorials in +Python notebook (all runnable in Colab): + +- `Getting Started with MLC-LLM `_: + how to quickly download prebuilt models and chat with it +- `Raw Text Generation with MLC-LLM `_: + how to perform raw text generation with MLC-LLM in Python + +.. These notebooks are not up-to-date with SLM yet +.. - `Compiling Llama-2 with MLC-LLM `_: +.. how to use Python APIs to compile models with the MLC-LLM workflow +.. - `Extensions to More Model Variants `_: +.. how to use Python APIs to compile and chat with any model variant you'd like + + +Verify Installation +------------------- + +.. code:: bash + + python -c "from mlc_chat import ChatModule; print(ChatModule)" + +You are expected to see the information about the :class:`mlc_chat.ChatModule` class. + +If the command above results in error, follow :ref:`install-mlc-packages` (either install the prebuilt pip wheels +or :ref:`mlcchat_build_from_source`). + +Run MLC Models w/ Python +------------------------ + +To run a model with MLC LLM in any platform/runtime, you need: + +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC + `_.) +2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). + +There are two ways to obtain the model weights and libraries: + +1. Compile your own model weights and libraries following :doc:`the model compilation page `. +2. Use off-the-shelf `prebuilt models weights `__ and + `prebuilt model libraries `__ (see :ref:`Model Prebuilts` for details). + +We use off-the-shelf prebuilt models in this page. However, same steps apply if you want to run +the models you compiled yourself. + +**Step 1: Download prebuilt model weights and libraries** + +Skip this step if you have already obtained the model weights and libraries. + +.. code:: shell + + # Activate your conda environment + conda install -c conda-forge git-lfs + + # Download pre-conveted weights + git lfs install && mkdir dist/ + git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ + dist/Llama-2-7b-chat-hf-q4f16_1-MLC + + # Download pre-compiled model library + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + + +**Step 2: Run the model in Python** + +Use the conda environment you used to install ``mlc_chat``. +From the ``mlc-llm`` directory, you can create a Python +file ``sample_mlc_chat.py`` and paste the following lines: + +.. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamToStdout + + # Create a ChatModule instance + cm = ChatModule( + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} + ) + + # You can change to other models that you downloaded + # Model variants of the same architecture can reuse the same model library + # Here WizardMath reuses Mistral's model library + # cm = ChatModule( + # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" + # model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" + # ) + + # Generate a response for a given prompt + output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + # Print prefill and decode performance statistics + print(f"Statistics: {cm.stats()}\n") + + output = cm.generate( + prompt="How many points did you list out?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + # Reset the chat module by + # cm.reset_chat() + + +Now run the Python file to start the chat + +.. code:: bash + + python sample_mlc_chat.py + + +.. collapse:: See output + + .. code:: + + Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 + Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json + Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so + + Thank you for your question! The meaning of life is a complex and subjective topic that has been debated by philosophers, theologians, scientists, and many others for centuries. There is no one definitive answer to this question, as it can vary depending on a person's beliefs, values, experiences, and perspectives. + + However, here are some possible ways to approach the question: + + 1. Religious or spiritual beliefs: Many people believe that the meaning of life is to fulfill a divine or spiritual purpose, whether that be to follow a set of moral guidelines, to achieve spiritual enlightenment, or to fulfill a particular destiny. + 2. Personal growth and development: Some people believe that the meaning of life is to learn, grow, and evolve as individuals, to develop one's talents and abilities, and to become the best version of oneself. + 3. Relationships and connections: Others believe that the meaning of life is to form meaningful connections and relationships with others, to love and be loved, and to build a supportive and fulfilling social network. + 4. Contribution and impact: Some people believe that the meaning of life is to make a positive impact on the world, to contribute to society in a meaningful way, and to leave a lasting legacy. + 5. Simple pleasures and enjoyment: Finally, some people believe that the meaning of life is to simply enjoy the present moment, to find pleasure and happiness in the simple things in life, and to appreciate the beauty and wonder of the world around us. + + Ultimately, the meaning of life is a deeply personal and subjective question, and each person must find their own answer based on their own beliefs, values, and experiences. + + Statistics: prefill: 3477.5 tok/s, decode: 153.6 tok/s + + I listed out 5 possible ways to approach the question of the meaning of life. + +| + +**Running other models** + +Checkout the :doc:`/prebuilt_models` page to run other pre-compiled models. + +For models other than the prebuilt ones we provided: + +1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral`` as + shown in the code snippet), follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. +2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. + + +Configure MLCChat in Python +--------------------------- +If you have checked out :ref:`Configure MLCChat in JSON`, you would know +that you could configure MLCChat through various fields such as ``temperature``. We provide the +option of overriding any field you'd like in Python, so that you do not need to manually edit +``mlc-chat-config.json``. + +Since there are two concepts -- `MLCChat Configuration` and `Conversation Configuration` -- we correspondingly +provide two dataclasses :class:`mlc_chat.ChatConfig` and :class:`mlc_chat.ConvConfig`. + +We provide an example below. + +.. code:: python + + from mlc_chat import ChatModule, ChatConfig, ConvConfig + from mlc_chat.callback import StreamToStdout + + # Using a `ConvConfig`, we modify `system`, a field in the conversation template + # `system` refers to the prompt encoded before starting the chat + conv_config = ConvConfig(system='Please show as much happiness as you can when talking to me.') + + # We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len` + # Note that `conv_config` is an optional subfield of `chat_config` + chat_config = ChatConfig(max_gen_len=256, conv_config=conv_config) + + # Using the `chat_config` we created, instantiate a `ChatModule` + cm = ChatModule( + chat_config=chat_config, + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} + ) + + output = cm.generate( + prompt="What is one plus one?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + # You could also pass in a `ConvConfig` instance to `reset_chat()` + conv_config = ConvConfig(system='Please show as much sadness as you can when talking to me.') + chat_config = ChatConfig(max_gen_len=128, conv_config=conv_config) + cm.reset_chat(chat_config) + + output = cm.generate( + prompt="What is one plus one?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + +.. collapse:: See output + + .. code:: + + Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 + Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json + Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so + + Oh, wow, *excitedly* one plus one? *grinning* Well, let me see... *counting on fingers* One plus one is... *eureka* Two! + ... + + *Sobs* Oh, the tragedy of it all... *sobs* One plus one... *chokes back tears* It's... *gulps* it's... *breaks down in tears* TWO! + ... + +| + +.. note:: + You do not need to specify the entire ``ChatConfig`` or ``ConvConfig``. Instead, we will first + load all the fields defined in ``mlc-chat-config.json``, a file required when instantiating + a :class:`mlc_chat.ChatModule`. Then, we will load in the optional ``ChatConfig`` you provide, overriding the + fields specified. + + It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template + specified by the field ``conv_template`` in the chat configuration. Learn more about it in + :ref:`Configure MLCChat in JSON`. + +Raw Text Generation in Python +----------------------------- + +Raw text generation allows the user to have more flexibility over his prompts, +without being forced to create a new conversational template, making prompt customization easier. +This serves other demands for APIs to handle LLM generation without the usual system prompts and other items. + +We provide an example below. + +.. code:: python + + from mlc_chat import ChatModule, ChatConfig, ConvConfig + from mlc_chat.callback import StreamToStdout + + # Use a `ConvConfig` to define the generation settings + # Since the "LM" template only supports raw text generation, + # System prompts will not be executed even if provided + conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]") + + # Note that `conv_config` is an optional subfield of `chat_config` + # The "LM" template serves the basic purposes of raw text generation + chat_config = ChatConfig(conv_config=conv_config, conv_template="LM") + + # Using the `chat_config` we created, instantiate a `ChatModule` + cm = ChatModule( + chat_config=chat_config, + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} + ) + # To make the model follow conversations a chat structure should be provided + # This allows users to build their own prompts without building a new template + system_prompt = "<>\nYou are a helpful, respectful and honest assistant.\n<>\n\n" + inst_prompt = "What is mother nature?" + + # Concatenate system and instruction prompts, and add instruction tags + output = cm.generate( + prompt=f"[INST] {system_prompt+inst_prompt} [/INST]", + progress_callback=StreamToStdout(callback_interval=2), + ) + + # The LM template has no memory, so it will be reset every single generation + # In this case the model will just follow normal text completion + # because there isn't a chat structure + output = cm.generate( + prompt="Life is a quality that distinguishes", + progress_callback=StreamToStdout(callback_interval=2), + ) + +.. note:: + The ``LM`` is a template without memory, which means that every execution will be cleared. + Additionally, system prompts will not be run when instantiating a `mlc_chat.ChatModule`, + unless explicitly given inside the prompt. + +Stream Iterator in Python +------------------------- + +Stream Iterator gives users an option to stream generated text to the function that the API is called from, +instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat. + +We provide an example below. + +.. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamIterator + + # Create a ChatModule instance + cm = ChatModule( + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} + ) + + # Stream to an Iterator + from threading import Thread + + stream = StreamIterator(callback_interval=2) + generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, + ) + generation_thread.start() + + output = "" + for delta_message in stream: + output += delta_message + + generation_thread.join() + + +API Reference +------------- + +User can initiate a chat module by creating :class:`mlc_chat.ChatModule` class, which is a wrapper of the MLC-Chat model. +The :class:`mlc_chat.ChatModule` class provides the following methods: + +.. currentmodule:: mlc_chat + +.. autoclass:: ChatModule + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ + +.. autoclass:: ChatConfig + :members: + +.. autoclass:: ConvConfig + :members: + +.. autoclass:: GenerationConfig + :members: diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst new file mode 100644 index 0000000..d12029a --- /dev/null +++ b/docs/deploy/rest.rst @@ -0,0 +1,394 @@ +Rest API +======== + +.. contents:: Table of Contents + :local: + :depth: 2 + +We provide `REST API `_ +for a user to interact with MLC-Chat in their own programs. + +Install MLC-Chat Package +------------------------ + +The REST API is a part of the MLC-Chat package, which we have prepared pre-built :doc:`pip wheels <../install/mlc_llm>`. + +Verify Installation +^^^^^^^^^^^^^^^^^^^ + +.. code:: bash + + python -m mlc_chat.rest --help + +You are expected to see the help information of the REST API. + +.. _mlcchat_package_build_from_source: + +Optional: Build from Source +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If the prebuilt is unavailable on your platform, or you would like to build a runtime +that supports other GPU runtime than the prebuilt version. We can build a customized version +of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. + +First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). +You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-chat-nightly`. +Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. + +You can now use ``mlc_chat`` package by including the `python` directory to ``PYTHONPATH`` environment variable. + +.. code:: bash + + PYTHONPATH=python python -m mlc_chat.rest --help + +Launch the Server +----------------- + +To launch the REST server for MLC-Chat, run the following command in your terminal. + +.. code:: bash + + python -m mlc_chat.rest --model MODEL [--lib-path LIB_PATH] [--device DEVICE] [--host HOST] [--port PORT] + +--model The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. +--lib-path An optional field to specify the full path to the model library file to use (e.g. a ``.so`` file). +--device The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. The default value is ``auto``, + with the device id set to 0 for default. +--host The host at which the server should be started, defaults to ``127.0.0.1``. +--port The port on which the server should be started, defaults to ``8000``. + +You can access ``http://127.0.0.1:PORT/docs`` (replace ``PORT`` with the port number you specified) to see the list of +supported endpoints. + +API Endpoints +------------- + +The REST API provides the following endpoints: + +.. http:get:: /v1/completions + +------------------------------------------------ + + Get a completion from MLC-Chat using a prompt. + +**Request body** + +**model**: *str* (required) + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. +**prompt**: *str* (required) + A list of chat messages. The last message should be from the user. +**stream**: *bool* (optional) + Whether to stream the response. If ``True``, the response will be streamed + as the model generates the response. If ``False``, the response will be + returned after the model finishes generating the response. +**temperature**: *float* (optional) + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. +**top_p**: *float* (optional) + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. +**repetition_penalty**: *float* (optional) + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). +**presence_penalty**: *float* (optional) + Positive values penalize new tokens if they are already present in the text so far, + decreasing the model's likelihood to repeat tokens. +**frequency_penalty**: *float* (optional) + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat tokens. +**mean_gen_len**: *int* (optional) + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. +**max_gen_len**: *int* (optional) + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. + +------------------------------------------------ + +**Returns** + If ``stream`` is set to ``False``, the response will be a ``CompletionResponse`` object. + If ``stream`` is set to ``True``, the response will be a stream of ``CompletionStreamResponse`` objects. + + +.. http:get:: /v1/chat/completions + +------------------------------------------------ + + Get a response from MLC-Chat using a prompt, either with or without streaming. + +**Request body** + +**model**: *str* (required) + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. +**messages**: *list[ChatMessage]* (required) + A list of chat messages. The last message should be from the user. +**stream**: *bool* (optional) + Whether to stream the response. If ``True``, the response will be streamed + as the model generates the response. If ``False``, the response will be + returned after the model finishes generating the response. +**temperature**: *float* (optional) + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. +**top_p**: *float* (optional) + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. +**repetition_penalty**: *float* (optional) + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). +**presence_penalty**: *float* (optional) + Positive values penalize new tokens if they are already present in the text so far, + decreasing the model's likelihood to repeat tokens. +**frequency_penalty**: *float* (optional) + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat tokens. +**mean_gen_len**: *int* (optional) + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. +**max_gen_len**: *int* (optional) + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. +**n**: *int* (optional) + This parameter determines the number of text samples to generate. The default + value is ``1``. Note that this parameter is only used when ``stream`` is set to + ``False``. +**stop**: *str* or *list[str]* (optional) + When ``stop`` is encountered, the model will stop generating output. + It can be a string or a list of strings. If it is a list of strings, the model + will stop generating output when any of the strings in the list is encountered. + Note that this parameter does not override the default stop string of the model. + +------------------------------------------------ + +**Returns** + If ``stream`` is set to ``False``, the response will be a ``ChatCompletionResponse`` object. + If ``stream`` is set to ``True``, the response will be a stream of ``ChatCompletionStreamResponse`` objects. + +.. http:get:: /chat/reset + + Reset the chat. + +.. http:get:: /stats + + Get the latest runtime stats (encode/decode speed). + +.. http:get:: /verbose_stats + + Get the verbose runtime stats (encode/decode speed, total runtime). + + +Request Objects +--------------- + +**ChatMessage** + +**role**: *str* (required) + The role(author) of the message. It can be either ``user`` or ``assistant``. +**content**: *str* (required) + The content of the message. +**name**: *str* (optional) + The name of the author of the message. + +Response Objects +---------------- + +**CompletionResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``text.completion``. +**created**: *int* + The time when the completion is created. +**choices**: *list[CompletionResponseChoice]* + A list of choices generated by the model. +**usage**: *UsageInfo* or *None* + The usage information of the model. + +------------------------------------------------ + +**CompletionResponseChoice** + +**index**: *int* + The index of the choice. +**text**: *str* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + + +------------------------------------------------ + +**CompletionStreamResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``text.completion.chunk``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseStreamhoice]* + A list of choices generated by the model. + +------------------------------------------------ + +**ChatCompletionResponseStreamChoice** + +**index**: *int* + The index of the choice. +**text**: *str* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + +**ChatCompletionResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``chat.completion``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseChoice]* + A list of choices generated by the model. +**usage**: *UsageInfo* or *None* + The usage information of the model. + +------------------------------------------------ + +**ChatCompletionResponseChoice** + +**index**: *int* + The index of the choice. +**message**: *ChatMessage* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + +**ChatCompletionStreamResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``chat.completion.chunk``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseStreamhoice]* + A list of choices generated by the model. + +------------------------------------------------ + +**ChatCompletionResponseStreamChoice** + +**index**: *int* + The index of the choice. +**delta**: *DeltaMessage* + The delta message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + + +**DeltaMessage** + +**role**: *str* + The role(author) of the message. It can be either ``user`` or ``assistant``. +**content**: *str* + The content of the message. + +------------------------------------------------ + + +Use REST API in your own program +-------------------------------- + +Once you have launched the REST server, you can use the REST API in your own program. Below is an example of using REST API to interact with MLC-Chat in Python (suppose the server is running on ``http://127.0.0.1:8000/``): + +.. code:: bash + + import requests + import json + + # Get a response using a prompt without streaming + payload = { + "model": "vicuna-v1-7b", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": False + } + r = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload) + print(f"Without streaming:\n{r.json()['choices'][0]['message']['content']}\n") + + # Reset the chat + r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) + print(f"Reset chat: {str(r)}\n") + + # Get a response using a prompt with streaming + payload = { + "model": "vicuna-v1-7b", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": True + } + with requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True) as r: + print(f"With streaming:") + for chunk in r: + content = json.loads(chunk[6:-2])["choices"][0]["delta"].get("content", "") + print(f"{content}", end="", flush=True) + print("\n") + + # Get the latest runtime stats + r = requests.get("http://127.0.0.1:8000/stats") + print(f"Runtime stats: {r.json()}\n") + +Please check `example folder `__ for more examples using REST API. + +.. note:: + The REST API is a uniform interface that supports multiple languages. You can also utilize the REST API in languages other than Python. diff --git a/docs/get_started/mlc_chat_config.rst b/docs/get_started/mlc_chat_config.rst new file mode 100644 index 0000000..c583c16 --- /dev/null +++ b/docs/get_started/mlc_chat_config.rst @@ -0,0 +1,254 @@ +.. _configure-mlc-chat-json: + +Configure MLCChat in JSON +========================= + +``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes: + +1. Specify how we compile a model (shown in :ref:`compile-model-libraries`), and +2. Specify conversation behavior in runtime. + +**This page focuses on the second purpose.** We explain the components of a chat +configuration and how to customize them by modifying the file. Additionally, +the runtimes also provide APIs to optionally override some of the configurations. + +In runtime, this file is stored under the directory of each compiled model +(e.g. `RedPajama chat config `__). + + +.. _struct-mlc-chat-conv: + +Structure of MLCChat Configuration +---------------------------------- + +Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: + +.. code:: json + + // mlc-chat-config.json + { + // 1. Metadata used to specify how to compile a model + "model_type": "llama", + "quantization": "q4f16_1", + "version": "0.1.0", + "model_config": { + "hidden_size": 4096, + "intermediate_size": 11008, + // more fields here... + }, + "vocab_size": 32000, + "context_window_size": 4096, + "sliding_window_size": -1, + "prefill_chunk_size": 4096, + "tensor_parallel_shards": 1, + + // 2. Tokenizer-related fields + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "tokenizer_files": [ + "tokenizer.model", + "tokenizer.json", + "tokenizer_config.json" + ] + + // 3. Chat related fields that affect runtime behavior + "mean_gen_len": 128, + "max_gen_len": 512, + "shift_fill_factor": 0.3, + "temperature": 0.6, + "repetition_penalty": 1.0, + "top_p": 0.9, + "conv_template": "llama-2", + } + +.. note:: + Fields in the first part of ``mlc-chat-config.json`` (e.g. ``context-window-size``) + is only for compile-time. Changing them during runtime may lead to unexpected behavior. + +**As shown above, the file is divided into three parts. We focus on the third part, which +can be customized to change the behavior of the model.** + +``conv_template`` + The name of the conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. + +``temperature`` + The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. + +``repetition_penalty`` + The repetition penalty controls the likelihood of the model generating repeated texts. The default value is set to ``1.0``, indicating that no repetition penalty is applied. Increasing the value reduces the likelihood of repeat text generation. However, setting a high ``repetition_penalty`` may result in the model generating meaningless texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please check out the `CTRL paper `_. + +``top_p`` + This parameter determines the set of tokens from which we sample during decoding. The default value is set to ``0.95``. At each step, we select tokens from the minimal set that has a cumulative probability exceeding the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this `blog post `_. + +``mean_gen_len`` + The approximated average number of generated tokens in each round. Used to determine whether the maximum window size would be exceeded. + +``max_gen_len`` + This parameter determines the maximum length of the generated text. If it is not set, the model will generate text until it encounters a stop token. + +``shift_fill_factor`` + The fraction of maximum window size to shift when it is exceeded. + +.. _struct-conv: + +Conversation Structure +^^^^^^^^^^^^^^^^^^^^^^ + +There are three options of loading conversation configurations: + +1. Load from pre-defined conversation templates. +2. Load from JSON format conversation configuration. +3. First load from pre-defined conversation templates, then override some fields with JSON format conversation configuration. + +.. _load-predefined-conv-template: + +Load from Pre-defined Conversation Templates +-------------------------------------------- + +MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by specifying the template name in ``conv_template`` field in the ``mlc-chat-config.json``, below is a list (not complete) of supported conversation templates: + +- ``llama-2`` +- ``vicuna_v1.1`` +- ``redpajama_chat`` +- ``rwkv`` +- ``dolly`` +- ... + +Please refer to `conv_template.cc `_ for the full list of supported templates and their implementations. + +.. _load-json-conv-config: + +Load from JSON Conversation Configuration +----------------------------------------- + +Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): + +.. code:: json + + // mlc-chat-config.json + { + // ... + "conv_config": { + "seps": [ + " ", + "<\/s>" + ], + "stop_tokens": [ + 2 + ], + "offset": 0, + "separator_style": 0, + "messages": [], + "stop_str": "<\/s>", + "roles": [ + "USER", + "ASSISTANT" + ], + "role_msg_sep": ": ", + "role_empty_sep": ": ", + "system": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", + "add_bos": true, + "name": "vicuna_v1.1" + } + } + +``roles`` + An array that describes the role names of the user and the model. These names are specific to the model being used. +``system`` + The prompt encoded before starting the chat. It can be customized to a user-defined prompt. +``add_bos`` + Determines whether a beginning-of-string (bos) token should be added before the input tokens. +``stop_str`` + When the ``stop_str`` is encountered, the model will stop generating output. +``stop_tokens`` + A list of token IDs that act as stop tokens. +``seps`` + An array of strings indicating the separators to be used after a user message and a model message respectively. +``messages`` + The chat history represented as an array of string pairs in the following format: ``[[role_0, msg_0], [role_1, msg_1], ...]`` +``offset`` + The offset used to begin the chat from the chat history. When ``offset`` is not ``0``, ``messages[0:offset-1]`` will be encoded. +``separator_style`` + Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). +``role_msg_sep`` + A string indicating the separator between a role and a message. +``role_empty_sep`` + A string indicating the separator to append to a role when there is no message yet. + + +When the value of ``separator_style`` is set to 0 (or ``kSepRoleMsg``), each round of conversation follows the format: + +.. code:: text + + {role[0]}{separator_style}{user_input}{sep[0]} + {role[1]}{separator_style}{model_output}{sep[1]} + +Here, ``{user_input}`` represents the input provided by the user, and ``{model_output}`` represents the output generated by the model. + +On the other hand, if the value of ``separator_style`` is set to 1 (or ``kLM``), the model is not aware of the chat history and generates the response immediately after the user input prompt: + + +.. code:: text + + {user_prompt}{model_output} + + +.. _customize-conv-template: + +Customize Conversation Template +------------------------------- + +In the ``mlc-chat-config.json`` file, you have the option to specify both ``conv_template`` and ``conv_config``. MLC-LLM will first load the predefined template with the name specified in ``conv_template`` and then override some of the configurations specified in ``conv_config``. It's important to note that the configurations in ``conv_config`` don't need to be complete, allowing for partial updates. + +.. _example_replace_system_prompt: + +Example 1: Replace System Prompt +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you're tired of the default system prompt, here's an example of how you can replace it: + +.. code:: json + + // mlc-chat-config.json + { + // ... + "conv_template": "vicuna_v1.1", + "conv_config": { + "system": "You are not Vicuna, your name is Guanaco, now let's chat!" + } + } + + +The next time you run ``mlc_chat`` CLI, you will start a chat with Vicuna using a new system prompt. + +.. _example_resume_chat_history: + +Example 2: Resume from Chat History +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following example demonstrates how to chat with Vicuna and resume from a chat history: + +.. code:: json + + // mlc-chat-config.json + { + // ... + "conv_template": "vicuna_v1.1", + "conv_config": { + "messages": [ + ["USER", "Suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"], + ["ASSISTANT", "Based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. This name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."], + ["USER", "I like cervidae, but the name is too long!"], + ["ASSISTANT", "In that case, a shorter and catchier name for the next project could be \"DeerRun\" which plays on the idea of the project being fast and efficient, just like a deer running through the woods. This name is memorable and easy to pronounce, making it a good choice for a project name."] + ], + "offset": 4 + } + } + + +The next time you start ``mlc_chat`` CLI, or use Python API, you will initiate a chat with Vicuna and resume from the provided chat history. diff --git a/docs/get_started/project_overview.rst b/docs/get_started/project_overview.rst new file mode 100644 index 0000000..2b6ff74 --- /dev/null +++ b/docs/get_started/project_overview.rst @@ -0,0 +1,88 @@ +.. _project-overview: + +Project Overview +================ + +This page introduces high-level project concepts to help us use and customize MLC LLM. +The MLC-LLM project consists of three distinct submodules: model definition, model compilation, and runtimes. + +.. figure:: /_static/img/project-structure.svg + :width: 600 + :align: center + :alt: Project Structure + + Three independent submodules in MLC LLM + +**➀ Model definition in Python.** MLC offers a variety of pre-defined architectures, such as Llama (e.g., Llama2, Vicuna, OpenLlama, Wizard), GPT-NeoX (e.g., RedPajama, Dolly), RNNs (e.g., RWKV), and GPT-J (e.g., MOSS). Model developers could solely define the model in pure Python, without having to touch code generation and runtime. + +**➁ Model compilation in Python.** Models are compiled by :doc:`TVM Unity ` compiler, where the compilation is configured in pure Python. MLC LLM quantizes and exports the Python-based model to a model library and quantized model weights. Quantization and optimization algorithms can be developed in pure Python to compress and accelerate LLMs for specific usecases. + +**➂ Platform-native runtimes.** Variants of MLCChat are provided on each platform: **C++** for command line, **Javascript** for web, **Swift** for iOS, and **Java** for Android, configurable with a JSON chat config. App developers only need to familiarize with the platform-naive runtimes to integrate MLC-compiled LLMs into their projects. + +.. _terminologies: + +Terminologies +------------- + +It is helpful for us to familiarize the basic terminologies used in the MLC chat applications. Below are the +three things you need to run a model with MLC. + +- **model lib**: The model library refers to the executable libraries that enable + the execution of a specific model architecture. On Linux and M-chip macOS, these libraries have the suffix + ``.so``; on intel macOS, the suffix is ``.dylib``; on Windows, the library file ends with ``.dll``; + on web browser, the library suffix is ``.wasm``. (see `binary-mlc-llm-libs `__). + +- **model weights**: The model weight is a folder that contains the quantized neural network weights + of the language models as well as the tokenizer configurations. (e.g. `Llama-2-7b-chat-hf-q4f16_1-MLC `__) + +- **chat config**: The chat configuration includes settings that allow customization of parameters such as temperature and system prompt. + The default chat config usually resides in the same directory as model weights. (e.g. see ``Llama-2-7b-chat-hf-q4f16_1``'s + `mlc-chat-config.json `__) + +Model Preparation +----------------- + + +There are several ways to prepare the model weights and model lib. + +- :ref:`Model Prebuilts` contains models that can be directly used. +- You can also :doc:`run model compilation ` for model weight variants for given supported architectures. +- Finally, you can incorporate a new model architecture/inference logic following :doc:`Define New Models `. + +A default chat config usually comes with the model weight directory. You can further customize +the system prompt, temperature, and other options by modifying the JSON file. +MLC chat runtimes also provide API to override these options during model reload. +Please refer to :doc:`/get_started/mlc_chat_config` for more details. + + +Runtime Flow Overview +--------------------- + +Once the model weights, model library, and chat configuration are prepared, an MLC chat runtime can consume them as an engine to drive a chat application. +The diagram below shows a typical workflow for a MLC chat application. + +.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/a05d4598bae6eb5a3133652d5cc0323ced3b0e17/images/mlc-llm/tutorials/mlc-llm-flow-slm.svg + :width: 90% + :align: center + +On the right side of the figure, you can see pseudo-code illustrating the structure of an MLC chat API during the execution of a chat app. +Typically, there is a ``ChatModule`` that manages the model. We instantiate the chat app with two files: the model weights (which include an ``mlc-chat-config.json``) +and the model library. We also have an optional chat configuration, which allows for overriding settings such as the system prompt and temperature. + +All MLC runtimes, including iOS, Web, CLI, and others, use these three elements. +All the runtime can read the same model weight folder. The packaging of the model libraries may vary depending on the runtime. +For the CLI, the model libraries are stored in a DLL directory. +iOS and Android include pre-packaged model libraries within the app due to dynamic loading restrictions. +WebLLM utilizes URLs of local or Internet-hosted WebAssembly (Wasm) files. + +What to Do Next +--------------- + +Thank you for reading and learning the high-level concepts. +Moving next, feel free to check out documents on the left navigation panel and +learn about topics you are interested in. + +- :doc:`/get_started/mlc_chat_config` shows how to configure specific chat behavior. +- Build and Deploy App section contains guides to build apps + and platform-specific MLC chat runtimes. +- Compile models section provides guidelines to convert model weights and produce model libs. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..15ad6ca --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,233 @@ +👋 Welcome to MLC LLM +===================== + +`Discord `_ | `GitHub `_ + +Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. + +.. _get_started: + +Getting Started +--------------- + +To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +It is recommended to have at least 6GB free VRAM to run it. + +.. tabs:: + + .. tab:: Python + + **Install MLC Chat Python**. :doc:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Download pre-quantized weights**. The commands below download the int4-quantized Llama2-7B from HuggingFace: + + .. code:: bash + + git lfs install && mkdir dist/ + git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ + dist/Llama-2-7b-chat-hf-q4f16_1-MLC + + **Download pre-compiled model library**. The pre-compiled model library is available as below: + + .. code:: bash + + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + + **Run in Python.** The following Python script showcases the Python API of MLC LLM and its stream capability: + + .. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamToStdout + + cm = ChatModule( + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} + ) + cm.generate(prompt="What is the meaning of life?", progress_callback=StreamToStdout(callback_interval=2)) + + **Colab walkthrough.** A Jupyter notebook on `Colab `_ + is provided with detailed walkthrough of the Python API. + + **Documentation and tutorial.** Python API reference and its tutorials are `available online `_. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg + :width: 600 + :align: center + + MLC LLM Python API + + .. tab:: Command Line + + **Install MLC Chat CLI.** MLC Chat CLI is available via conda using the command below. + It is always recommended to install it in an isolated conda virtual environment. + For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. + + .. code:: bash + + conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly + conda activate mlc-chat-venv + + **Download pre-quantized weights**. The comamnds below download the int4-quantized Llama2-7B from HuggingFace: + + .. code:: bash + + git lfs install && mkdir dist/ + git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ + dist/Llama-2-7b-chat-hf-q4f16_1-MLC + + **Download pre-compiled model library**. The pre-compiled model library is available as below: + + .. code:: bash + + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + + **Run in command line**. + + .. code:: bash + + mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/Llama2-macOS.gif + :width: 500 + :align: center + + MLC LLM on CLI + + .. note:: + The MLC Chat CLI package is only built with Vulkan (Windows/Linux) and Metal (macOS). + To use other GPU backends such as CUDA and ROCm, please use the prebuilt Python package or build from source. + + .. tab:: Web Browser + + `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, + so that LLMs can be run locally in a web browser without server resources. + + **Download pre-quantized weights**. This step is self-contained in WebLLM. + + **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. + + **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. + + .. figure:: https://blog.mlc.ai/img/redpajama/web.gif + :width: 300 + :align: center + + MLC LLM on Web + + .. tab:: iOS + + **Install MLC Chat iOS**. It is available on AppStore: + + .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg + :width: 135 + :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 + + | + + **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + + **Tutorial and source code**. The source code of the iOS app is fully `open source `__, + and a :doc:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif + :width: 300 + :align: center + + MLC Chat on iOS + + .. tab:: Android + + **Install MLC Chat Android**. A prebuilt is available as an APK: + + .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png + :width: 135 + :target: https://github.com/mlc-ai/binary-mlc-llm-libs/raw/main/mlc-chat.apk + + | + + **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + The demo is tested on + + - Samsung S23 with Snapdragon 8 Gen 2 chip + - Redmi Note 12 Pro with Snapdragon 685 + - Google Pixel phones + + **Tutorial and source code**. The source code of the android app is fully `open source `__, + and a :doc:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/android/android-recording.gif + :width: 300 + :align: center + + MLC LLM on Android + + +.. toctree:: + :maxdepth: 1 + :caption: Get Started + :hidden: + + get_started/project_overview.rst + get_started/mlc_chat_config.rst + +.. toctree:: + :maxdepth: 1 + :caption: Build and Deploy Apps + :hidden: + + deploy/javascript.rst + deploy/rest.rst + deploy/cli.rst + deploy/python.rst + deploy/ios.rst + deploy/android.rst + +.. toctree:: + :maxdepth: 1 + :caption: Compile Models + :hidden: + + compilation/convert_weights.rst + compilation/compile_models.rst + compilation/define_new_models.rst + compilation/configure_quantization.rst + +.. toctree:: + :maxdepth: 1 + :caption: Model Prebuilts + :hidden: + + prebuilt_models.rst + prebuilt_models_deprecated.rst + +.. toctree:: + :maxdepth: 1 + :caption: Dependency Installation + :hidden: + + install/tvm.rst + install/mlc_llm.rst + install/conda.rst + install/gpu.rst + install/emcc.rst + +.. toctree:: + :maxdepth: 1 + :caption: Community + :hidden: + + community/guideline.rst + community/faq.rst + + +.. toctree:: + :maxdepth: 1 + :caption: Privacy + :hidden: + + privacy.rst diff --git a/docs/install/conda.rst b/docs/install/conda.rst new file mode 100644 index 0000000..305c0f6 --- /dev/null +++ b/docs/install/conda.rst @@ -0,0 +1,66 @@ +Install Conda +============= + +MLC LLM does not depend on, but generally recommends conda as a generic dependency manager, primarily because it creates unified cross-platform experience to make windows/Linux/macOS development equally easy. Moreover, conda is python-friendly and provides all the python packages needed for MLC LLM, such as numpy. + +.. contents:: Table of Contents + :depth: 2 + + +Install Miniconda +----------------- + +**Use installer.** Miniconda, a minimal distribution of conda, comes with out-of-box installer across Windows/macOS/Linux. Please refer to its `official website `_ link for detailed instructions. + +**Set libmamba as the dependency solver.** The default dependency solver in conda could be slow in certain scenarios, and it is always recommended to upgrade it to libmamba, a faster solver. + +.. code-block:: bash + :caption: Set libmamba as the default solver + + # update conda + conda update --yes -n base -c defaults conda + # install `conda-libmamba-solver` + conda install --yes -n base conda-libmamba-solver + # set it as the default solver + conda config --set solver libmamba + +.. note:: + Conda is a generic dependency manager, which is not necessarily related to any Python distributions. + In fact, some of our tutorials recommends to use conda to install cmake, git and rust for its unified experience across OS platforms. + + +Validate installation +--------------------- + +**Step 1. Check conda-arch mismatch.** Nowadays macOS runs on two different architectures: arm64 and x86_64, which could particularly lead to many misuses in MLC LLM, where the error message hints about "architecture mismatch". Use the following command to make sure particular conda architecture is installed accordingly: + +.. code-block:: bash + :caption: Check conda architecture + + >>> conda info | grep platform + # for arm mac + platform : osx-arm64 + # for x86 mac + platform : osx-64 + +**Step 2. Check conda virtual environment.** If you have installed python in your conda virtual environment, make sure conda, Python and pip are all from this environment: + +.. code-block:: bash + :caption: Check conda virtual environment (macOS, Linux) + + >>> echo $CONDA_PREFIX + /.../miniconda3/envs/mlc-doc-venv + >>> which python + /.../miniconda3/envs/mlc-doc-venv/bin/python + >>> which pip + /.../miniconda3/envs/mlc-doc-venv/bin/pip + +.. code-block:: bat + :caption: Check conda virtual environment (Windows) + + >>> echo $Env:CONDA_PREFIX + \...\miniconda3\envs\mlc-doc-venv + >>> Get-Command python.exe + \...\miniconda3\envs\mlc-doc-venv\bin\python.exe + >>> Get-Command pip.exe + \...\miniconda3\envs\mlc-doc-venv\bin\pip.exe diff --git a/docs/install/emcc.rst b/docs/install/emcc.rst new file mode 100644 index 0000000..9320be4 --- /dev/null +++ b/docs/install/emcc.rst @@ -0,0 +1,64 @@ +.. _install-web-build: + +Install Wasm Build Environment +============================== + +This page describes the steps to setup build environment for WebAssembly and WebGPU builds. + +Step 1: Install EMSDK +--------------------- + +Emscripten is an LLVM-based compiler that compiles C/C++ source code to WebAssembly. +We need to install emscripten for webgpu build. + +- Please follow the installation instruction `here `__ + to install the latest emsdk. +- Source path/to/emsdk_env.sh so emcc is reachable from PATH and the command emcc works. + +Validate that emcc is accessible in shell + +.. code:: bash + + emcc --version + +Step 2: Set TVM_HOME +-------------------- + +We need to set a path to a tvm source in order to build tvm runtime. +Note that you do not need to build tvm unity from the source. The source here is only used to build the web runtime component. +Set environment variable in your shell startup profile in to point to ``3rdparty/tvm`` + +.. code:: bash + + export TVM_HOME=/path/to/3rdparty/tvm + + +Step 3: Prepare Wasm Runtime +---------------------------- + +First, we need to obtain a copy of the mlc-llm source code for the setup script + +.. code:: bash + + git clone https://github.com/mlc-ai/mlc-llm.git --recursive + cd mlc-llm + +Now we can prepare wasm runtime using the script in mlc-llm repo + +.. code:: bash + + ./scripts/prep_emcc_deps.sh + +We can then validate the outcome + +.. code:: bash + + >>> echo ${TVM_HOME} + + /path/set/in/step2 + + >>> ls -l ${TVM_HOME}/web/dist/wasm/*.bc + + tvmjs_support.bc + wasm_runtime.bc + webgpu_runtime.bc diff --git a/docs/install/gpu.rst b/docs/install/gpu.rst new file mode 100644 index 0000000..608c238 --- /dev/null +++ b/docs/install/gpu.rst @@ -0,0 +1,201 @@ +GPU Drivers and SDKs +==================== + +.. contents:: Table of Contents + :depth: 2 + +MLC LLM is a universal deployment solution that allows efficient CPU/GPU code generation without AutoTVM-based performance tuning. This section focuses on generic GPU environment setup and troubleshooting. + +CUDA +---- + +CUDA is required to compile and run models with CUDA backend. + +Installation +^^^^^^^^^^^^ + +If you have a NVIDIA GPU and you want to use models compiled with CUDA +backend, you should install CUDA, which can be downloaded from +`here `__. + +Validate Installation +^^^^^^^^^^^^^^^^^^^^^ + +To verify you have correctly installed CUDA runtime and NVIDIA driver, run ``nvidia-smi`` in command line and see if you can get the GPU information. + +ROCm +---- + +ROCm is required to compile and run models with ROCm backend. + +Installation +^^^^^^^^^^^^ + +Right now MLC LLM only supports ROCm 5.6. +If you have AMD GPU and you want to use models compiled with ROCm +backend, you should install ROCm 5.6 from `here `__. + +Validate Installation +^^^^^^^^^^^^^^^^^^^^^ + +To verify you have correctly installed ROCm 5.6, run ``rocm-smi`` in command line. +If you see the list of AMD devices printed out in a table, it means the ROCm is correctly installed. + +.. _vulkan_driver: + +Vulkan Driver +------------- + +Installation +^^^^^^^^^^^^ + +To run pre-trained models (e.g. pulled from MLC-AI's Hugging Face repository) compiled with Vulkan backend, you are expected to install Vulkan driver on your machine. + +Please check `this +page `__ and find the +Vulkan driver according to your GPU vendor. + +AMD Radeon and Radeon PRO +######################### + +For AMD Radeon and Radeon PRO users, please download AMD's drivers from official website (`Linux `__ / `Windows `__). +For Linux users, after you installed the ``amdgpu-install`` package, you can follow the instructions in its `documentation `__ to install +the driver. We recommend you installing ROCr OpenCL and PRO Vulkan (proprietary) for best performance, which can be done by running the following command: + +.. code:: bash + + amdgpu-install --usecase=graphics,opencl --opencl=rocr --vulkan=pro --no-32 + +Validate Installation +^^^^^^^^^^^^^^^^^^^^^ + +To verify whether Vulkan installation is successful or not, you are encouraged to install ``vulkaninfo``, below are the instructions to install ``vulkaninfo`` on different platforms: + +.. tabs :: + + .. code-tab :: bash Ubuntu/Debian + + sudo apt-get update + sudo apt-get install vulkan-tools + + .. code-tab :: bash Windows + + # It comes with your GPU driver + + .. code-tab :: bash Fedora + + sudo dnf install vulkan-tools + + .. code-tab :: bash Arch Linux + + sudo pacman -S vulkan-tools + # Arch Linux has maintained an awesome wiki page for Vulkan which you can refer to for troubleshooting: https://wiki.archlinux.org/title/Vulkan + + .. code-tab :: bash Other Distributions + + # Please install Vulkan SDK for your platform + # https://vulkan.lunarg.com/sdk/home + + +After installation, you can run ``vulkaninfo`` in command line and see if you can get the GPU information. + +.. note:: + WSL support for Windows is work-in-progress at the moment. Please do not use WSL on Windows to run Vulkan. + +Vulkan SDK +---------- + +Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK. + +Check Vulkan SDK installation guide according to your platform: + +.. tabs :: + + .. tab :: Windows + + `Getting Started with the Windows Tarball Vulkan SDK `__ + + .. tab :: Linux + + For Ubuntu user, please check + `Getting Started with the Ubuntu Vulkan SDK `__ + + For other Linux distributions, please check + `Getting Started with the Linux Tarball Vulkan SDK `__ + + .. tab :: Mac + + `Getting Started with the macOS Vulkan SDK `__ + +Please refer to installation and setup page for next steps to build TVM-Unity from source. + +OpenCL SDK +---------- + +OpenCL SDK is only required when you want to build your own models for OpenCL backend. Please refer to `OpenCL's Github Repository `__ for installation guide of OpenCL-SDK. + +Orange Pi 5 (RK3588 based SBC) +------------------------------ + +OpenCL SDK and Mali GPU driver is required to compile and run models for OpenCL backend. + +Installation +^^^^^^^^^^^^ + +* Download and install the Ubuntu 22.04 for your board from `here `__ + +* Download and install ``libmali-g610.so`` + +.. code-block:: bash + + cd /usr/lib && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/lib/aarch64-linux-gnu/libmali-valhall-g610-g6p0-x11-wayland-gbm.so + +* Check if file ``mali_csffw.bin`` exist under path ``/lib/firmware``, if not download it with command: + +.. code-block:: bash + + cd /lib/firmware && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/firmware/g610/mali_csffw.bin + +* Download OpenCL ICD loader and manually add libmali to ICD + +.. code-block:: bash + + sudo apt update + sudo apt install mesa-opencl-icd + sudo mkdir -p /etc/OpenCL/vendors + echo "/usr/lib/libmali-valhall-g610-g6p0-x11-wayland-gbm.so" | sudo tee /etc/OpenCL/vendors/mali.icd + +* Download and install ``libOpenCL`` + +.. code-block:: bash + + sudo apt install ocl-icd-opencl-dev + +* Download and install dependencies for Mali OpenCL + +.. code-block:: bash + + sudo apt install libxcb-dri2-0 libxcb-dri3-0 libwayland-client0 libwayland-server0 libx11-xcb1 + +* Download and install clinfo to check if OpenCL successfully installed + +.. code-block:: bash + + sudo apt install clinfo + +Validate Installation +^^^^^^^^^^^^^^^^^^^^^ + +To verify you have correctly installed OpenCL runtime and Mali GPU driver, run ``clinfo`` in command line and see if you can get the GPU information. +You are expect to see the following information: + +.. code-block:: bash + + $ clinfo + arm_release_ver: g13p0-01eac0, rk_so_ver: 3 + Number of platforms 2 + Platform Name ARM Platform + Platform Vendor ARM + Platform Version OpenCL 2.1 v1.g6p0-01eac0.2819f9d4dbe0b5a2f89c835d8484f9cd + Platform Profile FULL_PROFILE + ... diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst new file mode 100644 index 0000000..004ee15 --- /dev/null +++ b/docs/install/mlc_llm.rst @@ -0,0 +1,240 @@ +.. _install-mlc-packages: + +Install MLC LLM Python Package +============================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +MLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source. + +Option 1. Prebuilt Package +-------------------------- + +We provide nightly built pip wheels for MLC-LLM via pip. +Select your operating system/compute platform and run the command in your terminal: + +.. note:: + ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + +.. tabs:: + + .. tab:: Linux + + .. tabs:: + + .. tab:: CPU + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. tab:: CUDA 11.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 + + .. tab:: CUDA 11.8 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 + + .. tab:: CUDA 12.1 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 + + .. tab:: CUDA 12.2 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 + + .. tab:: ROCm 5.6 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 + + .. tab:: ROCm 5.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 + + .. tab:: Vulkan + + Supported in all Linux packages. + + .. note:: + + If encountering issues with GLIBC not found, please install the latest glibc in conda: + + .. code-block:: bash + + conda install -c conda-forge libgcc-ng + + Besides, we would recommend using Python 3.11; so if you are creating a new environment, + you could use the following command: + + .. code-block:: bash + + conda create --name mlc-prebuilt python=3.11 + + .. tab:: macOS + + .. tabs:: + + .. tab:: CPU + Metal + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + + Always check if conda is installed properly in macOS using the command below: + + .. code-block:: bash + + conda info | grep platform + + It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. + + .. tab:: Windows + + .. tabs:: + + .. tab:: CPU + Vulkan + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + If encountering the error below: + + .. code-block:: bash + + FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. + + It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: + + .. code-block:: bash + + conda install zstd + + +Then you can verify installation in command line: + +.. code-block:: bash + + python -c "import mlc_chat; print(mlc_chat)" + # Prints out: + +| + +.. _mlcchat_build_from_source: + +Option 2. Build from Source +--------------------------- + +We also provide options to build mlc runtime libraries ``mlc_chat`` from source. +This step is useful when you want to make modification or obtain a specific version of mlc runtime. + + +**Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are satisfied: + +* CMake >= 3.24 +* Git +* `Rust and Cargo `_, required by Hugging Face's tokenizer +* One of the GPU runtimes: + + * CUDA >= 11.8 (NVIDIA GPUs) + * Metal (Apple GPUs) + * Vulkan (NVIDIA, AMD, Intel GPUs) + +.. code-block:: bash + :caption: Set up build dependencies in Conda + + # make sure to start with a fresh environment + conda env remove -n mlc-chat-venv + # create the conda environment with build dependency + conda create -n mlc-chat-venv -c conda-forge \ + "cmake>=3.24" \ + rust \ + git \ + python=3.11 + # enter the build environment + conda activate mlc-chat-venv + +.. note:: + For runtime, :doc:`TVM Unity ` compiler is not a dependency for MLCChat CLI or Python API. Only TVM's runtime is required, which is automatically included in `3rdparty/tvm `_. + However, if you would like to compile your own models, you need to follow :doc:`TVM Unity `. + +**Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool: + +.. code-block:: bash + :caption: Configure and build + + # clone from GitHub + git clone --recursive https://github.com/mlc-ai/mlc-llm.git && cd mlc-llm/ + # create build directory + mkdir -p build && cd build + # generate build configuration + python3 ../cmake/gen_cmake_config.py + # build mlc_llm libraries + cmake .. && cmake --build . --parallel $(nproc) && cd .. + +.. note:: + If you are using CUDA and your compute capability is above 80, then it is require to build with + ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during + runtime. + + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. + +**Step 3. Install via Python.** We recommend that you install ``mlc_chat`` as a Python package, giving you +access to ``mlc_chat.compile``, ``mlc_chat.ChatModule``, and the CLI. +There are two ways to do so: + + .. tabs :: + + .. code-tab :: bash Install via environment variable + + export PYTHONPATH=/path-to-mlc-llm/python:$PYTHONPATH + + .. code-tab :: bash Install via pip local project + + conda activate your-own-env + which python # make sure python is installed, expected output: path_to_conda/envs/your-own-env/bin/python + cd /path-to-mlc-llm/python + pip install -e . + +**Step 4. Validate installation.** You may validate if MLC libarires and mlc_chat CLI is compiled successfully using the following command: + +.. code-block:: bash + :caption: Validate installation + + # expected to see `libmlc_llm.so` and `libtvm_runtime.so` + ls -l ./build/ + # expected to see help message + mlc_chat chat -h + +Finally, you can verify installation in command line. You should see the path you used to build from source with: + +.. code:: bash + + python -c "import mlc_chat; print(mlc_chat)" diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst new file mode 100644 index 0000000..f5cb460 --- /dev/null +++ b/docs/install/tvm.rst @@ -0,0 +1,313 @@ +.. _install-tvm-unity: + +Install TVM Unity Compiler +========================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +`TVM Unity `__, the latest development in Apache TVM, is required to build MLC LLM. Its features include: + +- High-performance CPU/GPU code generation instantly without tuning; +- Dynamic shape and symbolic shape tracking by design; +- Supporting both inference and training; +- Productive python-first compiler implementation. As a concrete example, MLC LLM compilation is implemented in pure python using its API. + +TVM Unity can be installed directly from a prebuilt developer package, or built from source. + +.. _tvm-unity-prebuilt-package: + +Option 1. Prebuilt Package +-------------------------- + +A nightly prebuilt Python package of Apache TVM Unity is provided. + +.. note:: + ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + +.. tabs:: + + .. tab:: Linux + + .. tabs:: + + .. tab:: CPU + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + + .. tab:: CUDA 11.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu117 + + .. tab:: CUDA 11.8 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu118 + + .. tab:: CUDA 12.1 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 + + .. tab:: CUDA 12.2 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 + + .. tab:: ROCm 5.6 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 + + .. tab:: ROCm 5.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 + + .. tab:: Vulkan + + Supported in all Linux packages. + + .. note:: + + If encountering issues with GLIBC not found, please install the latest glibc in conda: + + .. code-block:: bash + + conda install -c conda-forge libgcc-ng + + .. tab:: macOS + + .. tabs:: + + .. tab:: CPU + Metal + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + + .. note:: + + Always check if conda is installed properly in macOS using the command below: + + .. code-block:: bash + + conda info | grep platform + + It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. + + .. tab:: Windows + + .. tabs:: + + .. tab:: CPU + Vulkan + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + + .. note:: + If encountering the error below: + + .. code-block:: bash + + FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. + + It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: + + .. code-block:: bash + + conda install zstd + +.. _tvm-unity-build-from-source: + +Option 2. Build from Source +--------------------------- + +While it is generally recommended to always use the prebuilt TVM Unity, if you require more customization, you may need to build it from source. **NOTE.** this should only be attempted if you are familiar with the intricacies of C++, CMake, LLVM, Python, and other related systems. + +.. collapse:: Details + + **Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are met: + + - CMake >= 3.24 + - LLVM >= 15 + - Git + - (Optional) CUDA >= 11.8 (targeting NVIDIA GPUs) + - (Optional) Metal (targeting Apple GPUs such as M1 and M2) + - (Optional) Vulkan (targeting NVIDIA, AMD, Intel and mobile GPUs) + - (Optional) OpenCL (targeting NVIDIA, AMD, Intel and mobile GPUs) + + .. note:: + - To target NVIDIA GPUs, either CUDA or Vulkan is required (CUDA is recommended); + - For AMD and Intel GPUs, Vulkan is necessary; + - When targeting Apple (macOS, iOS, iPadOS), Metal is a mandatory dependency; + - Some Android devices only support OpenCL, but most of them support Vulkan. + + To easiest way to manage dependency is via conda, which maintains a set of toolchains including LLVM across platforms. To create the environment of those build dependencies, one may simply use: + + .. code-block:: bash + :caption: Set up build dependencies in conda + + # make sure to start with a fresh environment + conda env remove -n tvm-build-venv + # create the conda environment with build dependency + conda create -n tvm-build-venv -c conda-forge \ + "llvmdev>=15" \ + "cmake>=3.24" \ + git + # enter the build environment + conda activate tvm-build-venv + + **Step 2. Configure and build.** Standard git-based workflow are recommended to download Apache TVM Unity, and then specify build requirements in ``config.cmake``: + + .. code-block:: bash + :caption: Download TVM Unity from GitHub + + # clone from GitHub + git clone --recursive git@github.com:mlc-ai/relax.git tvm-unity && cd tvm-unity + # create the build directory + rm -rf build && mkdir build && cd build + # specify build requirements in `config.cmake` + cp ../cmake/config.cmake . + + .. note:: + We are temporarily using `mlc-ai/relax `_ instead, which comes with several temporary outstanding changes that we will upstream to Apache TVM's `unity branch `_. + + We want to specifically tweak the following flags by appending them to the end of the configuration file: + + .. code-block:: bash + :caption: Configure build in ``config.cmake`` + + # controls default compilation flags + echo "set(CMAKE_BUILD_TYPE RelWithDebInfo)" >> config.cmake + # LLVM is a must dependency + echo "set(USE_LLVM \"llvm-config --ignore-libllvm --link-static\")" >> config.cmake + echo "set(HIDE_PRIVATE_SYMBOLS ON)" >> config.cmake + # GPU SDKs, turn on if needed + echo "set(USE_CUDA OFF)" >> config.cmake + echo "set(USE_METAL OFF)" >> config.cmake + echo "set(USE_VULKAN OFF)" >> config.cmake + echo "set(USE_OPENCL OFF)" >> config.cmake + # FlashInfer related, requires CUDA w/ compute capability 80;86;89;90 + echo "set(USE_FLASHINFER OFF)" >> config.cmake + echo "set(FLASHINFER_CUDA_ARCHITECTURES YOUR_CUDA_COMPUTE_CAPABILITY_HERE)" >> config.cmake + echo "set(CMAKE_CUDA_ARCHITECTURES YOUR_CUDA_COMPUTE_CAPABILITY_HERE)" >> config.cmake + + .. note:: + ``HIDE_PRIVATE_SYMBOLS`` is a configuration option that enables the ``-fvisibility=hidden`` flag. This flag helps prevent potential symbol conflicts between TVM and PyTorch. These conflicts arise due to the frameworks shipping LLVMs of different versions. + + `CMAKE_BUILD_TYPE `_ controls default compilation flag: + + - ``Debug`` sets ``-O0 -g`` + - ``RelWithDebInfo`` sets ``-O2 -g -DNDEBUG`` (recommended) + - ``Release`` sets ``-O3 -DNDEBUG`` + + .. note:: + If you are using CUDA and your compute capability is above 80, then it is require to build with + ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during + runtime. + + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. + + Once ``config.cmake`` is edited accordingly, kick off build with the commands below: + + .. code-block:: bash + :caption: Build ``libtvm`` using cmake and cmake + + cmake .. && cmake --build . --parallel $(nproc) + + A success build should produce ``libtvm`` and ``libtvm_runtime`` under ``/path-tvm-unity/build/`` directory. + + Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment: + + .. tabs :: + + .. code-tab :: bash Install via environment variable + + export PYTHONPATH=/path-to-tvm-unity/python:$PYTHONPATH + + .. code-tab :: bash Install via pip local project + + conda activate your-own-env + conda install python # make sure python is installed + cd /path-to-tvm-unity/python + pip install -e . + +.. `|` adds a blank line + +| + +.. _tvm-unity-validate: + +Validate TVM Installation +------------------------- + +Using a compiler infrastructure with multiple language bindings could be error-prone. +Therefore, it is highly recommended to validate TVM Unity installation before use. + +**Step 1. Locate TVM Python package.** The following command can help confirm that TVM is properly installed as a python package and provide the location of the TVM python package: + +.. code-block:: bash + + >>> python -c "import tvm; print(tvm.__file__)" + /some-path/lib/python3.11/site-packages/tvm/__init__.py + +**Step 2. Confirm which TVM library is used.** When maintaining multiple build or installation of TVM, it becomes important to double check if the python package is using the proper ``libtvm`` with the following command: + +.. code-block:: bash + + >>> python -c "import tvm; print(tvm._ffi.base._LIB)" + + +**Step 3. Reflect TVM build option.** Sometimes when downstream application fails, it could likely be some mistakes with a wrong TVM commit, or wrong build flags. To find it out, the following commands will be helpful: + +.. code-block:: bash + + >>> python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))" + ... # Omitted less relevant options + GIT_COMMIT_HASH: 4f6289590252a1cf45a4dc37bce55a25043b8338 + HIDE_PRIVATE_SYMBOLS: ON + USE_LLVM: llvm-config --link-static + LLVM_VERSION: 15.0.7 + USE_VULKAN: OFF + USE_CUDA: OFF + CUDA_VERSION: NOT-FOUND + USE_OPENCL: OFF + USE_METAL: ON + USE_ROCM: OFF + +.. note:: + ``GIT_COMMIT_HASH`` indicates the exact commit of the TVM build, and it can be found on GitHub via ``https://github.com/mlc-ai/relax/commit/$GIT_COMMIT_HASH``. + +**Step 4. Check device detection.** Sometimes it could be helpful to understand if TVM could detect your device at all with the following commands: + +.. code-block:: bash + + >>> python -c "import tvm; print(tvm.metal().exist)" + True # or False + >>> python -c "import tvm; print(tvm.cuda().exist)" + False # or True + >>> python -c "import tvm; print(tvm.vulkan().exist)" + False # or True + +Please note that the commands above verify the presence of an actual device on the local machine for the TVM runtime (not the compiler) to execute properly. However, TVM compiler can perform compilation tasks without requiring a physical device. As long as the necessary toolchain, such as NVCC, is available, TVM supports cross-compilation even in the absence of an actual device. diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst new file mode 100644 index 0000000..6d848d5 --- /dev/null +++ b/docs/prebuilt_models.rst @@ -0,0 +1,773 @@ +.. _Model Prebuilts: + +Model Prebuilts +================== + +.. contents:: Table of Contents + :depth: 3 + :local: + +.. _model-prebuilts-overview: + +Overview +-------- + +MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ +(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the +help of :doc:`TVM Unity `. + +There are two ways to run a model on MLC-LLM (this page focuses on the second one): + +1. Compile your own models following :doc:`the model compilation page `. +2. Use off-the-shelf prebuilt models following this current page. + +In order to run a specific model on MLC-LLM, you need: + +**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). +See the full list of all precompiled model libraries `here `__. + +**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model +(e.g. https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC). See the full list of all precompiled weights `here `__. + +In this page, we first quickly go over :ref:`how to use prebuilts ` for different platforms, +then track what current :ref:`prebuilt models we provide `. + + +.. _using-model-prebuilts: + +Using Prebuilt Models for Different Platforms +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. + +.. _using-prebuilt-models-cli: + +**Prebuilt Models on CLI / Python** + +For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. + +.. collapse:: Click to show details + + First create the conda environment if you have not done so. + + .. code:: shell + + conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly + conda activate mlc-chat-venv + conda install git git-lfs + git lfs install + + Download the prebuilt model libraries from github. + + .. code:: shell + + mkdir dist/ + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + + Run the model with CLI: + + .. code:: shell + + mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + + To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + + +.. for a blank line + +| + +.. _using-prebuilt-models-ios: + +**Prebuilt Models on iOS** + +For more, please see :doc:`the iOS page `. + +.. collapse:: Click to show details + + The `iOS app `_ has builtin RedPajama-3B and Mistral-7B-Instruct-v0.2 support. + + All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: + + .. list-table:: Prebuilt Models for iOS + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model Code + - Model Series + - Quantization Mode + - MLC HuggingFace Weights Repo + * - `Mistral-7B-Instruct-v0.2-q3f16_1` + - `Mistral `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `phi-2-q4f16_1` + - `Microsoft Phi-2 `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ +.. for a blank line + +| + +.. _prebuilt-models-android: + +**Prebuilt Models on Android** + +For more, please see :doc:`the Android page `. + +.. collapse:: Click to show details + + The apk for demo Android app includes the following models. To add more, check out the Android page. + + .. list-table:: Prebuilt Models for Android + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q4f16_1` + - `Llama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ +.. for a blank line + +| + +.. _supported-model-architectures: + +Level 1: Supported Model Architectures (The All-In-One Table) +------------------------------------------------------------- + +For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. + +Each entry below hyperlinks to the corresponding level 2 and level 3 tables. + +MLC-LLM supports the following model architectures: + +.. list-table:: Supported Model Architectures + :widths: 10 10 15 15 + :header-rows: 1 + + * - Model Architecture + - Support + - Available MLC Prebuilts + - Unavailable in MLC Prebuilts + * - `LLaMA `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`Llama-2-chat ` + - * `Code Llama `__ + * `Vicuna `__ + * `WizardLM `__ + * `WizardCoder (new) `__ + * `OpenOrca Platypus2 `__ + * `FlagAlpha Llama-2 Chinese `__ + * `georgesung Llama-2 Uncensored `__ + * `Alpaca `__ + * `Guanaco `__ + * `OpenLLaMA `__ + * `Gorilla `__ + * `YuLan-Chat `__ + * - `Mistral `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`Mistral-7B-Instruct-v0.2 ` + * :ref:`NeuralHermes-2.5-Mistral-7B ` + * :ref:`OpenHermes-2.5-Mistral-7B ` + * :ref:`WizardMath-7B-V1.1 ` + - + * - `GPT-NeoX `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`RedPajama ` + - * `Dolly `__ + * `Pythia `__ + * `StableCode `__ + * - `GPTBigCode `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - + - * `StarCoder `__ + * `SantaCoder `__ + * `WizardCoder (old) `__ + * - `Phi `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`Phi-1_5 ` + * :ref:`Phi-2 ` + - + * - `GPT2 `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`GPT2 ` + - + +If the model variant you are interested in uses one of these model architectures we support, +(but we have not provided the prebuilt weights yet), you can check out +:doc:`/compilation/convert_weights` on how to convert the weights. +Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt +weights to hugging face, and submit a PR that adds an entry to this page, +contributing to the community. + +For models structured in an architecture we have not supported yet, you could: + +- Either `create a [Model Request] issue `__ which + automatically shows up on our `Model Request Tracking Board `__. + +- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. + + +.. _model-library-tables: + +Level 2: Model Library Tables (Precompiled Binary Files) +-------------------------------------------------------- + +As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). + +Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: + +- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) + +- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) + +- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`quantization page ` + (e.g. ``q3f16_1`` vs. ``q4f16_1``). + +Each entry links to the specific model library file found in `this github repo `__. + +If the model library you found is not available as a prebuilt, you can compile it yourself by following :doc:`the model compilation page `, +and submit a PR to the repo `binary-mlc-llm-libs `__ afterwards. + +.. _llama_library_table: + +Llama +^^^^^ +.. list-table:: Llama + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - 7B + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - + * - 13B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - + - + - `q4f16_1 `__ + - + * - 34B + - + - + - + - + - + - + - + - + - + - + * - 70B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - + - + - `q4f16_1 `__ + - + +.. _mistral_library_table: + +Mistral +^^^^^^^ +.. list-table:: Mistral + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - 7B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - + - `q3f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + + +.. _gpt_neox_library_table: + +GPT-NeoX (RedPajama-INCITE) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. list-table:: GPT-NeoX (RedPajama-INCITE) + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - 3B + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - + +.. _gpt_big_code_library_table: + +GPTBigCode +^^^^^^^^^^ + +.. list-table:: GPTBigCode + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - 15B + - + - + - + - + - + - + - + - + - + - + +.. _phi_library_table: + +Phi +^^^ +.. list-table:: Phi + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - Phi-2 + + (2.7B) + - `q0f16 `__ + + `q4f16_1 `__ + - + - `q0f16 `__ + + `q4f16_1 `__ + - + - `q0f16 `__ + + `q4f16_1 `__ + - + - + - + - `q0f16 `__ + + `q4f16_1 `__ + - + * - Phi-1.5 + + (1.3B) + - `q0f16 `__ + + `q4f16_1 `__ + - + - `q0f16 `__ + + `q4f16_1 `__ + - + - `q0f16 `__ + + `q4f16_1 `__ + - + - + - + - `q0f16 `__ + + `q4f16_1 `__ + - + +.. _gpt2_library_table: + +GPT2 +^^^^ +.. list-table:: GPT2 + :widths: 8 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M Chip) + - Metal + + (Intel) + - iOS + - Android + - webgpu + - mali + * - GPT2 + + (124M) + - `q0f16 `__ + - + - `q0f16 `__ + - + - `q0f16 `__ + - + - + - + - `q0f16 `__ + - + * - GPT2-med + + (355M) + - `q0f16 `__ + - + - `q0f16 `__ + - + - `q0f16 `__ + - + - + - + - `q0f16 `__ + - + +.. _model-variant-tables: + +Level 3: Model Variant Tables (Precompiled Weights) +--------------------------------------------------- + +Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. + +Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. + +Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. + +Some of these files are uploaded by our community contributors--thank you! + +.. _llama2_variant_table: + +`Llama-2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: Llama-2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 (Chat) `__ + * `q4f32_1 (Chat) `__ + + * - 13B + - * `q4f16_1 `__ + + * - 70B + - * `q4f16_1 `__ + +.. _mistralinstruct_variant_table: + +`Mistral `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``mistral_default`` + +.. list-table:: Mistral + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_1 (Instruct) `__ + * `q4f16_1 (Instruct) `__ + +.. _neuralhermes_variant_table: + +`NeuralHermes-2.5-Mistral `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``neural_hermes_mistral`` + +.. list-table:: Neural Hermes + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + +.. _openhermes_variant_table: + +`OpenHermes-2-Mistral `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``open_hermes_mistral`` + +.. list-table:: Open Hermes + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + + + +.. _wizardmathv1.1_variant_table: + +`WizardMath V1.1 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardMath + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + + +.. _red_pajama_variant_table: + +`RedPajama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``redpajama_chat`` + +.. list-table:: Red Pajama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 3B + - * `q4f16_1 (Chat) `__ + * `q4f32_1 (Chat) `__ + + +.. _phi_variant_table: + +`Phi `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``phi-2`` + +.. list-table:: Phi + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - Phi-2 (2.7B) + - * `q0f16 `__ + * `q4f16_1 `__ + * - Phi-1.5 (1.3B) + - * `q0f16 `__ + * `q4f16_1 `__ + + +.. _gpt2_variant_table: + +`GPT2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``gpt2`` + +.. list-table:: GPT2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - GPT2 (124M) + - * `q0f16 `__ + * - GPT2-medium (355M) + - * `q0f16 `__ + + +------------------ + + +.. _contribute-models-to-mlc-llm: + +Contribute Models to MLC-LLM +---------------------------- + +Ready to contribute your compiled models/new model architectures? Awesome! Please check :ref:`contribute-new-models` on how to contribute new models to MLC-LLM. diff --git a/docs/prebuilt_models_deprecated.rst b/docs/prebuilt_models_deprecated.rst new file mode 100644 index 0000000..c18f3f3 --- /dev/null +++ b/docs/prebuilt_models_deprecated.rst @@ -0,0 +1,845 @@ +Model Prebuilts from Old Flow (Deprecated) +========================================== + +**This page records the model libraries weights compiled under the old workflow (non-SLM).** + +**We will remove this page soon.** + +.. contents:: Table of Contents + :depth: 3 + :local: + +Overview +-------- + +MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ +(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the +help of :doc:`TVM Unity `. + +There are two ways to run a model on MLC-LLM: + +1. Compile your own models following :doc:`the model compilation page `. +2. Use off-the-shelf prebuilts models following this current page. + +This page focuses on the second option: + +- Documenting :ref:`how to use prebuilts ` for various platforms, and +- Tracking what current :ref:`prebuilt models we provide `. + +Prerequisite: Model Libraries and Compiled Weights +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to run a specific model on MLC-LLM, you need: + +**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). See the full list of all precompiled model libraries `here `__. + +**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model (e.g. https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1). See the full list of all precompiled weights `here `__. + +.. _deprecated-using-model-prebuilts: + +Using Prebuilt Models for Different Platforms +--------------------------------------------- + +We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. + +.. _deprecated-using-prebuilt-models-cli: + + +Prebuilt Models on CLI / Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. + +.. collapse:: Click to show details + + First create the conda environment if you have not done so. + + .. code:: shell + + conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly + conda activate mlc-chat-venv + conda install git git-lfs + git lfs install + + Download the prebuilt model libraries from github. + + .. code:: shell + + mkdir -p dist/prebuilt + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib + + Download the prebuilt model weights from hugging face for the model variant you want. + + .. code:: shell + + # Say we want to run rwkv-raven-7b-q8f16_0 + cd dist/prebuilt + git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 + cd ../.. + + # The format being: + # cd dist/prebuilt + # git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] + # cd ../.. + # mlc_chat_cli --model [model-code] + + Run the model with CLI: + + .. code:: shell + + # For CLI + mlc_chat_cli --model rwkv-raven-7b-q8f16_0 + + To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + + +.. for a blank line + +| + +.. _deprecated-using-prebuilt-models-ios: + +Prebuilt Models on iOS +^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the iOS page `. + +.. collapse:: Click to show details + + The `iOS app `_ has builtin RedPajama-3B and Llama-2-7b support. + + All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: + + .. list-table:: Prebuilt model libraries integrated in the iOS app + :widths: 15 15 15 + :header-rows: 1 + + * - Model library name + - Model Family + - Quantization Mode + * - `Llama-2-7b-chat-hf-q3f16_1` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `vicuna-v1-7b-q3f16_0` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - GPT-NeoX + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + + As for prebuilt model weights, the ones we have integrated into app are listed below: + + .. list-table:: Tested prebuilt model weights for iOS + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q3f16_1` + - `Llama `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `vicuna-v1-7b-q3f16_0` + - `Vicuna `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + + To run a model variant you compiled on your own, you can directly reuse the above + integrated prebuilt model libraries, as long as the model shares the + architecture and is compiled with the same quantization mode. + For example, if you compile `OpenLLaMA-7B `_ + with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone + without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. + Then you can upload the compiled weights to hugging face so that you can download + the weights in the app as shown below (for more on uploading to hugging face, + please check :ref:`distribute-compiled-models`). + + To add a model to the iOS app, follow the steps below: + + .. tabs:: + + .. tab:: Step 1 + + Open "MLCChat" app, click "Add model variant". + + .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-1.png + :align: center + :width: 30% + + .. tab:: Step 2 + + Paste the repository URL of the model built on your own, and click "Add". + + You can refer to the link in the image as an example. + + .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-2.png + :align: center + :width: 30% + + .. tab:: Step 3 + + After adding the model, you can download your model from the URL by clicking the download button. + + .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-3.png + :align: center + :width: 30% + + .. tab:: Step 4 + + When the download is finished, click into the model and enjoy. + + .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-4.png + :align: center + :width: 30% + +.. for a blank line + +| + +.. _deprecated-prebuilt-models-android: + +Prebuilt Models on Android +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the Android page `. + +.. collapse:: Click to show details + + The apk for demo Android app includes the following models. To add more, check out the Android page. + + .. list-table:: Prebuilt Models for Android + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q4f16_1` + - `Llama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ +.. for a blank line + +| + +.. _deprecated-supported-model-architectures: + +Level 1: Supported Model Architectures (The All-In-One Table) +------------------------------------------------------------- + +For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. + +Each entry below hyperlinks to the corresponding level 2 and level 3 tables. + +MLC-LLM supports the following model architectures: + +.. list-table:: Supported Model Architectures + :widths: 10 10 15 15 + :header-rows: 1 + + * - Model Architecture + - Support + - Available MLC Prebuilts + - Unavailable in MLC Prebuilts + * - `LLaMA `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`Llama-2 ` + * :ref:`Code Llama ` + * :ref:`Vicuna ` + * :ref:`WizardLM ` + * :ref:`WizardMath ` + * :ref:`OpenOrca Platypus2 ` + * :ref:`FlagAlpha Llama-2 Chinese ` + * :ref:`georgesung Llama-2 Uncensored ` + - * `Alpaca `__ + * `Guanaco `__ + * `OpenLLaMA `__ + * `Gorilla `__ + * `YuLan-Chat `__ + * `WizardCoder (new) `__ + * - `GPT-NeoX `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`RedPajama ` + - * `Dolly `__ + * `Pythia `__ + * `StableCode `__ + * - `GPT-J `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `MOSS `__ + * - `RWKV `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`RWKV-raven ` + - + * - `MiniGPT `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `MiniGPT-4 `__ + * - `GPTBigCode `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`WizardCoder (old) ` + - * `StarCoder `__ + * `SantaCoder `__ + * - `ChatGLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `ChatGLM2 `__ + * `CodeGeeX2 `__ + * - `StableLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `StableLM `__ + +If the model variant you are interested in uses one of these model architectures we support, +(but we have not provided the prebuilt weights yet), you can check out +:doc:`/compilation/convert_weights` and :doc:`/compilation/compile_models` on how to compile your own models. +Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt +weights to hugging face, and submit a PR that adds an entry to this page, +contributing to the community. + +For models structured in an architecture we have not supported yet, you could: + +- Either `create a [Model Request] issue `__ which automatically shows up on our `Model Request Tracking Board `__. + +- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. + + +.. _deprecated-model-library-tables: + +Level 2: Model Library Tables (Precompiled Binary Files) +-------------------------------------------------------- + +As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). + +Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: + +- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) + +- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) + +- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`model compilation page ` (e.g. ``q3f16_1`` vs. ``q4f16_1``) + +Each entry links to the specific model library file found in `this github repo `__. + +.. _deprecated-llama_library_table: + +Llama +^^^^^ +.. list-table:: Llama + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 7B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q3f16_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 13B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 34B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - + - + - + * - 70B + - + - + - + - + - `q3f16_1 `__ + + `q4f16_1 `__ + - + - + - `q4f16_1 `__ + - + +.. _deprecated-gpt_neox_library_table: + +GPT-NeoX (RedPajama-INCITE) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. list-table:: GPT-NeoX (RedPajama-INCITE) + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 3B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + + `q4f32_0 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + +.. _deprecated-rwkv_library_table: + +RWKV +^^^^ +.. list-table:: RWKV + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 1B5 + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 3B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 7B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + +.. _deprecated-gpt_big_code_library_table: + +GPTBigCode +^^^^^^^^^^ +Note that these all links to model libraries for WizardCoder (the older version released in Jun. 2023). +However, any GPTBigCode model variants should be able to reuse these (e.g. StarCoder, SantaCoder). + +.. list-table:: GPTBigCode + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 15B + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + - + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + +.. _deprecated-model-variant-tables: + +Level 3: Model Variant Tables (Precompiled Weights) +--------------------------------------------------- + +Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. + +Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. + +Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. + +Some of these files are uploaded by our community contributors--thank you! + +.. _deprecated-llama2_variant_table: + +`Llama-2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: Llama-2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_1 `__ + * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 13B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 70B + - * `q3f16_1 `__ + * `q4f16_1 `__ + +.. _deprecated-code_llama_variant_table: + +`Code Llama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``codellama_completion`` + +.. list-table:: Code Llama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 13B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 34B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + +.. _deprecated-vicuna_variant_table: + +`Vicuna `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: Vicuna + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_0 `__ + * `q4f32_0 `__ + * `int3 (demo) `__ + * `int4 (demo) `__ + + +.. _deprecated-WizardLM_variant_table: + +`WizardLM `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: WizardLM + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - * `q4f16_1 (V1.2) `__ + * `q4f32_1 (V1.2) `__ + + * - 70B + - * `q3f16_1 (V1.0) `__ + * `q4f16_1 (V1.0) `__ + + +.. _deprecated-wizard_math_variant_table: + +`WizardMath `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardMath + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + * - 13B + - `q4f16_1 `__ + * - 70B + - `q4f16_1 `__ + + +.. _deprecated-open_orca_variant_table: + +`OpenOrca Platypus2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: OpenOrca Platypus2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - `q4f16_1 `__ + + +.. _deprecated-flag_alpha_llama2_variant_table: + +`FlagAlpha Llama-2 Chinese `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: FlagAlpha Llama-2 Chinese + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + +.. _deprecated-llama2_uncensored_variant_table: + +`Llama2 uncensored (georgesung) `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-default`` + +.. list-table:: Llama2 uncensored + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + +.. _deprecated-red_pajama_variant_table: + +`RedPajama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``LM`` + +.. list-table:: Red Pajama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 3B + - * `q4f16_0 (Instruct) `__ + * `q4f16_0 (Chat) `__ + * `q4f16_1 (Chat) `__ + * `q4f32_0 (Chat) `__ + + +.. _deprecated-rwkv_raven_variant_table: + +`RWKV-raven `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``rwkv`` + +.. list-table:: RWKV-raven + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 1B5 + - `q8f16_0 `__ + + * - 3B + - `q8f16_0 `__ + + * - 7B + - `q8f16_0 `__ + + +.. _deprecated-wizard_coder_variant_table: + +`WizardCoder `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardCoder + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 15B + - `q4f16_1 `__ + +------------------ + + +.. _deprecated-contribute-models-to-mlc-llm: + +Contribute Models to MLC-LLM +---------------------------- + +Ready to contribute your compiled models/new model architectures? Awesome! Please check :ref:`contribute-new-models` on how to contribute new models to MLC-LLM. diff --git a/docs/privacy.rst b/docs/privacy.rst new file mode 100644 index 0000000..cdd3c91 --- /dev/null +++ b/docs/privacy.rst @@ -0,0 +1,5 @@ +MLC Chat App Privacy +==================== + +MLC Chat run all generation locally. +All data stays in users' device and is not collected by the app. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..bc020bc --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,10 @@ +sphinx-tabs == 3.4.1 +sphinx-rtd-theme +sphinx == 5.2.3 +sphinx-toolbox == 3.4.0 +tlcpack-sphinx-addon==0.2.2 +sphinxcontrib_httpdomain==1.8.1 +sphinxcontrib-napoleon==0.7 +sphinx-reredirects==0.1.2 +--find-links https://mlc.ai/wheels +mlc-ai-nightly diff --git a/examples/python/benchmark.py b/examples/python/benchmark.py new file mode 100644 index 0000000..7cdbe78 --- /dev/null +++ b/examples/python/benchmark.py @@ -0,0 +1,11 @@ +from mlc_chat import ChatModule + +# From the mlc-llm directory, run +# $ python examples/python/benchmark.py + +# Create a ChatModule instance +cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + +output = cm.benchmark_generate("What's the meaning of life?", generate_length=256) +print(f"Generated text:\n{output}\n") +print(f"Statistics: {cm.stats()}") diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py new file mode 100644 index 0000000..a290eb8 --- /dev/null +++ b/examples/python/run_llama_batched_vllm.py @@ -0,0 +1,448 @@ +import argparse +import math +import os +import json +from collections import defaultdict +from typing import List +from dataclasses import dataclass + +import numpy as np + +import tvm +from tvm import relax +from tvm.runtime import disco as di + +import torch +from transformers import AutoTokenizer + +from mlc_llm.relax_model.llama import LlamaConfig +from mlc_llm import utils + + +class KVCache: + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): + if disco_session: + init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) + + self.block_tables = defaultdict(list) + self.slot_mappings = defaultdict(list) + self.block_size = block_size + + +class CacheManager: + block_size: int = 16 + + def __init__( + self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + ): + self.num_blocks = num_blocks + self.free_blocks = list(range(num_blocks)) + self.kv_cache = KVCache( + num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + ) + + if sliding_window: + assert sliding_window % self.kv_cache.block_size == 0 + self.block_sliding_window = sliding_window // self.kv_cache.block_size + else: + self.block_sliding_window = None + + def set_size(self, request_ids: List[int], target_sizes: List[int]): + for id, size in zip(request_ids, target_sizes): + num_needed_block = math.ceil(size / self.block_size) + + if self.block_sliding_window: + num_needed_block = min(num_needed_block, self.block_sliding_window) + + if id in self.kv_cache.block_tables and size == 0: + self.free_blocks.extend(self.kv_cache.block_tables[id]) + del self.kv_cache.block_tables[id] + del self.kv_cache.slot_mappings[id] + + elif id in self.kv_cache.block_tables: + # Decoding + if len(self.kv_cache.block_tables[id]) < num_needed_block: + # Need to allocate a new block for this request + assert len(self.kv_cache.block_tables[id]) + 1 == num_needed_block + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + pos = size - 1 + block_number = self.kv_cache.block_tables[id][-1] + + if self.block_sliding_window: + block_number = self.kv_cache.block_tables[id][ + (pos // self.block_size) % self.block_sliding_window + ] + else: + block_number = self.kv_cache.block_tables[id][-1] + + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + elif id not in self.kv_cache.block_tables: + assert len(self.free_blocks) >= num_needed_block, "Not enough free blocks." + + for _ in range(num_needed_block): + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + for i in range(size): + block_idx = i // self.block_size + + if self.block_sliding_window: + block_idx %= self.block_sliding_window + + block_number = self.kv_cache.block_tables[id][block_idx] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + def get(self): + return self.kv_cache + + +@dataclass +class SequenceGenerationRequest: + request_id: int + token_ids: List[int] + + +@dataclass +class SequenceGenerationResponse: + request_id: int + token_id: int + + +def sample(logits): + logits = torch.from_dlpack(logits) + return torch.argmax(logits, -1).cpu().numpy() + + +def load_params_disco(artifact_path, lib_path, num_shards): + sess = di.ProcessSession(num_workers=num_shards) + devices = range(num_shards) + sess.init_ccl("nccl", *devices) + module = sess.load_vm_module(lib_path) + + loader_create = sess.get_global_func("runtime.disco.ShardLoader") + metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") + with open(metadata_path, "r", encoding="utf-8") as f: + ndarray_cache_metadata = f.read() + + loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") + params = loader_load(loader) + + return module, params, sess + + +def copy_to_worker_0(sess: di.Session, host_array): + x_array = sess.empty(host_array.shape, host_array.dtype) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def get_tvm_model(artifact_path, model, quantization, num_shards, dev): + lib_path = os.path.join(artifact_path, f"{model}-{quantization}-cuda.so") + + if num_shards == 1: + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, dev) + params = utils.load_params(artifact_path, dev) + return vm.module, params, None + + return load_params_disco(artifact_path, lib_path, num_shards) + + +def _prepare_inputs( + requests, + all_slot_mappings, + all_block_tables, + sliding_window, + dev, + is_prefill, +): + block_tables = [] + seq_lens = [] + input_ids = [] + slot_mapping = [] + positions = [] + max_num_blocks_per_seq = 0 + indices_within_window = [] + start_idx = 0 + + for request in requests: + request_id = request.request_id + token_ids = request.token_ids + + if is_prefill: + input_ids += token_ids + prompt_len = len(token_ids) + seq_lens.append(prompt_len) + positions += range(prompt_len) + slot_mapping += all_slot_mappings[request_id] + + if sliding_window: + indices_within_window += range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + ) + start_idx += prompt_len + + else: + input_ids.append(token_ids[-1]) + pos = len(token_ids) - 1 + positions.append(pos) + block_table = all_block_tables[request_id] + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + block_tables.append(block_table) + slot_mapping.append(all_slot_mappings[request_id][-1]) + + if sliding_window: + seq_lens.append(min(len(token_ids), sliding_window)) + else: + seq_lens.append(len(token_ids)) + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + if is_prefill and sliding_window: + indices_within_window = tvm.nd.array(np.array(indices_within_window, dtype="int32"), dev) + else: + indices_within_window = None + + if not is_prefill: + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in block_tables + ] + + block_tables_np = np.vstack(padded_block_tables).astype("int32") + block_tables = tvm.nd.array(np.array(block_tables_np, dtype="int32"), dev) + else: + block_tables = None + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) + + +class Model: + def __init__( + self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + ): + self.mod, self.params, self.disco_session = get_tvm_model( + artifact_path, model_name, quant, num_shards, dev + ) + self.dev = dev + self.vocab_size = vocab_size + self.sliding_window = sliding_window + + if sliding_window: + self.block_sliding_window = sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + def generate( + self, requests: List[SequenceGenerationRequest], cache: KVCache, is_prefill: bool + ) -> List[SequenceGenerationResponse]: + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = _prepare_inputs( + requests, + cache.slot_mappings, + cache.block_tables, + self.sliding_window, + self.dev, + is_prefill, + ) + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + + kv_cache = cache.cache + + if is_prefill: + if self.sliding_window: + if self.disco_session: + indices_within_window = copy_to_worker_0( + self.disco_session, indices_within_window + ) + + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + indices_within_window, + self.params, + ) + else: + out = self.mod["prefill"]( + input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] # Ignore returned KV cache since it is updated in-place anyway. + else: + if self.disco_session: + block_tables = copy_to_worker_0(self.disco_session, block_tables) + + out = self.mod["decode"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + block_tables, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + next_tokens = sample(logits) + + return [ + SequenceGenerationResponse(request.request_id, new_token) + for request, new_token in zip(requests, next_tokens) + ] + + +def parse_args(): + # Example + # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention + # python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q4f16_ft + # + # For Disco: + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --build-model-only --num-shards 2 + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --convert-weight-only + # CUDA_VISIBLE_DEVICES=0,1 python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 + + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--num-shards", type=int, default=1) + args.add_argument("--num-decode-steps", type=int, default=20) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def run(args): + quantization = args.quantization.name + artifact_path = args.artifact_path + model_name = args.model + model_path = f"dist/models/{model_name}" + + dev = tvm.device("cuda", 0) + + with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: + config = LlamaConfig(**json.load(i_f)) + + model = Model( + artifact_path, + model_name, + quantization, + config.vocab_size, + args.num_shards, + dev, + config.sliding_window, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + + num_kv_heads = config.get_num_key_value_heads() // args.num_shards + head_size = config.hidden_size // config.num_attention_heads + num_blocks = 500 + + cache_manager = CacheManager( + num_blocks, + config.num_hidden_layers, + num_kv_heads, + head_size, + model.disco_session, + sliding_window=config.sliding_window, + ) + cache = cache_manager.get() + + model.block_sliding_window = cache_manager.block_sliding_window + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + batched_token_ids = [tokenizer.encode(p) for p in prompts] + prompts_len = [len(ids) for ids in batched_token_ids] + request_ids = list(range(len(prompts))) + target_sizes = [] + requests = [] + + for token_ids, request_id in zip(batched_token_ids, request_ids): + request_ids.append(request_id) + target_sizes.append(len(token_ids)) + requests.append(SequenceGenerationRequest(request_id, token_ids)) + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, True) + + for _ in range(args.num_decode_steps): + for i, response in enumerate(out): + new_token_id = response.token_id + requests[i].token_ids.append(new_token_id) + target_sizes[i] += 1 + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, False) + + output_tokens = [ + tokenizer.convert_ids_to_tokens( + requests[i].token_ids[prompts_len[i] :], skip_special_tokens=True + ) + for i in range(len(requests)) + ] + + generated = [tokenizer.convert_tokens_to_string(tokens) for tokens in output_tokens] + + for p, g in zip(prompts, generated): + print("Prompt = '{}', generated text = '{}'".format(p, g)) + + +if __name__ == "__main__": + run(parse_args()) diff --git a/examples/python/sample_chat_stream.py b/examples/python/sample_chat_stream.py new file mode 100644 index 0000000..980e833 --- /dev/null +++ b/examples/python/sample_chat_stream.py @@ -0,0 +1,30 @@ +from mlc_chat import ChatModule +from mlc_chat.callback import StreamToStdout, StreamIterator + +# From the mlc-llm directory, run +# $ python examples/python/sample_chat_stream.py + +# Create a ChatModule instance +cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + +# Stream to Stdout +output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), +) + +# Stream to an Iterator +from threading import Thread + +stream = StreamIterator(callback_interval=2) +generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, +) +generation_thread.start() + +output = "" +for delta_message in stream: + output += delta_message + +generation_thread.join() diff --git a/examples/python/sample_mlc_chat.py b/examples/python/sample_mlc_chat.py new file mode 100644 index 0000000..6d20d0c --- /dev/null +++ b/examples/python/sample_mlc_chat.py @@ -0,0 +1,39 @@ +from mlc_chat import ChatModule +from mlc_chat.callback import StreamToStdout + +# From the mlc-llm directory, run +# $ python examples/python/sample_mlc_chat.py + +# Create a ChatModule instance +cm = ChatModule( + model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so + # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so + # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} +) + +# You can change to other models that you downloaded +# Model variants of the same architecture can reuse the same model library +# Here WizardMath reuses Mistral's model library +# cm = ChatModule( +# model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" +# model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" +# ) + +# Generate a response for a given prompt +output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), +) + +# Print prefill and decode performance statistics +print(f"Statistics: {cm.stats()}\n") + +output = cm.generate( + prompt="How many points did you list out?", + progress_callback=StreamToStdout(callback_interval=2), +) + +# Reset the chat module by +# cm.reset_chat() diff --git a/examples/rest/nodejs/README.MD b/examples/rest/nodejs/README.MD new file mode 100755 index 0000000..1d63d54 --- /dev/null +++ b/examples/rest/nodejs/README.MD @@ -0,0 +1,21 @@ +# Node/Javascript/Typescript Access Examples for MLC_CHAT REST APIs + +Please make sure you are running v18.17.x of node (and npm v9.6.7) -- v20.x currently has some compatibility problems with typescript used in the langchain example. + +First install dependencies. + +`npm i` + +Copy `dotenv.exmaple` to `.env`. + +To run JS chat completion (both streaming and non-streaming) example: + +`node sample_client.js` + +To run OpenAI (chat completion streaming and non-streaming, and legacy completion) example: + +`node sample_openai.js` + +To run LangchainJS Typescript example: + +`npm run example` diff --git a/examples/rest/nodejs/dotenv.example b/examples/rest/nodejs/dotenv.example new file mode 100755 index 0000000..5312f49 --- /dev/null +++ b/examples/rest/nodejs/dotenv.example @@ -0,0 +1,2 @@ +OPENAI_API_KEY="none" +OPENAI_API_BASE="http://127.0.0.1:8000/v1" \ No newline at end of file diff --git a/examples/rest/nodejs/package.json b/examples/rest/nodejs/package.json new file mode 100755 index 0000000..2a3ebf2 --- /dev/null +++ b/examples/rest/nodejs/package.json @@ -0,0 +1,40 @@ +{ + "name": "mlc-llm-js-examples", + "version": "1.0.0", + "description": "", + "main": "index.js", + "type": "module", + "license": "AGPL-version-3.0", + "private": false, + "engines": { + "node": ">= 14.0.0", + "npm": ">= 6.0.0" + }, + "homepage": "", + "repository": { + "type": "git", + "url": "" + }, + "bugs": "", + "keywords": [], + "author": { + "name": "", + "email": "", + "url": "" + }, + "contributors": [], + "scripts": { + "example": "ts-node --esm ./sample_langchain.ts" + }, + "dependencies": { + "@types/node": "^20.4.4", + "dotenv": "^16.3.1", + "langchain": "^0.0.117", + "needle": "^3.2.0", + "openai": "^3.3.0", + "typescript": "^5.1.6" + }, + "devDependencies": { + "ts-node": "^10.9.1" + } +} diff --git a/examples/rest/nodejs/sample_client.js b/examples/rest/nodejs/sample_client.js new file mode 100755 index 0000000..9a85072 --- /dev/null +++ b/examples/rest/nodejs/sample_client.js @@ -0,0 +1,74 @@ +import request from 'needle'; + +( async () => { +const color = { + PURPLE : '\x1b[95m', + CYAN : '\x1b[96m', + DARKCYAN : '\x1b[36m', + BLUE : '\x1b[94m', + GREEN : '\x1b[92m', + YELLOW : '\x1b[93m', + RED : '\x1b[91m', + BOLD : '\x1b[1m', + UNDERLINE : '\x1b[4m', + END : '\x1b[0m' +}; + +let payload = { + model : 'vicuna-v1-7b', + messages: [{"role": "user", "content": "Write a haiku"}], + stream: false +}; + +const print = ( str ) => { + process.stdout.write(str); +}; + +const newline = () => { + print('\n'); +} + +newline(); +print(color.BOLD + "Without streaming:" + color.END); +newline(); + +let r = await request("post", "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}); + +print(color.GREEN + r.body.choices[0].message.content + color.END); +print('\n'); +// Reset the chat +r = await request("post", "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}); +print(color.BOLD + "Reset chat" + color.END); +newline(); + +// Get a response using a prompt with streaming + +payload = { + "model": "vicuna-v1-7b", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": true +} + +print( color.BOLD + "With streaming:" + color.END); +newline(); +r = request.post( "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}) +.on('readable', function() { + let jsData = ''; + let data = ''; + while (data = this.read()) { + const chunk = data.toString().substring(6); + if (chunk.trim() === "[DONE]") break; + jsData = JSON.parse(chunk); + print(color.GREEN + jsData.choices[0].delta.content + color.END); + } +}) +.on('done', async function () { + newline(); + let txtresp = await request("get", "http://127.0.0.1:8000/stats"); + print(color.BOLD + "Runtime stats:" + color.END + txtresp.body); + +}) + +})() + + diff --git a/examples/rest/nodejs/sample_langchain.ts b/examples/rest/nodejs/sample_langchain.ts new file mode 100644 index 0000000..48e849d --- /dev/null +++ b/examples/rest/nodejs/sample_langchain.ts @@ -0,0 +1,75 @@ +import { OpenAI } from "langchain/llms/openai"; +import { BufferWindowMemory } from "langchain/memory"; +import { LLMChain } from "langchain/chains"; +import { PromptTemplate } from "langchain/prompts"; +import {TextLoader } from "langchain/document_loaders/fs/text"; +import { loadQAStuffChain } from "langchain/chains"; + +const color = { + PURPLE : '\x1b[95m', + CYAN : '\x1b[96m', + DARKCYAN : '\x1b[36m', + BLUE : '\x1b[94m', + GREEN : '\x1b[92m', + YELLOW : '\x1b[93m', + RED : '\x1b[91m', + BOLD : '\x1b[1m', + UNDERLINE : '\x1b[4m', + END : '\x1b[0m' +}; + +function print(str: string) { + process.stdout.write(str); +} + +const newline = () => { + print('\n'); +} + + const chat = new OpenAI( { + openAIApiKey: "empty", + temperature: 0 + }, { + basePath: 'http://127.0.0.1:8000/v1' + }); + +// Conversational LLMChain example + const memory = new BufferWindowMemory({ memoryKey: "history", k: 1 }); + + const template = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. + + Current conversation: + {history} + Human: {human_input} + AI:`; + + + const prompt = PromptTemplate.fromTemplate(template); + let chain = new LLMChain({ llm: chat, prompt, memory }); + + let input = "Write a poem about Pittsburgh."; + print(color.BOLD + input + "..." + color.END); + newline(); + let res = await chain.call({ human_input: input }); + newline(); + print(color.GREEN + res.text + color.END); + newline(); + input = "What does it mean?"; + print(color.BOLD + input + "..." + color.END); + newline(); + res = await chain.call({ human_input: input }); + newline(); + print(color.GREEN + res.text + color.END); + newline(); + +// Question and answer stuff chain example with text loader +const loader = new TextLoader('../resources/linux.txt'); +const documents = await loader.load(); +const schain = loadQAStuffChain(chat); +const query = "When was Linux released?"; +newline(); newline(); +print(color.BOLD + "Query: " + color.END + color.BLUE + query + color.END); +newline(); +const result = await schain.call({ input_documents: documents, question: query}); +print(color.BOLD + "Response: " + color.END + color.GREEN + result.text + color.END); + diff --git a/examples/rest/nodejs/sample_openai.js b/examples/rest/nodejs/sample_openai.js new file mode 100755 index 0000000..6e06114 --- /dev/null +++ b/examples/rest/nodejs/sample_openai.js @@ -0,0 +1,77 @@ +import { Configuration, OpenAIApi } from "openai"; +import dotenv from "dotenv"; +dotenv.config(); + +( async () => { + +const configuration = new Configuration({ + apiKey: process.env.OPENAI_API_KEY, + basePath : process.env.OPENAI_API_BASE +}) +const openai = new OpenAIApi(configuration); +let model = "vicuna-v1-7b" + +const color = { + PURPLE : '\x1b[95m', + CYAN : '\x1b[96m', + DARKCYAN : '\x1b[36m', + BLUE : '\x1b[94m', + GREEN : '\x1b[92m', + YELLOW : '\x1b[93m', + RED : '\x1b[91m', + BOLD : '\x1b[1m', + UNDERLINE : '\x1b[4m', + END : '\x1b[0m' +}; + +const print = ( str ) => { + process.stdout.write(str); +}; + +const newline = () => { + print('\n'); +} + +// Chat completion example without streaming +newline(); +print(color.BOLD + "OpenAI chat completion example without streaming:" + color.END); +newline(); + +let completion = await openai.createChatCompletion({ + model: model, + messages: [{"role": "user", "content": "Write a poem about OpenAI"}] +}); + + +print(color.GREEN + completion.data.choices[0].message.content + color.END) +newline(); newline(); + + +// Chat completion example with streaming +// (raw implementation since npm module does not support it yet - it will have support in upcoming 4.x) + +print(color.BOLD + "OpenAI chat completion example with streaming:" + color.END); +newline(); +completion = await openai.createChatCompletion({ + model: model, + messages: [{"role": "user", "content": "Write a poem about OpenAI"}], + stream: true, +}, {responseType: 'stream'}); + +completion.data.on('data', async (data) => { + const parsed = JSON.parse(data.toString().substring(6)); + print(color.GREEN + parsed.choices[0].delta.content + color.END); +}); + +completion.data.on('close', async () => { + newline(); newline(); + + // Completion example + print(color.BOLD + "OpenAI completion example:" + color.END) + newline(); + let res = await openai.createCompletion({ prompt: "Write a poem about OpenAI", model: model}); + print(color.GREEN + res.data.choices[0].text + color.END); + newline(); newline(); + + }); +})() \ No newline at end of file diff --git a/examples/rest/nodejs/tsconfig.json b/examples/rest/nodejs/tsconfig.json new file mode 100755 index 0000000..bc563cb --- /dev/null +++ b/examples/rest/nodejs/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2020", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ + "lib": ["es2020"], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ + "module": "nodenext", /* Specify what module code is generated. */ + "rootDir": "src", /* Specify the root folder within your source files. */ + "outDir": "./dist", /* Specify an output folder for all emitted files. */ + "esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */ + "forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */ + "strict": true, /* Enable all strict type-checking options. */ + "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */ + "skipLibCheck": true /* Skip type checking all .d.ts files. */ + } +} diff --git a/examples/rest/python/sample_client.py b/examples/rest/python/sample_client.py new file mode 100644 index 0000000..1af1d83 --- /dev/null +++ b/examples/rest/python/sample_client.py @@ -0,0 +1,46 @@ +import requests +import json + +class color: + PURPLE = '\033[95m' + CYAN = '\033[96m' + DARKCYAN = '\033[36m' + BLUE = '\033[94m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + END = '\033[0m' + +# Get a response using a prompt without streaming +payload = { + "model": "vicuna-v1-7b", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": False +} +r = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload) +print(f"{color.BOLD}Without streaming:{color.END}\n{color.GREEN}{r.json()['choices'][0]['message']['content']}{color.END}\n") + +# Reset the chat +r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) +print(f"{color.BOLD}Reset chat:{color.END} {str(r)}\n") + +# Get a response using a prompt with streaming +payload = { + "model": "vicuna-v1-7b", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": True +} +with requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True) as r: + print(f"{color.BOLD}With streaming:{color.END}") + for chunk in r: + if (chunk[6:].decode('utf-8').strip() == '[DONE]'): + break + content = json.loads(chunk[6:])["choices"][0]["delta"].get("content", "") + print(f"{color.GREEN}{content}{color.END}", end="", flush=True) + print("\n") + +# Get the latest runtime stats +r = requests.get("http://127.0.0.1:8000/stats") +print(f"{color.BOLD}Runtime stats:{color.END} {r.json()}\n") diff --git a/examples/rest/python/sample_langchain.py b/examples/rest/python/sample_langchain.py new file mode 100644 index 0000000..cda326f --- /dev/null +++ b/examples/rest/python/sample_langchain.py @@ -0,0 +1,156 @@ +from langchain.chat_models import ChatOpenAI +from langchain import LLMChain, PromptTemplate +from langchain.memory import ConversationBufferWindowMemory +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.document_loaders import TextLoader, UnstructuredRSTLoader, DirectoryLoader +from langchain.chains.question_answering import load_qa_chain +from langchain.llms import OpenAI +from langchain.text_splitter import CharacterTextSplitter +from langchain.chains import RetrievalQA +from langchain.vectorstores import Chroma + +# Note that Langchain support for embedding documents using MLC is currently blocked on +# https://github.com/langchain-ai/langchain/pull/7815 +# We have subclassed `OpenAIEmbeddings` in the meantime to get around this dependency. +from mlc_chat.embeddings.openai import MLCEmbeddings + + + +# First set the following in your environment: +# export OPENAI_API_BASE=http://127.0.0.1:8000/v1 +# export OPENAI_API_KEY=EMPTY + +# Note that Langchain does not currently support Pydantic v2: +# https://github.com/langchain-ai/langchain/issues/6841 +# Please ensure that your `pydantic` version is < 2.0 + +class color: + PURPLE = '\033[95m' + CYAN = '\033[96m' + DARKCYAN = '\033[36m' + BLUE = '\033[94m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + END = '\033[0m' + +def llm_chain_example(): + template = """ + {history} + USER: {human_input} + ASSISTANT:""" + + prompt = PromptTemplate( + input_variables=["history", "human_input"], + template=template + ) + + llm_chain = LLMChain( + llm=ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()]), + prompt=prompt, + verbose=True, + memory=ConversationBufferWindowMemory(human_prefix="USER", ai_prefix="ASSISTANT") + ) + + output = llm_chain.predict(human_input="Write a short poem about Pittsburgh.") + output = llm_chain.predict(human_input="What does the poem mean?") + +def load_qa_chain_example(): + loader = TextLoader('../resources/linux.txt') + documents = loader.load() + chain = load_qa_chain(llm=OpenAI(), chain_type="stuff", verbose=False) + query = "When was Linux released?" + print(f"{color.BOLD}Query:{color.END} {color.BLUE} {query}{color.END}") + print(f"{color.BOLD}Response:{color.END} {color.GREEN}{chain.run(input_documents=documents, question=query)}{color.END}") + +def retrieval_qa_sotu_example(): + prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. + + {context} + + USER: {question} + ASSISTANT:""" + + PROMPT = PromptTemplate( + template=prompt_template, input_variables=["context", "question"] + ) + + loader = TextLoader('../resources/state_of_the_union.txt') + documents = loader.load() + + text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100) + texts = text_splitter.split_documents(documents) + # print(texts) + embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) + db = Chroma.from_documents(documents=texts, embedding=embeddings) + retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":2}) + qa = RetrievalQA.from_chain_type( + llm=OpenAI(), + chain_type="stuff", + retriever=retriever, + return_source_documents=True, + chain_type_kwargs={"prompt": PROMPT} + ) + questions = [ + "What is the American Rescue Plan?", + "What did the president say about Ketanji Brown Jackson?", + "Who is mentioned in the speech?", + "To whom is the speech addressed?", + "Tell me more about the Made in America campaign." + ] + + for qn in questions: + print(f"{color.BOLD}QUESTION:{color.END} {qn}") + res = qa({'query': qn}) + print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") + print(f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}") + print() + +def retrieval_qa_mlc_docs_example(): + prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. + + {context} + + USER: {question} + ASSISTANT:""" + + PROMPT = PromptTemplate( + template=prompt_template, input_variables=["context", "question"] + ) + + loader = DirectoryLoader("../../../docs", glob='*/*.rst', show_progress=True, loader_cls=UnstructuredRSTLoader, loader_kwargs={"mode": "single"}) + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) + texts = text_splitter.split_documents(documents) + embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) + db = Chroma.from_documents(collection_name="abc", documents=texts, embedding=embeddings) + retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":3}) + qa = RetrievalQA.from_chain_type( + llm=OpenAI(), + chain_type="stuff", + retriever=retriever, + return_source_documents=True, + chain_type_kwargs={"prompt": PROMPT} + ) + while True: + qn = input(f"{color.BOLD}QUESTION:{color.END} ") + res = qa({'query': qn}) + print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") + print(f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}") + print() + + # Some example questions: + # - What is the chat config? + # - What is temperature? + # - What are the REST API endpoints? + # - What are the available quantization options? + + +# Uncomment one of the following lines to try out the corresponding demo: + +# llm_chain_example() +# load_qa_chain_example() +# retrieval_qa_sotu_example() +# retrieval_qa_mlc_docs_example() diff --git a/examples/rest/python/sample_openai.py b/examples/rest/python/sample_openai.py new file mode 100644 index 0000000..1c4acb0 --- /dev/null +++ b/examples/rest/python/sample_openai.py @@ -0,0 +1,43 @@ +import openai + +openai.api_key = "None" +openai.api_base = "http://127.0.0.1:8000/v1" + +model = "vicuna-v1-7b" + +class color: + PURPLE = '\033[95m' + CYAN = '\033[96m' + DARKCYAN = '\033[36m' + BLUE = '\033[94m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + END = '\033[0m' + +# Chat completion example without streaming +print(f"{color.BOLD}OpenAI chat completion example without streaming:{color.END}\n") +completion = openai.ChatCompletion.create( + model=model, + messages=[{"role": "user", "content": "Write a poem about OpenAI"}] +) +print(f"{color.GREEN}{completion.choices[0].message.content}{color.END}\n\n") + +# Chat completion example with streaming +print(f"{color.BOLD}OpenAI chat completion example with streaming:{color.END}\n") +res = openai.ChatCompletion.create( + model=model, + messages=[{"role": "user", "content": "Write a poem about OpenAI"}], + stream=True +) +for chunk in res: + content = chunk["choices"][0]["delta"].get("content", "") + print(f"{color.GREEN}{content}{color.END}", end="", flush=True) +print("\n") + +# Completion example +print(f"{color.BOLD}OpenAI completion example:{color.END}\n") +res = openai.Completion.create(prompt="Write a poem about OpenAI", model=model) +print(f"{color.GREEN}{res.choices[0].text}{color.END}\n\n") diff --git a/examples/rest/resources/linux.txt b/examples/rest/resources/linux.txt new file mode 100644 index 0000000..9f09b49 --- /dev/null +++ b/examples/rest/resources/linux.txt @@ -0,0 +1,23 @@ +Linux is a family of open-source Unix-like operating systems based on the Linux kernel, an operating system kernel first released on September 17, 1991, by Linus Torvalds. Linux is typically packaged as a Linux distribution, which includes the kernel and supporting system software and libraries, many of which are provided by the GNU Project. Many Linux distributions use the word "Linux" in their name, but the Free Software Foundation uses the name "GNU/Linux" to emphasize the importance of GNU software, causing some controversy. + +Popular Linux distributions include Debian, Fedora Linux, and Ubuntu, the latter of which itself consists of many different distributions and modifications, including Lubuntu and Xubuntu. Commercial distributions include Red Hat Enterprise Linux and SUSE Linux Enterprise. Desktop Linux distributions include a windowing system such as X11 or Wayland, and a desktop environment such as GNOME or KDE Plasma. Distributions intended for servers may omit graphics altogether, or include a solution stack such as LAMP. Because Linux is freely redistributable, anyone may create a distribution for any purpose. + +Linux was originally developed for personal computers based on the Intel x86 architecture, but has since been ported to more platforms than any other operating system. Because of the dominance of the Linux-based Android on smartphones, Linux, including Android, has the largest installed base of all general-purpose operating systems, as of May 2022. Although Linux is, as of November 2022, used by only around 2.6 percent of desktop computers, the Chromebook, which runs the Linux kernel-based ChromeOS, dominates the US K–12 education market and represents nearly 20 percent of sub-$300 notebook sales in the US. Linux is the leading operating system on servers (over 96.4% of the top 1 million web servers' operating systems are Linux), leads other big iron systems such as mainframe computers, and is used on all of the world's 500 fastest supercomputers (since November 2017, having gradually displaced all competitors). + +Linux also runs on embedded systems, i.e. devices whose operating system is typically built into the firmware and is highly tailored to the system. This includes routers, automation controls, smart home devices, video game consoles, televisions (Samsung and LG Smart TVs), automobiles (Tesla, Audi, Mercedes-Benz, Hyundai and Toyota), and spacecraft (Falcon 9 rocket, Dragon crew capsule and the Perseverance rover). + +Linux is one of the most prominent examples of free and open-source software collaboration. The source code may be used, modified and distributed commercially or non-commercially by anyone under the terms of its respective licenses, such as the GNU General Public License (GPL). The Linux kernel, for example, is licensed under the GPLv2, with an exception for system calls that allows code that calls the kernel via system calls not to be licensed under the GPL. + +The Unix operating system was conceived and implemented in 1969, at AT&T's Bell Labs, in the United States by Ken Thompson, Dennis Ritchie, Douglas McIlroy, and Joe Ossanna. First released in 1971, Unix was written entirely in assembly language, as was common practice at the time. In 1973, in a key pioneering approach, it was rewritten in the C programming language by Dennis Ritchie (with the exception of some hardware and I/O routines). The availability of a high-level language implementation of Unix made its porting to different computer platforms easier. + +Due to an earlier antitrust case forbidding it from entering the computer business, AT&T licensed the operating system's source code as a trade secret to anyone who asked. As a result, Unix grew quickly and became widely adopted by academic institutions and businesses. In 1984, AT&T divested itself of its regional operating companies, and was released from its obligation not to enter the computer business; freed of that obligation, Bell Labs began selling Unix as a proprietary product, where users were not legally allowed to modify it. + +Onyx Systems began selling early microcomputer-based Unix workstations in 1980. Later, Sun Microsystems, founded as a spin-off of a student project at Stanford University, also began selling Unix-based desktop workstations in 1982. While Sun workstations did not utilize commodity PC hardware, for which Linux was later originally developed, it represented the first successful commercial attempt at distributing a primarily single-user microcomputer that ran a Unix operating system. + +With Unix increasingly "locked in" as a proprietary product, the GNU Project, started in 1983 by Richard Stallman, had the goal of creating a "complete Unix-compatible software system" composed entirely of free software. Work began in 1984. Later, in 1985, Stallman started the Free Software Foundation and wrote the GNU General Public License (GNU GPL) in 1989. By the early 1990s, many of the programs required in an operating system (such as libraries, compilers, text editors, a command-line shell, and a windowing system) were completed, although low-level elements such as device drivers, daemons, and the kernel, called GNU Hurd, were stalled and incomplete. + +MINIX was created by Andrew S. Tanenbaum, a computer science professor, and released in 1987 as a minimal Unix-like operating system targeted at students and others who wanted to learn operating system principles. Although the complete source code of MINIX was freely available, the licensing terms prevented it from being free software until the licensing changed in April 2000. + +Although not released until 1992, due to legal complications, development of 386BSD, from which NetBSD, OpenBSD and FreeBSD descended, predated that of Linux. + +Linus Torvalds has stated on separate occasions that if the GNU kernel or 386BSD had been available at the time (1991), he probably would not have created Linux. \ No newline at end of file diff --git a/examples/rest/resources/state_of_the_union.txt b/examples/rest/resources/state_of_the_union.txt new file mode 100644 index 0000000..d50175d --- /dev/null +++ b/examples/rest/resources/state_of_the_union.txt @@ -0,0 +1,723 @@ +Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. + +Last year COVID-19 kept us apart. This year we are finally together again. + +Tonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. + +With a duty to one another to the American people to the Constitution. + +And with an unwavering resolve that freedom will always triumph over tyranny. + +Six days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. + +He thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. + +He met the Ukrainian people. + +From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. + +Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. + +In this struggle as President Zelenskyy said in his speech to the European Parliament “Light will win over darkness.” The Ukrainian Ambassador to the United States is here tonight. + +Let each of us here tonight in this Chamber send an unmistakable signal to Ukraine and to the world. + +Please rise if you are able and show that, Yes, we the United States of America stand with the Ukrainian people. + +Throughout our history we’ve learned this lesson when dictators do not pay a price for their aggression they cause more chaos. + +They keep moving. + +And the costs and the threats to America and the world keep rising. + +That’s why the NATO Alliance was created to secure peace and stability in Europe after World War 2. + +The United States is a member along with 29 other nations. + +It matters. American diplomacy matters. American resolve matters. + +Putin’s latest attack on Ukraine was premeditated and unprovoked. + +He rejected repeated efforts at diplomacy. + +He thought the West and NATO wouldn’t respond. And he thought he could divide us at home. Putin was wrong. We were ready. Here is what we did. + +We prepared extensively and carefully. + +We spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin. + +I spent countless hours unifying our European allies. We shared with the world in advance what we knew Putin was planning and precisely how he would try to falsely justify his aggression. + +We countered Russia’s lies with truth. + +And now that he has acted the free world is holding him accountable. + +Along with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland. + +We are inflicting pain on Russia and supporting the people of Ukraine. Putin is now isolated from the world more than ever. + +Together with our allies –we are right now enforcing powerful economic sanctions. + +We are cutting off Russia’s largest banks from the international financial system. + +Preventing Russia’s central bank from defending the Russian Ruble making Putin’s $630 Billion “war fund” worthless. + +We are choking off Russia’s access to technology that will sap its economic strength and weaken its military for years to come. + +Tonight I say to the Russian oligarchs and corrupt leaders who have bilked billions of dollars off this violent regime no more. + +The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. + +We are joining with our European allies to find and seize your yachts your luxury apartments your private jets. We are coming for your ill-begotten gains. + +And tonight I am announcing that we will join our allies in closing off American air space to all Russian flights – further isolating Russia – and adding an additional squeeze –on their economy. The Ruble has lost 30% of its value. + +The Russian stock market has lost 40% of its value and trading remains suspended. Russia’s economy is reeling and Putin alone is to blame. + +Together with our allies we are providing support to the Ukrainians in their fight for freedom. Military assistance. Economic assistance. Humanitarian assistance. + +We are giving more than $1 Billion in direct assistance to Ukraine. + +And we will continue to aid the Ukrainian people as they defend their country and to help ease their suffering. + +Let me be clear, our forces are not engaged and will not engage in conflict with Russian forces in Ukraine. + +Our forces are not going to Europe to fight in Ukraine, but to defend our NATO Allies – in the event that Putin decides to keep moving west. + +For that purpose we’ve mobilized American ground forces, air squadrons, and ship deployments to protect NATO countries including Poland, Romania, Latvia, Lithuania, and Estonia. + +As I have made crystal clear the United States and our Allies will defend every inch of territory of NATO countries with the full force of our collective power. + +And we remain clear-eyed. The Ukrainians are fighting back with pure courage. But the next few days weeks, months, will be hard on them. + +Putin has unleashed violence and chaos. But while he may make gains on the battlefield – he will pay a continuing high price over the long run. + +And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. + +To all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. + +And I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. + +Tonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. + +America will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. + +These steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. + +But I want you to know that we are going to be okay. + +When the history of this era is written Putin’s war on Ukraine will have left Russia weaker and the rest of the world stronger. + +While it shouldn’t have taken something so terrible for people around the world to see what’s at stake now everyone sees it clearly. + +We see the unity among leaders of nations and a more unified Europe a more unified West. And we see unity among the people who are gathering in cities in large crowds around the world even in Russia to demonstrate their support for Ukraine. + +In the battle between democracy and autocracy, democracies are rising to the moment, and the world is clearly choosing the side of peace and security. + +This is a real test. It’s going to take time. So let us continue to draw inspiration from the iron will of the Ukrainian people. + +To our fellow Ukrainian Americans who forge a deep bond that connects our two nations we stand with you. + +Putin may circle Kyiv with tanks, but he will never gain the hearts and souls of the Ukrainian people. + +He will never extinguish their love of freedom. He will never weaken the resolve of the free world. + +We meet tonight in an America that has lived through two of the hardest years this nation has ever faced. + +The pandemic has been punishing. + +And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. + +I understand. + +I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. + +That’s why one of the first things I did as President was fight to pass the American Rescue Plan. + +Because people were hurting. We needed to act, and we did. + +Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. + +It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. + +Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. + +And as my Dad used to say, it gave people a little breathing room. + +And unlike the $2 Trillion tax cut passed in the previous administration that benefitted the top 1% of Americans, the American Rescue Plan helped working people—and left no one behind. + +And it worked. It created jobs. Lots of jobs. + +In fact—our economy created over 6.5 Million new jobs just last year, more jobs created in one year +than ever before in the history of America. + +Our economy grew at a rate of 5.7% last year, the strongest growth in nearly 40 years, the first step in bringing fundamental change to an economy that hasn’t worked for the working people of this nation for too long. + +For the past 40 years we were told that if we gave tax breaks to those at the very top, the benefits would trickle down to everyone else. + +But that trickle-down theory led to weaker economic growth, lower wages, bigger deficits, and the widest gap between those at the top and everyone else in nearly a century. + +Vice President Harris and I ran for office with a new economic vision for America. + +Invest in America. Educate Americans. Grow the workforce. Build the economy from the bottom up +and the middle out, not from the top down. + +Because we know that when the middle class grows, the poor have a ladder up and the wealthy do very well. + +America used to have the best roads, bridges, and airports on Earth. + +Now our infrastructure is ranked 13th in the world. + +We won’t be able to compete for the jobs of the 21st Century if we don’t fix that. + +That’s why it was so important to pass the Bipartisan Infrastructure Law—the most sweeping investment to rebuild America in history. + +This was a bipartisan effort, and I want to thank the members of both parties who worked to make it happen. + +We’re done talking about infrastructure weeks. + +We’re going to have an infrastructure decade. + +It is going to transform America and put us on a path to win the economic competition of the 21st Century that we face with the rest of the world—particularly with China. + +As I’ve told Xi Jinping, it is never a good bet to bet against the American people. + +We’ll create good jobs for millions of Americans, modernizing roads, airports, ports, and waterways all across America. + +And we’ll do it all to withstand the devastating effects of the climate crisis and promote environmental justice. + +We’ll build a national network of 500,000 electric vehicle charging stations, begin to replace poisonous lead pipes—so every child—and every American—has clean water to drink at home and at school, provide affordable high-speed internet for every American—urban, suburban, rural, and tribal communities. + +4,000 projects have already been announced. + +And tonight, I’m announcing that this year we will start fixing over 65,000 miles of highway and 1,500 bridges in disrepair. + +When we use taxpayer dollars to rebuild America – we are going to Buy American: buy American products to support American jobs. + +The federal government spends about $600 Billion a year to keep the country safe and secure. + +There’s been a law on the books for almost a century +to make sure taxpayers’ dollars support American jobs and businesses. + +Every Administration says they’ll do it, but we are actually doing it. + +We will buy American to make sure everything from the deck of an aircraft carrier to the steel on highway guardrails are made in America. + +But to compete for the best jobs of the future, we also need to level the playing field with China and other competitors. + +That’s why it is so important to pass the Bipartisan Innovation Act sitting in Congress that will make record investments in emerging technologies and American manufacturing. + +Let me give you one example of why it’s so important to pass it. + +If you travel 20 miles east of Columbus, Ohio, you’ll find 1,000 empty acres of land. + +It won’t look like much, but if you stop and look closely, you’ll see a “Field of dreams,” the ground on which America’s future will be built. + +This is where Intel, the American company that helped build Silicon Valley, is going to build its $20 billion semiconductor “mega site”. + +Up to eight state-of-the-art factories in one place. 10,000 new good-paying jobs. + +Some of the most sophisticated manufacturing in the world to make computer chips the size of a fingertip that power the world and our everyday lives. + +Smartphones. The Internet. Technology we have yet to invent. + +But that’s just the beginning. + +Intel’s CEO, Pat Gelsinger, who is here tonight, told me they are ready to increase their investment from +$20 billion to $100 billion. + +That would be one of the biggest investments in manufacturing in American history. + +And all they’re waiting for is for you to pass this bill. + +So let’s not wait any longer. Send it to my desk. I’ll sign it. + +And we will really take off. + +And Intel is not alone. + +There’s something happening in America. + +Just look around and you’ll see an amazing story. + +The rebirth of the pride that comes from stamping products “Made In America.” The revitalization of American manufacturing. + +Companies are choosing to build new factories here, when just a few years ago, they would have built them overseas. + +That’s what is happening. Ford is investing $11 billion to build electric vehicles, creating 11,000 jobs across the country. + +GM is making the largest investment in its history—$7 billion to build electric vehicles, creating 4,000 jobs in Michigan. + +All told, we created 369,000 new manufacturing jobs in America just last year. + +Powered by people I’ve met like JoJo Burgess, from generations of union steelworkers from Pittsburgh, who’s here with us tonight. + +As Ohio Senator Sherrod Brown says, “It’s time to bury the label “Rust Belt.” + +It’s time. + +But with all the bright spots in our economy, record job growth and higher wages, too many families are struggling to keep up with the bills. + +Inflation is robbing them of the gains they might otherwise feel. + +I get it. That’s why my top priority is getting prices under control. + +Look, our economy roared back faster than most predicted, but the pandemic meant that businesses had a hard time hiring enough workers to keep up production in their factories. + +The pandemic also disrupted global supply chains. + +When factories close, it takes longer to make goods and get them from the warehouse to the store, and prices go up. + +Look at cars. + +Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy. + +And guess what, prices of automobiles went up. + +So—we have a choice. + +One way to fight inflation is to drive down wages and make Americans poorer. + +I have a better plan to fight inflation. + +Lower your costs, not your wages. + +Make more cars and semiconductors in America. + +More infrastructure and innovation in America. + +More goods moving faster and cheaper in America. + +More jobs where you can earn a good living in America. + +And instead of relying on foreign supply chains, let’s make it in America. + +Economists call it “increasing the productive capacity of our economy.” + +I call it building a better America. + +My plan to fight inflation will lower your costs and lower the deficit. + +17 Nobel laureates in economics say my plan will ease long-term inflationary pressures. Top business leaders and most Americans support my plan. And here’s the plan: + +First – cut the cost of prescription drugs. Just look at insulin. One in ten Americans has diabetes. In Virginia, I met a 13-year-old boy named Joshua Davis. + +He and his Dad both have Type 1 diabetes, which means they need insulin every day. Insulin costs about $10 a vial to make. + +But drug companies charge families like Joshua and his Dad up to 30 times more. I spoke with Joshua’s mom. + +Imagine what it’s like to look at your child who needs insulin and have no idea how you’re going to pay for it. + +What it does to your dignity, your ability to look your child in the eye, to be the parent you expect to be. + +Joshua is here with us tonight. Yesterday was his birthday. Happy birthday, buddy. + +For Joshua, and for the 200,000 other young people with Type 1 diabetes, let’s cap the cost of insulin at $35 a month so everyone can afford it. + +Drug companies will still do very well. And while we’re at it let Medicare negotiate lower prices for prescription drugs, like the VA already does. + +Look, the American Rescue Plan is helping millions of families on Affordable Care Act plans save $2,400 a year on their health care premiums. Let’s close the coverage gap and make those savings permanent. + +Second – cut energy costs for families an average of $500 a year by combatting climate change. + +Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again. + +Third – cut the cost of child care. Many families pay up to $14,000 a year for child care per child. + +Middle-class and working families shouldn’t have to pay more than 7% of their income for care of young children. + +My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. + +My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. + +All of these will lower costs. + +And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. + +The one thing all Americans agree on is that the tax system is not fair. We have to fix it. + +I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. + +Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. + +That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations. + +We got more than 130 countries to agree on a global minimum tax rate so companies can’t get out of paying their taxes at home by shipping jobs and factories overseas. + +That’s why I’ve proposed closing loopholes so the very wealthy don’t pay a lower tax rate than a teacher or a firefighter. + +So that’s my plan. It will grow the economy and lower costs for families. + +So what are we waiting for? Let’s get this done. And while you’re at it, confirm my nominees to the Federal Reserve, which plays a critical role in fighting inflation. + +My plan will not only lower costs to give families a fair shot, it will lower the deficit. + +The previous Administration not only ballooned the deficit with tax cuts for the very wealthy and corporations, it undermined the watchdogs whose job was to keep pandemic relief funds from being wasted. + +But in my administration, the watchdogs have been welcomed back. + +We’re going after the criminals who stole billions in relief money meant for small businesses and millions of Americans. + +And tonight, I’m announcing that the Justice Department will name a chief prosecutor for pandemic fraud. + +By the end of this year, the deficit will be down to less than half what it was before I took office. + +The only president ever to cut the deficit by more than one trillion dollars in a single year. + +Lowering your costs also means demanding more competition. + +I’m a capitalist, but capitalism without competition isn’t capitalism. + +It’s exploitation—and it drives up prices. + +When corporations don’t have to compete, their profits go up, your prices go up, and small businesses and family farmers and ranchers go under. + +We see it happening with ocean carriers moving goods in and out of America. + +During the pandemic, these foreign-owned companies raised prices by as much as 1,000% and made record profits. + +Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. + +And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. + +That ends on my watch. + +Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. + +We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. + +Let’s pass the Paycheck Fairness Act and paid leave. + +Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. + +Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges. + +And let’s pass the PRO Act when a majority of workers want to form a union—they shouldn’t be stopped. + +When we invest in our workers, when we build the economy from the bottom up and the middle out together, we can do something we haven’t done in a long time: build a better America. + +For more than two years, COVID-19 has impacted every decision in our lives and the life of the nation. + +And I know you’re tired, frustrated, and exhausted. + +But I also know this. + +Because of the progress we’ve made, because of your resilience and the tools we have, tonight I can say +we are moving forward safely, back to more normal routines. + +We’ve reached a new moment in the fight against COVID-19, with severe cases down to a level not seen since last July. + +Just a few days ago, the Centers for Disease Control and Prevention—the CDC—issued new mask guidelines. + +Under these new guidelines, most Americans in most of the country can now be mask free. + +And based on the projections, more of the country will reach that point across the next couple of weeks. + +Thanks to the progress we have made this past year, COVID-19 need no longer control our lives. + +I know some are talking about “living with COVID-19”. Tonight – I say that we will never just accept living with COVID-19. + +We will continue to combat the virus as we do other diseases. And because this is a virus that mutates and spreads, we will stay on guard. + +Here are four common sense steps as we move forward safely. + +First, stay protected with vaccines and treatments. We know how incredibly effective vaccines are. If you’re vaccinated and boosted you have the highest degree of protection. + +We will never give up on vaccinating more Americans. Now, I know parents with kids under 5 are eager to see a vaccine authorized for their children. + +The scientists are working hard to get that done and we’ll be ready with plenty of vaccines when they do. + +We’re also ready with anti-viral treatments. If you get COVID-19, the Pfizer pill reduces your chances of ending up in the hospital by 90%. + +We’ve ordered more of these pills than anyone in the world. And Pfizer is working overtime to get us 1 Million pills this month and more than double that next month. + +And we’re launching the “Test to Treat” initiative so people can get tested at a pharmacy, and if they’re positive, receive antiviral pills on the spot at no cost. + +If you’re immunocompromised or have some other vulnerability, we have treatments and free high-quality masks. + +We’re leaving no one behind or ignoring anyone’s needs as we move forward. + +And on testing, we have made hundreds of millions of tests available for you to order for free. + +Even if you already ordered free tests tonight, I am announcing that you can order more from covidtests.gov starting next week. + +Second – we must prepare for new variants. Over the past year, we’ve gotten much better at detecting new variants. + +If necessary, we’ll be able to deploy new vaccines within 100 days instead of many more months or years. + +And, if Congress provides the funds we need, we’ll have new stockpiles of tests, masks, and pills ready if needed. + +I cannot promise a new variant won’t come. But I can promise you we’ll do everything within our power to be ready if it does. + +Third – we can end the shutdown of schools and businesses. We have the tools we need. + +It’s time for Americans to get back to work and fill our great downtowns again. People working from home can feel safe to begin to return to the office. + +We’re doing that here in the federal government. The vast majority of federal workers will once again work in person. + +Our schools are open. Let’s keep it that way. Our kids need to be in school. + +And with 75% of adult Americans fully vaccinated and hospitalizations down by 77%, most Americans can remove their masks, return to work, stay in the classroom, and move forward safely. + +We achieved this because we provided free vaccines, treatments, tests, and masks. + +Of course, continuing this costs money. + +I will soon send Congress a request. + +The vast majority of Americans have used these tools and may want to again, so I expect Congress to pass it quickly. + +Fourth, we will continue vaccinating the world. + +We’ve sent 475 Million vaccine doses to 112 countries, more than any other nation. + +And we won’t stop. + +We have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. + +Let’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. + +Let’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. + +We can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. + +I recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. + +They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. + +Officer Mora was 27 years old. + +Officer Rivera was 22. + +Both Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. + +I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. + +I’ve worked on these issues a long time. + +I know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety. + +So let’s not abandon our streets. Or choose between safety and equal justice. + +Let’s come together to protect our communities, restore trust, and hold law enforcement accountable. + +That’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. + +That’s why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope. + +We should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities. + +I ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe. + +And I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and can’t be traced. + +And I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon? + +Ban assault weapons and high-capacity magazines. + +Repeal the liability shield that makes gun manufacturers the only industry in America that can’t be sued. + +These laws don’t infringe on the Second Amendment. They save lives. + +The most fundamental right in America is the right to vote – and to have it counted. And it’s under assault. + +In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. + +We cannot let this happen. + +Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. + +Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. + +One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. + +And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. + +A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. + +And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. + +We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. + +We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. + +We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. + +We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders. + +We can do all this while keeping lit the torch of liberty that has led generations of immigrants to this land—my forefathers and so many of yours. + +Provide a pathway to citizenship for Dreamers, those on temporary status, farm workers, and essential workers. + +Revise our laws so businesses have the workers they need and families don’t wait decades to reunite. + +It’s not only the right thing to do—it’s the economically smart thing to do. + +That’s why immigration reform is supported by everyone from labor unions to religious leaders to the U.S. Chamber of Commerce. + +Let’s get it done once and for all. + +Advancing liberty and justice also requires protecting the rights of women. + +The constitutional right affirmed in Roe v. Wade—standing precedent for half a century—is under attack as never before. + +If we want to go forward—not backward—we must protect access to health care. Preserve a woman’s right to choose. And let’s continue to advance maternal health care in America. + +And for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong. + +As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. + +While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice. + +And soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things. + +So tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together. + +First, beat the opioid epidemic. + +There is so much we can do. Increase funding for prevention, treatment, harm reduction, and recovery. + +Get rid of outdated rules that stop doctors from prescribing treatments. And stop the flow of illicit drugs by working with state and local law enforcement to go after traffickers. + +If you’re suffering from addiction, know you are not alone. I believe in recovery, and I celebrate the 23 million Americans in recovery. + +Second, let’s take on mental health. Especially among our children, whose lives and education have been turned upside down. + +The American Rescue Plan gave schools money to hire teachers and help students make up for lost learning. + +I urge every parent to make sure your school does just that. And we can all play a part—sign up to be a tutor or a mentor. + +Children were also struggling before the pandemic. Bullying, violence, trauma, and the harms of social media. + +As Frances Haugen, who is here with us tonight, has shown, we must hold social media platforms accountable for the national experiment they’re conducting on our children for profit. + +It’s time to strengthen privacy protections, ban targeted advertising to children, demand tech companies stop collecting personal data on our children. + +And let’s get all Americans the mental health services they need. More people they can turn to for help, and full parity between physical and mental health care. + +Third, support our veterans. + +Veterans are the best of us. + +I’ve always believed that we have a sacred obligation to equip all those we send to war and care for them and their families when they come home. + +My administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free. + +Our troops in Iraq and Afghanistan faced many dangers. + +One was stationed at bases and breathing in toxic smoke from “burn pits” that incinerated wastes of war—medical and hazard material, jet fuel, and more. + +When they came home, many of the world’s fittest and best trained warriors were never the same. + +Headaches. Numbness. Dizziness. + +A cancer that would put them in a flag-draped coffin. + +I know. + +One of those soldiers was my son Major Beau Biden. + +We don’t know for sure if a burn pit was the cause of his brain cancer, or the diseases of so many of our troops. + +But I’m committed to finding out everything we can. + +Committed to military families like Danielle Robinson from Ohio. + +The widow of Sergeant First Class Heath Robinson. + +He was born a soldier. Army National Guard. Combat medic in Kosovo and Iraq. + +Stationed near Baghdad, just yards from burn pits the size of football fields. + +Heath’s widow Danielle is here with us tonight. They loved going to Ohio State football games. He loved building Legos with their daughter. + +But cancer from prolonged exposure to burn pits ravaged Heath’s lungs and body. + +Danielle says Heath was a fighter to the very end. + +He didn’t know how to stop fighting, and neither did she. + +Through her pain she found purpose to demand we do better. + +Tonight, Danielle—we are. + +The VA is pioneering new ways of linking toxic exposures to diseases, already helping more veterans get benefits. + +And tonight, I’m announcing we’re expanding eligibility to veterans suffering from nine respiratory cancers. + +I’m also calling on Congress: pass a law to make sure veterans devastated by toxic exposures in Iraq and Afghanistan finally get the benefits and comprehensive health care they deserve. + +And fourth, let’s end cancer as we know it. + +This is personal to me and Jill, to Kamala, and to so many of you. + +Cancer is the #2 cause of death in America–second only to heart disease. + +Last month, I announced our plan to supercharge +the Cancer Moonshot that President Obama asked me to lead six years ago. + +Our goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases. + +More support for patients and families. + +To get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. + +It’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. + +ARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. + +A unity agenda for the nation. + +We can do this. + +My fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. + +In this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. + +We have fought for freedom, expanded liberty, defeated totalitarianism and terror. + +And built the strongest, freest, and most prosperous nation the world has ever known. + +Now is the hour. + +Our moment of responsibility. + +Our test of resolve and conscience, of history itself. + +It is in this moment that our character is formed. Our purpose is found. Our future is forged. + +Well I know this nation. + +We will meet the test. + +To protect freedom and liberty, to expand fairness and opportunity. + +We will save democracy. + +As hard as these times have been, I am more optimistic about America today than I have been my whole life. + +Because I see the future that is within our grasp. + +Because I know there is simply nothing beyond our capacity. + +We are the only nation on Earth that has always turned every crisis we have faced into an opportunity. + +The only nation that can be defined by a single word: possibilities. + +So on this night, in our 245th year as a nation, I have come to report on the State of the Union. + +And my report is this: the State of the Union is strong—because you, the American people, are strong. + +We are stronger today than we were a year ago. + +And we will be stronger a year from now than we are today. + +Now is our moment to meet and overcome the challenges of our time. + +And we will, as one people. + +One America. + +The United States of America. + +May God bless you all. May God protect our troops. \ No newline at end of file diff --git a/ios/.gitignore b/ios/.gitignore new file mode 100644 index 0000000..31d064c --- /dev/null +++ b/ios/.gitignore @@ -0,0 +1,2 @@ +xuserdata +*~ diff --git a/ios/MLCChat copy-Info.plist b/ios/MLCChat copy-Info.plist new file mode 100644 index 0000000..ff579a6 --- /dev/null +++ b/ios/MLCChat copy-Info.plist @@ -0,0 +1,8 @@ + + + + + UIFileSharingEnabled + + + diff --git a/ios/MLCChat.xcodeproj/project.pbxproj b/ios/MLCChat.xcodeproj/project.pbxproj new file mode 100644 index 0000000..cdf5205 --- /dev/null +++ b/ios/MLCChat.xcodeproj/project.pbxproj @@ -0,0 +1,769 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 56; + objects = { + +/* Begin PBXBuildFile section */ + 1453A4CF2A1354B9001B909F /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; }; + 1453A4D02A1354B9001B909F /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; }; + 1453A4D12A1354B9001B909F /* AppState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* AppState.swift */; }; + 1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; }; + 1453A4D32A1354B9001B909F /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; }; + A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */ = {isa = PBXBuildFile; fileRef = A773CC642A5DC98200467BFE /* ImageProcessing.swift */; }; + AA14F2D42B911A9100308009 /* ImageProcessing.swift in Sources */ = {isa = PBXBuildFile; fileRef = A773CC642A5DC98200467BFE /* ImageProcessing.swift */; }; + AA14F2D52B911A9100308009 /* AppState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* AppState.swift */; }; + AA14F2D62B911A9100308009 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; }; + AA14F2D72B911A9100308009 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; }; + AA14F2D82B911A9100308009 /* PerformanceMetrics.swift in Sources */ = {isa = PBXBuildFile; fileRef = CF3673232A9E2A9300E6D5AB /* PerformanceMetrics.swift */; }; + AA14F2D92B911A9100308009 /* ChatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C229F99B07004DDAA4 /* ChatView.swift */; }; + AA14F2DA2B911A9100308009 /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; }; + AA14F2DB2B911A9100308009 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; }; + AA14F2DC2B911A9100308009 /* RestAwaitLib.swift in Sources */ = {isa = PBXBuildFile; fileRef = CFEEEF112B6423560086AA32 /* RestAwaitLib.swift */; }; + AA14F2DD2B911A9100308009 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; }; + AA14F2DE2B911A9100308009 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; }; + AA14F2DF2B911A9100308009 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; }; + AA14F2E02B911A9100308009 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; }; + AA14F2E12B911A9100308009 /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; }; + AA14F2E22B911A9100308009 /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; }; + AA14F2E42B911A9100308009 /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = AA14F2D22B911A9100308009 /* MLCSwift */; }; + AA14F2E62B911A9100308009 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; }; + AA14F2E72B911A9100308009 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; }; + AA14F2EA2B911A9100308009 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; }; + AA14F2EB2B911A9100308009 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; }; + AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; }; + AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; }; + AEC27F022A86337E00254E67 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; }; + C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; }; + C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; }; + C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; }; + C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; }; + C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; }; + C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C229F99B07004DDAA4 /* ChatView.swift */; }; + C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; }; + C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; }; + C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */; }; + CF3673242A9E2A9300E6D5AB /* PerformanceMetrics.swift in Sources */ = {isa = PBXBuildFile; fileRef = CF3673232A9E2A9300E6D5AB /* PerformanceMetrics.swift */; }; + CFEEEF122B6423560086AA32 /* RestAwaitLib.swift in Sources */ = {isa = PBXBuildFile; fileRef = CFEEEF112B6423560086AA32 /* RestAwaitLib.swift */; }; +/* End PBXBuildFile section */ + +/* Begin PBXCopyFilesBuildPhase section */ + AA14F2E82B911A9100308009 /* Embed Libraries */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + ); + name = "Embed Libraries"; + runOnlyForDeploymentPostprocessing = 0; + }; + AA14F2E92B911A9100308009 /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 7; + files = ( + AA14F2EA2B911A9100308009 /* app-config.json in CopyFiles */, + AA14F2EB2B911A9100308009 /* dist in CopyFiles */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C06A74F129F9A78000BC4BE6 /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 7; + files = ( + C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */, + C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C0D643CF29F99C5D004DDAA4 /* Embed Libraries */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + ); + name = "Embed Libraries"; + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXCopyFilesBuildPhase section */ + +/* Begin PBXFileReference section */ + 1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartView.swift; sourceTree = ""; }; + 1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelView.swift; sourceTree = ""; }; + 1453A4CC2A1354B9001B909F /* AppState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AppState.swift; sourceTree = ""; }; + 1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = ""; }; + 1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelState.swift; sourceTree = ""; }; + A773CC642A5DC98200467BFE /* ImageProcessing.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageProcessing.swift; sourceTree = ""; }; + AA14F2EF2B911A9100308009 /* MLCChat_rebased.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat_rebased.app; sourceTree = BUILT_PRODUCTS_DIR; }; + AA14F2F02B911A9100308009 /* MLCChat copy-Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; name = "MLCChat copy-Info.plist"; path = "/Users/steve/Documents/brave-projects/LLMs/MELT/frameworks/MLC/mlc-llm/ios/MLCChat copy-Info.plist"; sourceTree = ""; }; + AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ParamsConfig.swift; sourceTree = ""; }; + AEC27EFB2A85C3B000254E67 /* AppConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppConfig.swift; sourceTree = ""; }; + AEC27F012A86337E00254E67 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; + C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = ""; }; + C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = ""; }; + C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = ""; }; + C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; }; + C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCChatApp.swift; sourceTree = ""; }; + C0D643B629F99A80004DDAA4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; + C0D643C029F99B07004DDAA4 /* ChatState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatState.swift; sourceTree = ""; }; + C0D643C229F99B07004DDAA4 /* ChatView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatView.swift; sourceTree = ""; }; + C0D643C729F99B34004DDAA4 /* MessageView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MessageView.swift; sourceTree = ""; }; + C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */ = {isa = PBXFileReference; lastKnownFileType = wrapper; path = MLCSwift; sourceTree = ""; }; + CF3673232A9E2A9300E6D5AB /* PerformanceMetrics.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PerformanceMetrics.swift; sourceTree = ""; }; + CFEEEF112B6423560086AA32 /* RestAwaitLib.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RestAwaitLib.swift; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + AA14F2E32B911A9100308009 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + AA14F2E42B911A9100308009 /* MLCSwift in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C0D643AC29F99A7F004DDAA4 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + AEC27EF82A85C29000254E67 /* Models */ = { + isa = PBXGroup; + children = ( + 1453A4CD2A1354B9001B909F /* ModelConfig.swift */, + AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */, + AEC27EFB2A85C3B000254E67 /* AppConfig.swift */, + ); + path = Models; + sourceTree = ""; + }; + AEC27EFF2A85EE2800254E67 /* States */ = { + isa = PBXGroup; + children = ( + 1453A4CE2A1354B9001B909F /* ModelState.swift */, + 1453A4CC2A1354B9001B909F /* AppState.swift */, + C0D643C029F99B07004DDAA4 /* ChatState.swift */, + ); + path = States; + sourceTree = ""; + }; + AEC27F002A86306800254E67 /* Views */ = { + isa = PBXGroup; + children = ( + A773CC642A5DC98200467BFE /* ImageProcessing.swift */, + 1453A4CB2A1354B9001B909F /* ModelView.swift */, + 1453A4CA2A1354B9001B909F /* StartView.swift */, + C0D643C729F99B34004DDAA4 /* MessageView.swift */, + C0D643C229F99B07004DDAA4 /* ChatView.swift */, + ); + path = Views; + sourceTree = ""; + }; + AEC27F032A86338800254E67 /* Common */ = { + isa = PBXGroup; + children = ( + AEC27F012A86337E00254E67 /* Constants.swift */, + ); + path = Common; + sourceTree = ""; + }; + C0D643A629F99A7F004DDAA4 = { + isa = PBXGroup; + children = ( + C0DDBDF02A39068900E9D060 /* Packages */, + C06A74E029F99C9F00BC4BE6 /* dist */, + C0D643B129F99A7F004DDAA4 /* MLCChat */, + C0D643B029F99A7F004DDAA4 /* Products */, + C0D643C929F99BDA004DDAA4 /* Frameworks */, + AA14F2F02B911A9100308009 /* MLCChat copy-Info.plist */, + ); + sourceTree = ""; + }; + C0D643B029F99A7F004DDAA4 /* Products */ = { + isa = PBXGroup; + children = ( + C0D643AF29F99A7F004DDAA4 /* MLCChat.app */, + AA14F2EF2B911A9100308009 /* MLCChat_rebased.app */, + ); + name = Products; + sourceTree = ""; + }; + C0D643B129F99A7F004DDAA4 /* MLCChat */ = { + isa = PBXGroup; + children = ( + C09834182A16F4CB00A05B51 /* app-config.json */, + AEC27F032A86338800254E67 /* Common */, + AEC27EF82A85C29000254E67 /* Models */, + AEC27EFF2A85EE2800254E67 /* States */, + AEC27F002A86306800254E67 /* Views */, + C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */, + C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */, + CF3673232A9E2A9300E6D5AB /* PerformanceMetrics.swift */, + CFEEEF112B6423560086AA32 /* RestAwaitLib.swift */, + C0D643B629F99A80004DDAA4 /* Assets.xcassets */, + C0D643B829F99A80004DDAA4 /* Preview Content */, + ); + path = MLCChat; + sourceTree = ""; + }; + C0D643B829F99A80004DDAA4 /* Preview Content */ = { + isa = PBXGroup; + children = ( + C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */, + ); + path = "Preview Content"; + sourceTree = ""; + }; + C0D643C929F99BDA004DDAA4 /* Frameworks */ = { + isa = PBXGroup; + children = ( + ); + name = Frameworks; + sourceTree = ""; + }; + C0DDBDF02A39068900E9D060 /* Packages */ = { + isa = PBXGroup; + children = ( + C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */, + ); + name = Packages; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + AA14F2D12B911A9100308009 /* MLCChat_rebased */ = { + isa = PBXNativeTarget; + buildConfigurationList = AA14F2EC2B911A9100308009 /* Build configuration list for PBXNativeTarget "MLCChat_rebased" */; + buildPhases = ( + AA14F2D32B911A9100308009 /* Sources */, + AA14F2E32B911A9100308009 /* Frameworks */, + AA14F2E52B911A9100308009 /* Resources */, + AA14F2E82B911A9100308009 /* Embed Libraries */, + AA14F2E92B911A9100308009 /* CopyFiles */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MLCChat_rebased; + packageProductDependencies = ( + AA14F2D22B911A9100308009 /* MLCSwift */, + ); + productName = MLCChat; + productReference = AA14F2EF2B911A9100308009 /* MLCChat_rebased.app */; + productType = "com.apple.product-type.application"; + }; + C0D643AE29F99A7F004DDAA4 /* MLCChat */ = { + isa = PBXNativeTarget; + buildConfigurationList = C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget "MLCChat" */; + buildPhases = ( + C0D643AB29F99A7F004DDAA4 /* Sources */, + C0D643AC29F99A7F004DDAA4 /* Frameworks */, + C0D643AD29F99A7F004DDAA4 /* Resources */, + C0D643CF29F99C5D004DDAA4 /* Embed Libraries */, + C06A74F129F9A78000BC4BE6 /* CopyFiles */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MLCChat; + packageProductDependencies = ( + C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */, + ); + productName = MLCChat; + productReference = C0D643AF29F99A7F004DDAA4 /* MLCChat.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + C0D643A729F99A7F004DDAA4 /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastSwiftUpdateCheck = 1430; + LastUpgradeCheck = 1430; + TargetAttributes = { + C0D643AE29F99A7F004DDAA4 = { + CreatedOnToolsVersion = 14.3; + LastSwiftMigration = 1430; + }; + }; + }; + buildConfigurationList = C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject "MLCChat" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = C0D643A629F99A7F004DDAA4; + productRefGroup = C0D643B029F99A7F004DDAA4 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + C0D643AE29F99A7F004DDAA4 /* MLCChat */, + AA14F2D12B911A9100308009 /* MLCChat_rebased */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + AA14F2E52B911A9100308009 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + AA14F2E62B911A9100308009 /* Preview Assets.xcassets in Resources */, + AA14F2E72B911A9100308009 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C0D643AD29F99A7F004DDAA4 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */, + C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + AA14F2D32B911A9100308009 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + AA14F2D42B911A9100308009 /* ImageProcessing.swift in Sources */, + AA14F2D52B911A9100308009 /* AppState.swift in Sources */, + AA14F2D62B911A9100308009 /* MLCChatApp.swift in Sources */, + AA14F2D72B911A9100308009 /* ChatState.swift in Sources */, + AA14F2D82B911A9100308009 /* PerformanceMetrics.swift in Sources */, + AA14F2D92B911A9100308009 /* ChatView.swift in Sources */, + AA14F2DA2B911A9100308009 /* ModelState.swift in Sources */, + AA14F2DB2B911A9100308009 /* MessageView.swift in Sources */, + AA14F2DC2B911A9100308009 /* RestAwaitLib.swift in Sources */, + AA14F2DD2B911A9100308009 /* ModelConfig.swift in Sources */, + AA14F2DE2B911A9100308009 /* ParamsConfig.swift in Sources */, + AA14F2DF2B911A9100308009 /* AppConfig.swift in Sources */, + AA14F2E02B911A9100308009 /* Constants.swift in Sources */, + AA14F2E12B911A9100308009 /* ModelView.swift in Sources */, + AA14F2E22B911A9100308009 /* StartView.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C0D643AB29F99A7F004DDAA4 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */, + 1453A4D12A1354B9001B909F /* AppState.swift in Sources */, + C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */, + C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */, + CF3673242A9E2A9300E6D5AB /* PerformanceMetrics.swift in Sources */, + C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */, + 1453A4D32A1354B9001B909F /* ModelState.swift in Sources */, + C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */, + CFEEEF122B6423560086AA32 /* RestAwaitLib.swift in Sources */, + 1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */, + AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */, + AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */, + AEC27F022A86337E00254E67 /* Constants.swift in Sources */, + 1453A4D02A1354B9001B909F /* ModelView.swift in Sources */, + 1453A4CF2A1354B9001B909F /* StartView.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + AA14F2ED2B911A9100308009 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 4; + DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; + DEVELOPMENT_TEAM = KL8N8XSYF4; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + "HEADER_SEARCH_PATHS[arch=*]" = ""; + INFOPLIST_FILE = "MLCChat copy-Info.plist"; + INFOPLIST_KEY_CFBundleDisplayName = "MLChat++"; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; + INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; + INFOPLIST_KEY_NSLocalNetworkUsageDescription = "We require this persmission to notify Blade Runner service that the LLM task is completed."; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/build/lib", + ); + MARKETING_VERSION = 1.3; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.brave.mlc.Chat33; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = ""; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + AA14F2EE2B911A9100308009 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 4; + DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; + DEVELOPMENT_TEAM = KL8N8XSYF4; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + "HEADER_SEARCH_PATHS[arch=*]" = ""; + INFOPLIST_FILE = "MLCChat copy-Info.plist"; + INFOPLIST_KEY_CFBundleDisplayName = "MLChat++"; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; + INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; + INFOPLIST_KEY_NSLocalNetworkUsageDescription = "We require this persmission to notify Blade Runner service that the LLM task is completed."; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/build/lib", + ); + MARKETING_VERSION = 1.3; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.brave.mlc.Chat33; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; + C0D643BB29F99A80004DDAA4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + C0D643BC29F99A80004DDAA4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + C0D643BE29F99A80004DDAA4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 6; + DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; + DEVELOPMENT_TEAM = KL8N8XSYF4; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + "HEADER_SEARCH_PATHS[arch=*]" = ""; + INFOPLIST_FILE = MLCChat/Info.plist; + INFOPLIST_KEY_CFBundleDisplayName = MLChat; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; + INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; + INFOPLIST_KEY_NSLocalNetworkUsageDescription = "We require this persmission to notify Blade Runner service that the LLM task is completed."; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/build/lib", + ); + MARKETING_VERSION = 1.3; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.brave.mlc.Chat32; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = ""; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + C0D643BF29F99A80004DDAA4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 6; + DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; + DEVELOPMENT_TEAM = KL8N8XSYF4; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + "HEADER_SEARCH_PATHS[arch=*]" = ""; + INFOPLIST_FILE = MLCChat/Info.plist; + INFOPLIST_KEY_CFBundleDisplayName = MLChat; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; + INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; + INFOPLIST_KEY_NSLocalNetworkUsageDescription = "We require this persmission to notify Blade Runner service that the LLM task is completed."; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/build/lib", + ); + MARKETING_VERSION = 1.3; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.brave.mlc.Chat32; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + AA14F2EC2B911A9100308009 /* Build configuration list for PBXNativeTarget "MLCChat_rebased" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + AA14F2ED2B911A9100308009 /* Debug */, + AA14F2EE2B911A9100308009 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject "MLCChat" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C0D643BB29F99A80004DDAA4 /* Debug */, + C0D643BC29F99A80004DDAA4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget "MLCChat" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C0D643BE29F99A80004DDAA4 /* Debug */, + C0D643BF29F99A80004DDAA4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + +/* Begin XCSwiftPackageProductDependency section */ + AA14F2D22B911A9100308009 /* MLCSwift */ = { + isa = XCSwiftPackageProductDependency; + productName = MLCSwift; + }; + C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */ = { + isa = XCSwiftPackageProductDependency; + productName = MLCSwift; + }; +/* End XCSwiftPackageProductDependency section */ + }; + rootObject = C0D643A729F99A7F004DDAA4 /* Project object */; +} diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000..919434a --- /dev/null +++ b/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000..18d9810 --- /dev/null +++ b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings new file mode 100644 index 0000000..0c67376 --- /dev/null +++ b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings @@ -0,0 +1,5 @@ + + + + + diff --git a/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme b/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme new file mode 100644 index 0000000..311123f --- /dev/null +++ b/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme @@ -0,0 +1,81 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json b/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json new file mode 100644 index 0000000..eb87897 --- /dev/null +++ b/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json @@ -0,0 +1,11 @@ +{ + "colors" : [ + { + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000..7324dc2 --- /dev/null +++ b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,14 @@ +{ + "images" : [ + { + "filename" : "mlc-logo.png", + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png new file mode 100644 index 0000000..4ae381d Binary files /dev/null and b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png differ diff --git a/ios/MLCChat/Assets.xcassets/Contents.json b/ios/MLCChat/Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/ios/MLCChat/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCChat/Common/Constants.swift b/ios/MLCChat/Common/Constants.swift new file mode 100644 index 0000000..cf3a240 --- /dev/null +++ b/ios/MLCChat/Common/Constants.swift @@ -0,0 +1,11 @@ +// +// Constants.swift +// MLCChat +// + +struct Constants { + static let prebuiltModelDir = "dist" + static let appConfigFileName = "app-config.json" + static let modelConfigFileName = "mlc-chat-config.json" + static let paramsConfigFileName = "ndarray-cache.json" +} diff --git a/ios/MLCChat/Info.plist b/ios/MLCChat/Info.plist new file mode 100644 index 0000000..ff579a6 --- /dev/null +++ b/ios/MLCChat/Info.plist @@ -0,0 +1,8 @@ + + + + + UIFileSharingEnabled + + + diff --git a/ios/MLCChat/MLCChat.entitlements b/ios/MLCChat/MLCChat.entitlements new file mode 100644 index 0000000..caa3d58 --- /dev/null +++ b/ios/MLCChat/MLCChat.entitlements @@ -0,0 +1,10 @@ + + + + + com.apple.developer.kernel.extended-virtual-addressing + + com.apple.developer.kernel.increased-memory-limit + + + diff --git a/ios/MLCChat/MLCChatApp.swift b/ios/MLCChat/MLCChatApp.swift new file mode 100644 index 0000000..fcefd6f --- /dev/null +++ b/ios/MLCChat/MLCChatApp.swift @@ -0,0 +1,28 @@ +// +// MLCChatApp.swift +// MLCChat +// +// Created by Tianqi Chen on 4/26/23. +// + +import SwiftUI + +@main +struct MLCChatApp: App { + @StateObject private var appState = AppState() + + init() { + UITableView.appearance().separatorStyle = .none + UITableView.appearance().tableFooterView = UIView() + } + + var body: some Scene { + WindowGroup { + StartView() + .environmentObject(appState) + .task { + appState.loadAppConfigAndModels() + } + } + } +} diff --git a/ios/MLCChat/Models/AppConfig.swift b/ios/MLCChat/Models/AppConfig.swift new file mode 100644 index 0000000..69867b0 --- /dev/null +++ b/ios/MLCChat/Models/AppConfig.swift @@ -0,0 +1,28 @@ +// +// AppConfig.swift +// MLCChat +// + +struct AppConfig: Codable { + struct ModelRecord: Codable { + let modelPath: String? + let modelURL: String? + let modelLib: String + let estimatedVRAMReq: Int + let modelID: String + + enum CodingKeys: String, CodingKey { + case modelPath = "model_path" + case modelURL = "model_url" + case modelLib = "model_lib" + case estimatedVRAMReq = "estimated_vram_bytes" + case modelID = "model_id" + } + } + + var modelList: [ModelRecord] + + enum CodingKeys: String, CodingKey { + case modelList = "model_list" + } +} diff --git a/ios/MLCChat/Models/ModelConfig.swift b/ios/MLCChat/Models/ModelConfig.swift new file mode 100644 index 0000000..4ed8819 --- /dev/null +++ b/ios/MLCChat/Models/ModelConfig.swift @@ -0,0 +1,18 @@ +// +// ModelConfig.swift +// MLCChat +// + +struct ModelConfig: Decodable { + let tokenizerFiles: [String] + var modelLib: String? + var modelID: String? + var estimatedVRAMReq: Int? + + enum CodingKeys: String, CodingKey { + case tokenizerFiles = "tokenizer_files" + case modelLib = "model_lib" + case modelID = "model_id" + case estimatedVRAMReq = "estimated_vram_req" + } +} diff --git a/ios/MLCChat/Models/ParamsConfig.swift b/ios/MLCChat/Models/ParamsConfig.swift new file mode 100644 index 0000000..2635afa --- /dev/null +++ b/ios/MLCChat/Models/ParamsConfig.swift @@ -0,0 +1,12 @@ +// +// ParamsConfig.swift +// MLCChat +// + +struct ParamsConfig: Decodable { + struct ParamsRecord: Decodable { + let dataPath: String + } + + let records: [ParamsRecord] +} diff --git a/ios/MLCChat/PerformanceMetrics.swift b/ios/MLCChat/PerformanceMetrics.swift new file mode 100644 index 0000000..eede59b --- /dev/null +++ b/ios/MLCChat/PerformanceMetrics.swift @@ -0,0 +1,61 @@ +// +// PerformanceMetrics.swift +// MLCChat +// +// Created by Kleomenis Katevas on 07/06/2023. +// + +import Foundation + + +struct ConversationRecord: Codable { + let modelName: String + var modelLoadTime: TimeRecord? + var questionRecords: [QuestionRecord] = [] + + init(modelName: String) { + self.modelName = modelName + } +} + +struct QuestionRecord: Codable { + let time: TimeRecord + let input, output: String + let original_session_tokens, input_tokens, output_tokens: Int + let runtimeStats: String +} + +struct TimeRecord: Codable { + let start: Date + let duration: TimeInterval +} + +class ConversationsRecordManager: ObservableObject { + @Published private var conversations: [ConversationRecord] = [] + + func addConversationRecord(_ conversation: ConversationRecord) { + conversations.append(conversation) + } + + func saveToFile(withFileName fileName: String) { + let fileManager = FileManager.default + let directoryURL = fileManager.urls(for: .documentDirectory, in: .userDomainMask)[0] + let fileURL = directoryURL.appendingPathComponent("\(fileName).json") + + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + encoder.dateEncodingStrategy = .custom { (date, encoder) in + var container = encoder.singleValueContainer() + let timestamp = date.timeIntervalSince1970 + try container.encode(timestamp) + } + + do { + let data = try encoder.encode(conversations) + try data.write(to: fileURL) + print("Energy measurements JSON file successfully saved at \(fileURL)") + } catch { + print("Failed to write JSON data: \(error.localizedDescription)") + } + } +} diff --git a/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json b/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCChat/RestAwaitLib.swift b/ios/MLCChat/RestAwaitLib.swift new file mode 100644 index 0000000..32aabbf --- /dev/null +++ b/ios/MLCChat/RestAwaitLib.swift @@ -0,0 +1,47 @@ +// +// RestAwaitLib.swift +// MLCChat +// +// Created by Kleomenis Katevas on 26/01/2024. +// + +import Foundation +import Network + +class RestAwaitLib { + let host: String + let port: Int + + static func requestPermission() { + // dummy url + let url = URL(string: "http://192.168.1.1:8080")! + let task = URLSession.shared.dataTask(with: url) { data, response, error in + // Nothing + } + + task.resume() + } + + init(host: String, port: Int) { + self.host = host + self.port = port + } + + func continueExecution(completion: @escaping (String?, Error?) -> Void) { + guard let url = URL(string: "http://\(host):\(port)/continue") else { + completion(nil, NSError(domain: "", code: -1, userInfo: [NSLocalizedDescriptionKey: "Invalid URL"])) + return + } + + let task = URLSession.shared.dataTask(with: url) { data, response, error in + guard let data = data, error == nil else { + completion(nil, error) + return + } + let responseString = String(data: data, encoding: .utf8) + completion(responseString, nil) + } + + task.resume() + } +} diff --git a/ios/MLCChat/States/AppState.swift b/ios/MLCChat/States/AppState.swift new file mode 100644 index 0000000..22e16f4 --- /dev/null +++ b/ios/MLCChat/States/AppState.swift @@ -0,0 +1,274 @@ +// +// AppState.swift +// MLCChat +// +// Created by Yaxing Cai on 5/13/23. +// + +import Foundation + +final class AppState: ObservableObject { + @Published var models = [ModelState]() + @Published var chatState = ChatState() + + @Published var alertMessage = "" // TODO: Should move out + @Published var alertDisplayed = false // TODO: Should move out + + var conversationsRecordManager = ConversationsRecordManager() + + private var appConfig: AppConfig? + private var modelIDs = Set() + + private let fileManager: FileManager = FileManager.default + private lazy var cacheDirectoryURL: URL = { + fileManager.urls(for: .cachesDirectory, in: .userDomainMask)[0] + }() + + private let jsonDecoder = JSONDecoder() + private let jsonEncoder = JSONEncoder() + + func loadAppConfigAndModels() { + appConfig = loadAppConfig() + // Can't do anything without a valid app config + guard let appConfig else { + return + } + loadModelsConfig(modelList: appConfig.modelList) + } + + func requestDeleteModel(modelID: String) { + // model dir should have been deleted in ModelState + assert(!fileManager.fileExists(atPath: cacheDirectoryURL.appending(path: modelID).path())) + modelIDs.remove(modelID) + models.removeAll(where: {$0.modelConfig.modelID == modelID}) + updateAppConfig { + appConfig?.modelList.removeAll(where: {$0.modelID == modelID}) + } + } +} + +private extension AppState { + func loadAppConfig() -> AppConfig? { + // models in cache to download + var appConfigFileURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName) + if !fileManager.fileExists(atPath: appConfigFileURL.path()) { + appConfigFileURL = Bundle.main.bundleURL.appending(path: Constants.appConfigFileName) + } + assert(fileManager.fileExists(atPath: appConfigFileURL.path())) + + do { + let fileHandle = try FileHandle(forReadingFrom: appConfigFileURL) + let data = fileHandle.readDataToEndOfFile() + + let appConfig = try jsonDecoder.decode(AppConfig.self, from: data) + return appConfig + } catch { + showAlert(message: "Failed to load app config: \(error.localizedDescription)") + return nil + } + } + + func loadModelsConfig(modelList: [AppConfig.ModelRecord]) { + for model in modelList { + if model.modelPath != nil { + // local model + let documentsPath = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first + let modelBaseURL = documentsPath!.appending(path: model.modelID) + let modelDir = modelBaseURL + let modelConfigURL = modelDir.appending(path: Constants.modelConfigFileName) + if fileManager.fileExists(atPath: modelConfigURL.path()) { + if let modelConfig = loadModelConfig( + modelConfigURL: modelConfigURL, + modelLib: model.modelLib, + modelID: model.modelID, + estimatedVRAMReq: model.estimatedVRAMReq + ) { + addModelConfig( + modelConfig: modelConfig, + modelPath: model.modelPath!, + modelURL: nil, + isBuiltin: true + ) + } else { + showAlert(message: "Failed to load prebuilt model: \(model.modelPath!)") + } + } else { + showAlert(message: "Prebuilt mlc-chat-config.json file not found: \(model.modelPath!)") + } + } else if model.modelURL != nil { + // remote model + let modelConfigFileURL = cacheDirectoryURL + .appending(path: model.modelID) + .appending(path: Constants.modelConfigFileName) + if fileManager.fileExists(atPath: modelConfigFileURL.path()) { + if let modelConfig = loadModelConfig( + modelConfigURL: modelConfigFileURL, + modelLib: model.modelLib, + modelID: model.modelID, + estimatedVRAMReq: model.estimatedVRAMReq + ) { + addModelConfig( + modelConfig: modelConfig, + modelPath: nil, + modelURL: URL(string: model.modelURL!), + isBuiltin: true + ) + } + } else { + downloadConfig( + modelURL: URL(string: model.modelURL!), + modelLib: model.modelLib, + modelID: model.modelID, + estimatedVRAMReq: model.estimatedVRAMReq, + isBuiltin: true + ) + } + } else { + showAlert(message: "Path or URL should be provided in app config: \(model.modelID)") + } + } + } + + func loadModelConfig(modelConfigURL: URL, modelLib: String, modelID: String, estimatedVRAMReq: Int) -> ModelConfig? { + do { + assert(fileManager.fileExists(atPath: modelConfigURL.path())) + let fileHandle = try FileHandle(forReadingFrom: modelConfigURL) + let data = fileHandle.readDataToEndOfFile() + var modelConfig = try jsonDecoder.decode(ModelConfig.self, from: data) + modelConfig.modelLib = modelLib + modelConfig.modelID = modelID + modelConfig.estimatedVRAMReq = estimatedVRAMReq + return modelConfig + } catch { + showAlert(message: "Failed to resolve model config: \(error.localizedDescription)") + } + return nil + } + + func showAlert(message: String) { + DispatchQueue.main.async { [weak self] in + guard let self = self else { return } + if !self.alertDisplayed { + self.alertMessage = message + self.alertDisplayed = true + } else { + self.alertMessage.append("\n" + message) + } + } + } + + func downloadConfig(modelURL: URL?, modelLib: String, modelID: String, estimatedVRAMReq: Int, isBuiltin: Bool) { + guard let modelConfigURL = modelURL?.appending(path: "resolve").appending(path: "main").appending(path: Constants.modelConfigFileName) else { + return + } + + let downloadTask = URLSession.shared.downloadTask(with: modelConfigURL) { + [weak self] urlOrNil, responseOrNil, errorOrNil in + guard let self else { + return + } + if let error = errorOrNil { + self.showAlert(message: "Failed to download model config: \(error.localizedDescription)") + return + } + guard let fileUrl = urlOrNil else { + self.showAlert(message: "Failed to download model config") + return + } + + // cache temp file to avoid being deleted by system automatically + let tempName = UUID().uuidString + let tempFileURL = self.cacheDirectoryURL.appending(path: tempName) + + do { + try self.fileManager.moveItem(at: fileUrl, to: tempFileURL) + } catch { + self.showAlert(message: "Failed to cache downloaded file: \(error.localizedDescription)") + return + } + + do { + guard let modelConfig = loadModelConfig( + modelConfigURL: tempFileURL, + modelLib: modelLib, + modelID: modelID, + estimatedVRAMReq: estimatedVRAMReq + ) else { + try fileManager.removeItem(at: tempFileURL) + return + } + + if modelIDs.contains(modelConfig.modelID!) { + try fileManager.removeItem(at: tempFileURL) + return + } + + let modelBaseUrl = cacheDirectoryURL.appending(path: modelConfig.modelID!) + try fileManager.createDirectory(at: modelBaseUrl, withIntermediateDirectories: true) + let modelConfigUrl = modelBaseUrl.appending(path: Constants.modelConfigFileName) + try fileManager.moveItem(at: tempFileURL, to: modelConfigUrl) + assert(fileManager.fileExists(atPath: modelConfigUrl.path())) + assert(!fileManager.fileExists(atPath: tempFileURL.path())) + addModelConfig( + modelConfig: modelConfig, + modelPath: nil, + modelURL: modelURL, + isBuiltin: isBuiltin + ) + } catch { + showAlert(message: "Failed to import model: \(error.localizedDescription)") + } + } + downloadTask.resume() + } + + func addModelConfig(modelConfig: ModelConfig, modelPath: String?, modelURL: URL?, isBuiltin: Bool) { + assert(!modelIDs.contains(modelConfig.modelID!)) + modelIDs.insert(modelConfig.modelID!) + let modelBaseURL: URL + + // model_id dir should exist + if modelURL == nil { + // prebuilt model in dist + let documentsPath = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first + modelBaseURL = documentsPath!.appending(path: modelConfig.modelID!) + } else { + // download model in cache + modelBaseURL = cacheDirectoryURL.appending(path: modelConfig.modelID!) + } + assert(fileManager.fileExists(atPath: modelBaseURL.path())) + + // mlc-chat-config.json should exist + let modelConfigURL = modelBaseURL.appending(path: Constants.modelConfigFileName) + assert(fileManager.fileExists(atPath: modelConfigURL.path())) + + let model = ModelState(modelConfig: modelConfig, modelLocalBaseURL: modelBaseURL, startState: self, chatState: chatState) + model.checkModelDownloadState(modelURL: modelURL) + models.append(model) + + if modelURL != nil && !isBuiltin { + updateAppConfig { + appConfig?.modelList.append( + AppConfig.ModelRecord( + modelPath: nil, + modelURL: modelURL!.absoluteString, + modelLib: modelConfig.modelLib!, + estimatedVRAMReq: modelConfig.estimatedVRAMReq!, + modelID: modelConfig.modelID! + ) + ) + } + } + } + + func updateAppConfig(action: () -> Void) { + action() + let appConfigURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName) + do { + let data = try jsonEncoder.encode(appConfig) + try data.write(to: appConfigURL, options: Data.WritingOptions.atomic) + } catch { + print(error.localizedDescription) + } + } +} diff --git a/ios/MLCChat/States/ChatState.swift b/ios/MLCChat/States/ChatState.swift new file mode 100644 index 0000000..54dd720 --- /dev/null +++ b/ios/MLCChat/States/ChatState.swift @@ -0,0 +1,499 @@ +// +// ChatState.swift +// LLMChat +// + +import Foundation +import MLCSwift + +enum MessageRole { + case user + case bot +} + +extension MessageRole { + var isUser: Bool { self == .user } +} + +struct MessageData: Hashable { + let id = UUID() + var role: MessageRole + var message: String +} + +final class ChatState: ObservableObject { + fileprivate enum ModelChatState { + case generating + case resetting + case reloading + case terminating + case ready + case failed + case pendingImageUpload + case processingImage + } + + @Published var messages = [MessageData]() + @Published var infoText = "" + @Published var displayName = "" + @Published var useVision = false + + private let modelChatStateLock = NSLock() + private var modelChatState: ModelChatState = .ready + + private let threadWorker = ThreadWorker() + private let chatModule = ChatModule() + private var modelLib = "" + private var modelPath = "" + var modelID = "" + + var modelLoadTime: TimeRecord? + + init() { + threadWorker.qualityOfService = QualityOfService.userInteractive + threadWorker.start() + + RestAwaitLib.requestPermission() + } + + func getFileURLFromName(_ name: String) -> URL { + let fileManager = FileManager.default + let documentsPath = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first + let destinationURL = documentsPath!.appendingPathComponent(name) + return destinationURL + } + + // read input.json and return [[questions]] (conversation with questions) + func readInputFile() -> [[String]] { + + let url = getFileURLFromName("input.json") + do { + // Load the data from the file into a Data object + let data = try Data(contentsOf: url) + + // Decode the JSON data + let jsonDecoder = JSONDecoder() + let questions = try jsonDecoder.decode([[String]].self, from: data) + + return questions + + } catch { + print("Error reading or decoding file: \(error)") + return [] + } + } + + var isInterruptible: Bool { + return getModelChatState() == .ready + || getModelChatState() == .generating + || getModelChatState() == .failed + || getModelChatState() == .pendingImageUpload + } + + var isChattable: Bool { + return getModelChatState() == .ready + } + + var isUploadable: Bool { + return getModelChatState() == .pendingImageUpload + } + + var isResettable: Bool { + return getModelChatState() == .ready + || getModelChatState() == .generating + } + + func requestResetChat() { + assert(isResettable) + interruptChat(prologue: { + switchToResetting() + }, epilogue: { [weak self] in + self?.mainResetChat() + }) + } + + func requestTerminateChat(callback: @escaping () -> Void) { + assert(isInterruptible) + interruptChat(prologue: { + switchToTerminating() + }, epilogue: { [weak self] in + self?.mainTerminateChat(callback: callback) + }) + } + + func requestReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) { + if (isCurrentModel(modelID: modelID)) { + return + } + assert(isInterruptible) + interruptChat(prologue: { + switchToReloading() + }, epilogue: { [weak self] in + self?.mainReloadChat(modelID: modelID, + modelLib: modelLib, + modelPath: modelPath, + estimatedVRAMReq: estimatedVRAMReq, + displayName: displayName) + }) + } + + func requestGenerate(prompt: String) { + assert(isChattable) + switchToGenerating() + appendMessage(role: .user, message: prompt) + appendMessage(role: .bot, message: "") + threadWorker.push {[weak self] in + guard let self else { return } + chatModule.prefill(prompt) + while !chatModule.stopped() { + chatModule.decode() + if let newText = chatModule.getMessage() { + DispatchQueue.main.async { + self.updateMessage(role: .bot, message: newText) + } + } + + if getModelChatState() != .generating { + break + } + } + if getModelChatState() == .generating { + if let runtimeStats = chatModule.runtimeStatsText(useVision) { + DispatchQueue.main.async { + self.infoText = runtimeStats + self.switchToReady() + } + } + } + } + } + + func requestAutomation(measurementFilename: String) { + + let conversationsRecordManager = ConversationsRecordManager() + let conversations = readInputFile() + + assert(isChattable) + switchToGenerating() + + threadWorker.push {[self] in + + // per conversation + for (c_idx, conversation) in conversations.enumerated() { + + var conversationRecord = ConversationRecord(modelName: self.displayName) + conversationRecord.modelLoadTime = self.modelLoadTime + + for (q_idx, question) in conversation.enumerated() { + + DispatchQueue.main.async { + self.appendMessage(role: .user, message: "\(c_idx)_\(q_idx): \(question)") + } + + //print(question) + + let timeStart = Date() + + chatModule.prefill(question) + + while !chatModule.stopped() { + chatModule.decode() + if getModelChatState() != .generating { + break + } + } + + let runtimeStatsText = chatModule.runtimeStatsText(useVision)! + + let jsonResult = parseJSON(from: runtimeStatsText)! + let original_session_tokens = -1 + let input_tokens = Int(jsonResult["prefill"]!["total tokens"]!.components(separatedBy: " ")[0]) + let output_tokens = Int(jsonResult["decode"]!["total tokens"]!.components(separatedBy: " ")[0]) + + //print(chatModule.getMessage()!) + let questionRecord = QuestionRecord.init(time: TimeRecord(start: timeStart, duration: -timeStart.timeIntervalSinceNow), + input: question, + output: chatModule.getMessage(), + original_session_tokens: original_session_tokens, + input_tokens: input_tokens!, + output_tokens: output_tokens!, + runtimeStats: runtimeStatsText) + conversationRecord.questionRecords.append(questionRecord) + + if let newText = chatModule.getMessage() { + DispatchQueue.main.async { + self.appendMessage(role: .bot, message: "\(c_idx)_\(q_idx): \(newText)") + } + } + + Thread.sleep(forTimeInterval: 5.0) + } + + // Save energy events for particular session + chatModule.saveEnergyEventsToCSV(withFilename: "\(measurementFilename)_conv\(c_idx).csv") + + // add metrics + conversationsRecordManager.addConversationRecord(conversationRecord) + + // clear context + chatModule.resetChat() + chatModule.resetEnergyEvents() + + DispatchQueue.main.async { + self.appendMessage(role: .bot, message: "--sleep--") + } + Thread.sleep(forTimeInterval: 60.0) + } + + // Add session and save + conversationsRecordManager.saveToFile(withFileName: measurementFilename) + + // Notify BladeRunner that task is complete + let restAwaitLib = RestAwaitLib(host: "192.168.1.42", port: 5100) + restAwaitLib.continueExecution { response, error in + if (response != nil) { + print(response!) + } + else { + print(error!) + } + } + + // Exit app + DispatchQueue.main.asyncAfter(deadline: .now() + 1.0) { + UIApplication.shared.perform(#selector(NSXPCConnection.suspend)) + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { + exit(0) + } + } + } + } + + func parseJSON(from jsonString: String) -> [String: [String: String]]? { + guard let jsonData = jsonString.data(using: .utf8) else { + print("Error: Cannot create Data from JSON string") + return nil + } + + do { + let jsonObject = try JSONSerialization.jsonObject(with: jsonData, options: []) + + if let dictionary = jsonObject as? [String: [String: String]] { + return dictionary + } else { + print("Error: JSON is not in the expected format [String: [String: String]]") + return nil + } + + } catch { + print("Error parsing JSON: \(error)") + return nil + } + } + + func requestProcessImage(image: UIImage) { + assert(getModelChatState() == .pendingImageUpload) + switchToProcessingImage() + threadWorker.push {[weak self] in + guard let self else { return } + assert(messages.count > 0) + DispatchQueue.main.async { + self.updateMessage(role: .bot, message: "[System] Processing image") + } + // step 1. resize image + let new_image = resizeImage(image: image, width: 112, height: 112) + // step 2. prefill image by chatModule.prefillImage() + chatModule.prefillImage(new_image, prevPlaceholder: "", postPlaceholder: " ") + DispatchQueue.main.async { + self.updateMessage(role: .bot, message: "[System] Ready to chat") + self.switchToReady() + } + } + } + + func isCurrentModel(modelID: String) -> Bool { + return self.modelID == modelID + } +} + +private extension ChatState { + func getModelChatState() -> ModelChatState { + modelChatStateLock.lock() + defer { modelChatStateLock.unlock() } + return modelChatState + } + + func setModelChatState(_ newModelChatState: ModelChatState) { + modelChatStateLock.lock() + modelChatState = newModelChatState + modelChatStateLock.unlock() + } + + func appendMessage(role: MessageRole, message: String) { + messages.append(MessageData(role: role, message: message)) + } + + func updateMessage(role: MessageRole, message: String) { + messages[messages.count - 1] = MessageData(role: role, message: message) + } + + func clearHistory() { + messages.removeAll() + infoText = "" + } + + func switchToResetting() { + setModelChatState(.resetting) + } + + func switchToGenerating() { + setModelChatState(.generating) + } + + func switchToReloading() { + setModelChatState(.reloading) + } + + func switchToReady() { + setModelChatState(.ready) + } + + func switchToTerminating() { + setModelChatState(.terminating) + } + + func switchToFailed() { + setModelChatState(.failed) + } + + func switchToPendingImageUpload() { + setModelChatState(.pendingImageUpload) + } + + func switchToProcessingImage() { + setModelChatState(.processingImage) + } + + func interruptChat(prologue: () -> Void, epilogue: @escaping () -> Void) { + assert(isInterruptible) + if getModelChatState() == .ready + || getModelChatState() == .failed + || getModelChatState() == .pendingImageUpload { + prologue() + epilogue() + } else if getModelChatState() == .generating { + prologue() + threadWorker.push { + DispatchQueue.main.async { + epilogue() + } + } + } else { + assert(false) + } + } + + func mainResetChat() { + threadWorker.push {[weak self] in + guard let self else { return } + chatModule.resetChat() + if useVision { + chatModule.resetImageModule() + } + DispatchQueue.main.async { + self.clearHistory() + if self.useVision { + self.appendMessage(role: .bot, message: "[System] Upload an image to chat") + self.switchToPendingImageUpload() + } else { + self.switchToReady() + } + } + } + } + + func mainTerminateChat(callback: @escaping () -> Void) { + threadWorker.push {[weak self] in + guard let self else { return } + if useVision { + chatModule.unloadImageModule() + } + chatModule.unload() + DispatchQueue.main.async { + self.clearHistory() + self.modelID = "" + self.modelLib = "" + self.modelPath = "" + self.displayName = "" + self.useVision = false + self.switchToReady() + callback() + } + } + } + + func mainReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) { + clearHistory() + let prevUseVision = useVision + self.modelID = modelID + self.modelLib = modelLib + self.modelPath = modelPath + self.displayName = displayName + self.useVision = displayName.hasPrefix("minigpt") + threadWorker.push {[weak self] in + guard let self else { return } + DispatchQueue.main.async { + self.appendMessage(role: .bot, message: "[System] Initalize...") + } + + let modelTimeStart = Date() + + if prevUseVision { + chatModule.unloadImageModule() + } + chatModule.unload() + let vRAM = os_proc_available_memory() + if (vRAM < estimatedVRAMReq) { + let requiredMemory = String ( + format: "%.1fMB", Double(estimatedVRAMReq) / Double(1 << 20) + ) + let errorMessage = ( + "Sorry, the system cannot provide \(requiredMemory) VRAM as requested to the app, " + + "so we cannot initialize this model on this device." + ) + DispatchQueue.main.sync { + self.messages.append(MessageData(role: MessageRole.bot, message: errorMessage)) + self.switchToFailed() + } + return + } + + if useVision { + // load vicuna model + let dir = (modelPath as NSString).deletingLastPathComponent + let vicunaModelLib = "vicuna-7b-v1.3-q3f16_0" + let vicunaModelPath = dir + "/" + vicunaModelLib + let appConfigJSONData = try? JSONSerialization.data(withJSONObject: ["conv_template": "minigpt"], options: []) + let appConfigJSON = String(data: appConfigJSONData!, encoding: .utf8) + chatModule.reload(vicunaModelLib, modelPath: vicunaModelPath, appConfigJson: appConfigJSON) + // load image model + chatModule.reloadImageModule(modelLib, modelPath: modelPath) + } else { + chatModule.reload(modelLib, modelPath: modelPath, appConfigJson: "") + } + + let modelDuration = -modelTimeStart.timeIntervalSinceNow + self.modelLoadTime = TimeRecord(start: modelTimeStart, duration: modelDuration) + + DispatchQueue.main.async { + if self.useVision { + self.updateMessage(role: .bot, message: "[System] Upload an image to chat") + self.switchToPendingImageUpload() + } else { + self.updateMessage(role: .bot, message: "[System] Ready to chat") + self.switchToReady() + } + } + } + } +} diff --git a/ios/MLCChat/States/ModelState.swift b/ios/MLCChat/States/ModelState.swift new file mode 100644 index 0000000..ed22910 --- /dev/null +++ b/ios/MLCChat/States/ModelState.swift @@ -0,0 +1,414 @@ +// +// ModelState.swift +// MLCChat +// + +import Foundation + +final class ModelState: ObservableObject, Identifiable { + enum ModelDownloadState { + case initializing + case indexing + case paused + case downloading + case pausing + case verifying + case finished + case failed + case clearing + case deleting + } + + fileprivate struct DownloadTask: Hashable { + let remoteURL: URL + let localURL: URL + } + + @Published var modelConfig: ModelConfig + @Published var modelDownloadState: ModelDownloadState = .initializing + @Published var progress: Int = 0 + @Published var total: Int = 1 + + private var modelLocalBaseURL: URL + private var startState: AppState + private var chatState: ChatState + + private let fileManager: FileManager = FileManager.default + private let decoder = JSONDecoder() + private var paramsConfig: ParamsConfig? + private var modelRemoteBaseURL: URL? + private var remainingTasks: Set = Set() + private var downloadingTasks: Set = Set() + private var maxDownloadingTasks: Int = 3 + + init(modelConfig: ModelConfig, + modelLocalBaseURL: URL, + startState: AppState, + chatState: ChatState) { + self.modelConfig = modelConfig + self.modelLocalBaseURL = modelLocalBaseURL + self.startState = startState + self.chatState = chatState + } + + func checkModelDownloadState(modelURL: URL?) { + createModelFolderIfNeeded() + + guard let modelURL else { + switchToVerifying() + return + } + + modelRemoteBaseURL = modelURL.appending(path: "resolve").appending(path: "main") + + // create local params dir + let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) + if fileManager.fileExists(atPath: paramsConfigURL.path()) { + // ndarray-cache.json already downloaded + loadParamsConfig() + switchToIndexing() + } else { + // download ndarray-cache.json + downloadParamsConfig() + } + } + + func startChat(chatState: ChatState) { + chatState.requestReloadChat( + modelID: modelConfig.modelID!, + modelLib: modelConfig.modelLib!, + modelPath: modelLocalBaseURL.path(), + estimatedVRAMReq: modelConfig.estimatedVRAMReq!, + displayName: modelConfig.modelID!.components(separatedBy: "-")[0] + ) + } + + func handleStart() { + // start downloading + switchToDownloading() + } + + func handlePause() { + // pause downloading + switchToPausing() + } + + func handleClear() { + assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished) + switchToClearing() + } + + func handleDelete() { + assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished || modelDownloadState == .failed) + switchToDeleting() + } +} + +private extension ModelState { + func createModelFolderIfNeeded() { + if !fileManager.fileExists(atPath: modelLocalBaseURL.path()) { + do { + try fileManager.createDirectory(at: modelLocalBaseURL, withIntermediateDirectories: true) + } catch { + print(error.localizedDescription) + } + } + } + + func loadParamsConfig() { + let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) + assert(fileManager.fileExists(atPath: paramsConfigURL.path())) + do { + let fileHandle = try FileHandle(forReadingFrom: paramsConfigURL) + let data = fileHandle.readDataToEndOfFile() + paramsConfig = try self.decoder.decode(ParamsConfig.self, from: data) + } catch { + print(error.localizedDescription) + } + } + + func downloadParamsConfig() { + guard let modelRemoteBaseURL else { + return + } + + let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) + let downloadTask = URLSession.shared.downloadTask(with: modelRemoteBaseURL.appending(path: Constants.paramsConfigFileName)) { + [weak self] urlOrNil, responseOrNil, errorOrNil in + guard let self else { return } + guard let fileURL = urlOrNil else { return } + do { + try? self.fileManager.removeItem(at: paramsConfigURL) + try self.fileManager.moveItem(at: fileURL, to: paramsConfigURL) + DispatchQueue.main.async { + self.loadParamsConfig() + self.switchToIndexing() + } + } catch { + print(error.localizedDescription) + } + } + downloadTask.resume() + } + + func switchToIndexing() { + guard let paramsConfig, let modelRemoteBaseURL else { + return + } + + modelDownloadState = .indexing + progress = 0 + total = modelConfig.tokenizerFiles.count + paramsConfig.records.count + + // collect tokenizer download tasks + for tokenizerFile in modelConfig.tokenizerFiles { + let remoteURL = modelRemoteBaseURL.appending(path: tokenizerFile) + let localURL = modelLocalBaseURL.appending(path: tokenizerFile) + + if fileManager.fileExists(atPath: localURL.path()) { + progress += 1 + } else { + remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL)) + } + } + + // collect params download tasks + for paramsRecord in paramsConfig.records { + let remoteURL = modelRemoteBaseURL.appending(path: paramsRecord.dataPath) + let localURL = modelLocalBaseURL.appending(path: paramsRecord.dataPath) + + if fileManager.fileExists(atPath: localURL.path()) { + progress += 1 + } else { + remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL)) + } + } + + if progress < total { + switchToPaused() + } else { + switchToFinished() + } + } + + func handleNewDownload(downloadTask: DownloadTask) { + // start one download task + assert(downloadingTasks.count < maxDownloadingTasks) + let task = URLSession.shared.downloadTask(with: downloadTask.remoteURL) { + [weak self] urlOrNil, responseOrNil, errorOrNil in + guard let self else { return } + guard let fileUrl = urlOrNil else { + DispatchQueue.main.async { + self.handleCancelDownload(downloadTask: downloadTask) + } + return + } + + do { + try self.fileManager.createDirectory(at: downloadTask.localURL.deletingLastPathComponent(), withIntermediateDirectories: true) + try? self.fileManager.removeItem(at: downloadTask.localURL) + try self.fileManager.moveItem(at: fileUrl, to: downloadTask.localURL) + } catch { + print(error.localizedDescription) + } + DispatchQueue.main.async { + self.handleFinishDownload(downloadTask: downloadTask) + } + } + downloadingTasks.insert(downloadTask) + task.resume() + } + + func handleFinishDownload(downloadTask: DownloadTask) { + // update the finished download task + remainingTasks.remove(downloadTask) + downloadingTasks.remove(downloadTask) + progress += 1 + assert(modelDownloadState == .downloading || + modelDownloadState == .pausing || + modelDownloadState == .clearing || + modelDownloadState == .deleting + ) + if modelDownloadState == .downloading { + if remainingTasks.isEmpty && downloadingTasks.isEmpty { + switchToFinished() + } else { + handleNextDownload() + } + } else if modelDownloadState == .pausing && downloadingTasks.isEmpty { + switchToPaused() + } else if modelDownloadState == .clearing && downloadingTasks.isEmpty { + clear() + } else if modelDownloadState == .deleting && downloadingTasks.isEmpty { + delete() + } + } + + func handleCancelDownload(downloadTask: DownloadTask) { + // withdraw the failed download task + assert(modelDownloadState == .downloading || modelDownloadState == .pausing) + downloadingTasks.remove(downloadTask) + if modelDownloadState == .downloading { + handleNextDownload() + } else if modelDownloadState == .pausing && downloadingTasks.count == 0 { + switchToPaused() + } + } + + func handleNextDownload() { + // start next download task + assert(modelDownloadState == .downloading) + for downloadTask in remainingTasks { + if !downloadingTasks.contains(downloadTask) { + handleNewDownload(downloadTask: downloadTask) + break + } + } + } + + func switchToPaused() { + modelDownloadState = .paused + } + + func switchToPausing() { + modelDownloadState = .pausing + } + + func switchToVerifying() { + modelDownloadState = .verifying + + let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) + guard fileManager.fileExists(atPath: paramsConfigURL.path()) else { + switchToFailed() + return + } + + loadParamsConfig() + guard let paramsConfig else { + switchToFailed() + return + } + progress = 0 + total = modelConfig.tokenizerFiles.count + paramsConfig.records.count + + if !verifyTokenizers() { + switchToFailed() + return + } + + if !verifyParams() { + switchToFailed() + return + } + + switchToFinished() + } + + func verifyTokenizers() -> Bool { + for tokenizerFile in modelConfig.tokenizerFiles { + let localURL = modelLocalBaseURL.appending(path: tokenizerFile) + + if !fileManager.fileExists(atPath: localURL.path()) { + switchToFailed() + return false + } + progress += 1 + } + return true + } + + func verifyParams() -> Bool { + guard let paramsConfig else { + return false + } + + for paramsRecord in paramsConfig.records { + let localUrl = modelLocalBaseURL.appending(path: paramsRecord.dataPath) + + if !fileManager.fileExists(atPath: localUrl.path()) { + switchToFailed() + return false + } + + progress += 1 + } + return true + } + + func switchToClearing() { + if modelDownloadState == .paused { + modelDownloadState = .clearing + clear() + } else if modelDownloadState == .finished { + if chatState.modelID == modelConfig.modelID { + chatState.requestTerminateChat { [weak self] in + self?.clear() + } + } else { + clear() + } + } else { + modelDownloadState = .clearing + } + } + + func switchToDeleting() { + if modelDownloadState == .paused || modelDownloadState == .failed { + modelDownloadState = .deleting + delete() + } else if modelDownloadState == .finished { + if chatState.modelID == modelConfig.modelID { + chatState.requestTerminateChat { [weak self] in + self?.delete() + } + } else { + delete() + } + } else { + modelDownloadState = .deleting + } + } + + func switchToFinished() { + modelDownloadState = .finished + } + + func switchToFailed() { + modelDownloadState = .failed + } + + func switchToDownloading() { + modelDownloadState = .downloading + for downloadTask in remainingTasks { + if downloadingTasks.count < maxDownloadingTasks { + handleNewDownload(downloadTask: downloadTask) + } else { + return + } + } + } + + func clear() { + do { + let fileURLs = try fileManager.contentsOfDirectory(at: modelLocalBaseURL, includingPropertiesForKeys: nil) + for fileURL in fileURLs where fileURL.lastPathComponent != Constants.modelConfigFileName { + try fileManager.removeItem(at: fileURL) + assert(!fileManager.fileExists(atPath: fileURL.path())) + } + assert(fileManager.fileExists(atPath: modelLocalBaseURL.appending(path: Constants.modelConfigFileName).path())) + switchToIndexing() + } catch { + print(error.localizedDescription) + } + } + + func delete() { + do { + try fileManager.removeItem(at: modelLocalBaseURL) + assert(!fileManager.fileExists(atPath: modelLocalBaseURL.path())) + startState.requestDeleteModel(modelID: modelConfig.modelID!) // TODO: can it decouple? + } catch { + print(error.localizedDescription) + } + } +} diff --git a/ios/MLCChat/Views/ChatView.swift b/ios/MLCChat/Views/ChatView.swift new file mode 100644 index 0000000..0ec5358 --- /dev/null +++ b/ios/MLCChat/Views/ChatView.swift @@ -0,0 +1,193 @@ +// +// ChatView.swift +// MLCChat +// + +import SwiftUI +import GameController + +struct ChatView: View { + @EnvironmentObject private var chatState: ChatState + + @State private var inputMessage: String = "" + @FocusState private var inputIsFocused: Bool + @Environment(\.dismiss) private var dismiss + @Namespace private var messagesBottomID + + // vision-related properties + @State private var showActionSheet: Bool = false + @State private var showImagePicker: Bool = false + @State private var imageConfirmed: Bool = false + @State private var imageSourceType: UIImagePickerController.SourceType = .photoLibrary + @State private var image: UIImage? + + @State private var showingAlert = false + @State private var filenamePrefix = "" + + var body: some View { + VStack { + modelInfoView + messagesView + uploadImageView + messageInputView + } + .navigationBarTitle("MLC Chat: \(chatState.displayName)", displayMode: .inline) + .navigationBarBackButtonHidden() + .toolbar { + ToolbarItem(placement: .navigationBarLeading) { + Button { + dismiss() + } label: { + Image(systemName: "chevron.backward") + } + .buttonStyle(.borderless) + .disabled(!chatState.isInterruptible) + } + ToolbarItem(placement: .navigationBarTrailing) { + Button("Automate") { // Brave + automate() + } + .padding() + .disabled(!chatState.isResettable) + } + } + .alert("Run Automation", isPresented: $showingAlert) { + TextField("Filename prefix", text: $filenamePrefix) + Button("Run", action: runAutomation) + Button("Cancel", role: .cancel) { } + } + } +} + +private extension ChatView { + + var modelInfoView: some View { + Text(chatState.infoText) + .multilineTextAlignment(.center) + .opacity(0.5) + .listRowSeparator(.hidden) + } + + var messagesView: some View { + ScrollViewReader { scrollViewProxy in + ScrollView { + VStack { + let messageCount = chatState.messages.count + let hasSystemMessage = messageCount > 0 && chatState.messages[0].role == MessageRole.bot + let startIndex = hasSystemMessage ? 1 : 0 + + // display the system message + if hasSystemMessage { + MessageView(role: chatState.messages[0].role, message: chatState.messages[0].message) + } + + // display image + if let image, imageConfirmed { + ImageView(image: image) + } + + // display conversations + ForEach(chatState.messages[startIndex...], id: \.id) { message in + MessageView(role: message.role, message: message.message) + } + HStack { EmptyView() } + .id(messagesBottomID) + } + } + .onChange(of: chatState.messages) { _ in + withAnimation { + scrollViewProxy.scrollTo(messagesBottomID, anchor: .bottom) + } + } + } + } + + @ViewBuilder + var uploadImageView: some View { + if chatState.useVision && !imageConfirmed { + if image == nil { + Button("Upload picture to chat") { + showActionSheet = true + } + .actionSheet(isPresented: $showActionSheet) { + ActionSheet(title: Text("Choose from"), buttons: [ + .default(Text("Photo Library")) { + showImagePicker = true + imageSourceType = .photoLibrary + }, + .default(Text("Camera")) { + showImagePicker = true + imageSourceType = .camera + }, + .cancel() + ]) + } + .sheet(isPresented: $showImagePicker) { + ImagePicker(image: $image, + showImagePicker: $showImagePicker, + imageSourceType: imageSourceType) + } + .disabled(!chatState.isUploadable) + } else { + VStack { + if let image { + Image(uiImage: image) + .resizable() + .frame(width: 300, height: 300) + + HStack { + Button("Undo") { + self.image = nil + } + .padding() + + Button("Submit") { + imageConfirmed = true + chatState.requestProcessImage(image: image) + } + .padding() + } + } + } + } + } + } + + var messageInputView: some View { + HStack { + TextField("Inputs...", text: $inputMessage, axis: .vertical) + .textFieldStyle(RoundedBorderTextFieldStyle()) + .frame(minHeight: CGFloat(30)) + .focused($inputIsFocused) + .onSubmit { + let isKeyboardConnected = GCKeyboard.coalesced != nil + if isKeyboardConnected { + send() + } + } + Button("Send") { + send() + } + .bold() + .disabled(!(chatState.isChattable && inputMessage != "")) + } + .frame(minHeight: CGFloat(70)) + .padding() + } + + func send() { + inputIsFocused = false + chatState.requestGenerate(prompt: inputMessage) + inputMessage = "" + } + + func automate() { + showingAlert.toggle() + } + + func runAutomation() { + inputIsFocused = false + chatState.requestAutomation(measurementFilename: filenamePrefix) + inputMessage = "" + } +} diff --git a/ios/MLCChat/Views/ImageProcessing.swift b/ios/MLCChat/Views/ImageProcessing.swift new file mode 100644 index 0000000..3d7260e --- /dev/null +++ b/ios/MLCChat/Views/ImageProcessing.swift @@ -0,0 +1,66 @@ +// +// ImageProcessing.swift +// MLCChat +// +// Created by Kathryn Chen on 7/8/23. +// + +import Foundation +import SwiftUI +import UIKit + +// adapted from Mohammad Azam: https://github.com/azamsharp/SwiftUICamera +// delegate task to the coordinator to produce the image +struct ImagePicker : UIViewControllerRepresentable { + typealias UIViewControllerType = UIImagePickerController + typealias Coordinator = ImagePickerCoordinator + + @Binding var image: UIImage? + @Binding var showImagePicker: Bool + var imageSourceType: UIImagePickerController.SourceType = .photoLibrary + + func makeCoordinator() -> ImagePicker.Coordinator { + return ImagePickerCoordinator(image: $image, showImagePicker: $showImagePicker) + } + + func makeUIViewController(context: UIViewControllerRepresentableContext) -> UIImagePickerController { + let picker = UIImagePickerController() + picker.sourceType = imageSourceType + picker.delegate = context.coordinator + return picker + } + + func updateUIViewController(_ uiViewController: UIImagePickerController, context: UIViewControllerRepresentableContext) {} +} + +// image picker coordinator handling selecting from library or taking a photo +class ImagePickerCoordinator: NSObject, UINavigationControllerDelegate, UIImagePickerControllerDelegate { + @Binding var image: UIImage? + @Binding var showImagePicker: Bool + + init(image: Binding, showImagePicker: Binding) { + _image = image + _showImagePicker = showImagePicker + } + + func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) { + if let optionalImage = info[UIImagePickerController.InfoKey.originalImage] as? UIImage { + image = optionalImage + showImagePicker = false + } + } + + func imagePickerControllerDidCancel(_ picker: UIImagePickerController) { + showImagePicker = false + } +} + +// resize the input image to given width and height +func resizeImage(image: UIImage, width: Int, height: Int) -> UIImage { + let shape = CGSize(width: width, height: height) + UIGraphicsBeginImageContextWithOptions(shape, true, 0.0) + image.draw(in: CGRect(x: 0, y: 0, width: width, height: height)) + let resizedImage: UIImage? = UIGraphicsGetImageFromCurrentImageContext() + UIGraphicsEndImageContext() + return resizedImage ?? image +} diff --git a/ios/MLCChat/Views/MessageView.swift b/ios/MLCChat/Views/MessageView.swift new file mode 100644 index 0000000..4553f6b --- /dev/null +++ b/ios/MLCChat/Views/MessageView.swift @@ -0,0 +1,66 @@ +// +// MessageView.swift +// MLCChat +// + +import SwiftUI + +struct MessageView: View { + let role: MessageRole; + let message: String + + var body: some View { + let textColor = role.isUser ? Color.white : Color(UIColor.label) + let background = role.isUser ? Color.blue : Color(UIColor.secondarySystemBackground) + + HStack { + if role.isUser { + Spacer() + } + Text(message) + .padding(10) + .foregroundColor(textColor) + .background(background) + .cornerRadius(10) + .textSelection(.enabled) + if !role.isUser { + Spacer() + } + } + .padding() + .listRowSeparator(.hidden) + } +} + +struct ImageView: View { + let image: UIImage + + var body: some View { + let background = Color.blue + HStack { + Spacer() + Image(uiImage: image) + .resizable() + .frame(width: 150, height: 150) + .padding(15) + .background(background) + .cornerRadius(20) + } + .padding() + .listRowSeparator(.hidden) + } +} + +struct MessageView_Previews: PreviewProvider { + static var previews: some View { + NavigationView { + VStack (spacing: 0){ + ScrollView { + MessageView(role: MessageRole.user, message: "Message 1") + MessageView(role: MessageRole.bot, message: "Message 2") + MessageView(role: MessageRole.user, message: "Message 3") + } + } + } + } +} diff --git a/ios/MLCChat/Views/ModelView.swift b/ios/MLCChat/Views/ModelView.swift new file mode 100644 index 0000000..4676fb2 --- /dev/null +++ b/ios/MLCChat/Views/ModelView.swift @@ -0,0 +1,97 @@ +// +// ModelView.swift +// MLCChat +// +// Created by Yaxing Cai on 5/14/23. +// + +import SwiftUI + +struct ModelView: View { + @EnvironmentObject private var modelState: ModelState + @EnvironmentObject private var chatState: ChatState + @Binding var isRemoving: Bool + + @State private var isShowingDeletionConfirmation: Bool = false + + var body: some View { + VStack(alignment: .leading) { + if (modelState.modelDownloadState == .finished) { + NavigationLink(destination: + ChatView() + .environmentObject(chatState) + .onAppear { + modelState.startChat(chatState: chatState) + } + ) { + HStack { + Text(modelState.modelConfig.modelID!) + Spacer() + if chatState.isCurrentModel(modelID: modelState.modelConfig.modelID!) { + Image(systemName: "checkmark").foregroundColor(.blue) + } + } + } + .buttonStyle(.borderless) + } else { + Text(modelState.modelConfig.modelID!).opacity(0.5) + } + HStack{ + if modelState.modelDownloadState != .finished || isRemoving { + ProgressView(value: Double(modelState.progress) / Double(modelState.total)) + .progressViewStyle(.linear) + } + + if (modelState.modelDownloadState == .paused) { + Button { + modelState.handleStart() + } label: { + Image(systemName: "icloud.and.arrow.down") + } + .buttonStyle(.borderless) + } else if (modelState.modelDownloadState == .downloading) { + Button { + modelState.handlePause() + } label: { + Image(systemName: "stop.circle") + } + .buttonStyle(.borderless) + } else if (modelState.modelDownloadState == .failed) { + Image(systemName: "exclamationmark.triangle") + .foregroundColor(.red) + } + + if isRemoving { + Button(role: .destructive) { + isShowingDeletionConfirmation = true + } label: { + Image(systemName: "trash") + } + .confirmationDialog("Delete Model", isPresented: $isShowingDeletionConfirmation) { + Button("Delete Model", role: .destructive) { + modelState.handleDelete() + } + .disabled( + modelState.modelDownloadState != .downloading && + modelState.modelDownloadState != .paused && + modelState.modelDownloadState != .finished && + modelState.modelDownloadState != .failed) + Button("Clear Data") { + modelState.handleClear() + } + .disabled( + modelState.modelDownloadState != .downloading && + modelState.modelDownloadState != .paused && + modelState.modelDownloadState != .finished) + Button("Cancel", role: .cancel) { + isShowingDeletionConfirmation = false + } + } message: { + Text("Delete model will delete the all files with model config, and delete the entry in list. \n Clear model will keep the model config only, and keep the entry in list for future re-downloading.") + } + .buttonStyle(.borderless) + } + } + } + } +} diff --git a/ios/MLCChat/Views/StartView.swift b/ios/MLCChat/Views/StartView.swift new file mode 100644 index 0000000..0baa404 --- /dev/null +++ b/ios/MLCChat/Views/StartView.swift @@ -0,0 +1,47 @@ +// +// DownloadView.swift +// MLCChat +// +// Created by Yaxing Cai on 5/11/23. +// + +import SwiftUI + +struct StartView: View { + @EnvironmentObject private var appState: AppState + @State private var isAdding: Bool = false + @State private var isRemoving: Bool = false + @State private var inputModelUrl: String = "" + + var body: some View { + NavigationStack { + List{ + Section(header: Text("Models")) { + ForEach(appState.models) { modelState in + ModelView(isRemoving: $isRemoving) + .environmentObject(modelState) + .environmentObject(appState.chatState) + .environmentObject(appState.conversationsRecordManager) + } + if !isRemoving { + Button("Edit model") { + isRemoving = true + } + .buttonStyle(.borderless) + } else { + Button("Cancel edit model") { + isRemoving = false + } + .buttonStyle(.borderless) + } + } + } + .navigationTitle("MLC Chat") + .alert("Error", isPresented: $appState.alertDisplayed) { + Button("OK") { } + } message: { + Text(appState.alertMessage) + } + } + } +} diff --git a/ios/MLCChat/app-config.json b/ios/MLCChat/app-config.json new file mode 100644 index 0000000..4315ea6 --- /dev/null +++ b/ios/MLCChat/app-config.json @@ -0,0 +1,27 @@ +{ + "model_lib_path_for_prepare_libs": { + "meta-llama_Llama-2-7b-chat-hf-q3f16_1": "meta-llama_Llama-2-7b-chat-hf-q3f16_1/meta-llama_Llama-2-7b-chat-hf-q3f16_1-iphone.tar", + "google_gemma-2b-it-q3f16_1": "google_gemma-2b-it-q3f16_1/google_gemma-2b-it-q3f16_1-iphone.tar", + "google_gemma-2b-it-q4f16_1": "google_gemma-2b-it-q4f16_1/google_gemma-2b-it-q4f16_1-iphone.tar" + }, + "model_list": [ + { + "model_path": "", + "model_id": "meta-llama_Llama-2-7b-chat-hf-q3f16_1", + "model_lib": "llama_q3f16_1", + "estimated_vram_bytes": 0 + }, + { + "model_path": "", + "model_id": "google_gemma-2b-it-q3f16_1", + "model_lib": "gemma_q3f16_1", + "estimated_vram_bytes": 0 + }, + { + "model_path": "", + "model_id": "google_gemma-2b-it-q4f16_1", + "model_lib": "gemma_q4f16_1", + "estimated_vram_bytes": 0 + } + ] +} diff --git a/ios/MLCSwift/Package.swift b/ios/MLCSwift/Package.swift new file mode 100644 index 0000000..eac88db --- /dev/null +++ b/ios/MLCSwift/Package.swift @@ -0,0 +1,32 @@ +// swift-tools-version:5.5 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "MLCSwift", + products: [ + .library( + name: "MLCSwift", + targets: ["LLMChatObjC", "MLCSwift"] + ) + ], + dependencies: [], + targets: [ + .target( + name: "LLMChatObjC", + path: "Sources/ObjC", + cxxSettings: [ + .headerSearchPath("../../tvm_home/include"), + .headerSearchPath("../../tvm_home/3rdparty/dmlc-core/include"), + .headerSearchPath("../../tvm_home/3rdparty/dlpack/include") + ] + ), + .target( + name: "MLCSwift", + dependencies: ["LLMChatObjC"], + path: "Sources/Swift" + ) + ], + cxxLanguageStandard: .cxx17 +) diff --git a/ios/MLCSwift/README.md b/ios/MLCSwift/README.md new file mode 100644 index 0000000..3a7c2b5 --- /dev/null +++ b/ios/MLCSwift/README.md @@ -0,0 +1,4 @@ +# MLCSwift + +This is a simple swift package that exposes the chat module to swift. +Checkout our [documentation](https://llm.mlc.ai/docs/) for more examples. diff --git a/ios/MLCSwift/Sources/ObjC/LLMChat.mm b/ios/MLCSwift/Sources/ObjC/LLMChat.mm new file mode 100644 index 0000000..43d1a61 --- /dev/null +++ b/ios/MLCSwift/Sources/ObjC/LLMChat.mm @@ -0,0 +1,333 @@ +// +// LLMChat.mm +// LLMChat +// +#import +#import +#include + +#include "LLMChat.h" + +#define TVM_USE_LIBBACKTRACE 0 +#define DMLC_USE_LOGGING_LIBRARY + +#include +#include + +using namespace tvm::runtime; + +enum PlaceInPrompt : int { + // The input message should have role names and corresponding seperators appended both + // prior to it and after it, making it a complete prompt. + kAll, + // The input message is only the beginning part of a prompt, no role name and separator should be + // appended after the message since there will be future messages appended after the message. + kBegin, + // The input message is in the middle of a prompt, nothing should be appended before or after the + // message. + kMiddle, + // The input message is the ending part of a prompt, no role name and separator should be appended + // prior to it since the message is concatenated to some prior messages. + kEnd, +}; + +@implementation ChatModule { + // Internal c++ classes + // chat-related module and functions + Module llm_chat_; + PackedFunc unload_func_; + PackedFunc reload_func_; + PackedFunc prefill_func_; + PackedFunc embed_func_; + PackedFunc prefill_with_embed_func_; + PackedFunc decode_func_; + PackedFunc get_message_; + PackedFunc stopped_func_; + PackedFunc reset_chat_func_; + PackedFunc runtime_stats_text_func_; + PackedFunc verbose_runtime_stats_text_func_; + PackedFunc process_system_prompts_func_; + // image-related module and functions + Module llm_image_mod_; + PackedFunc image_mod_unload_func_; + PackedFunc image_mod_reload_func_; + PackedFunc image_mod_embed_func_; + PackedFunc image_mod_reset_func_; + PackedFunc image_mod_runtime_stats_text_func_; + // helper variables + bool first_input_after_image; + std::vector image_data; + NSUInteger image_width; + NSUInteger image_height; + + std::unordered_map energy_events; + int unload_counter; + int reload_counter; + int reset_chat_counter; + int decode_counter; + int prefill_counter; + int get_message_counter; + int stopped_counter; +} + +- (instancetype)init { + if (self = [super init]) { + energy_events["init.start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + // load chat module + const PackedFunc* f_chat_create = Registry::Get("mlc.llm_chat_create"); + ICHECK(f_chat_create) << "Cannot find mlc.llm_chat_create"; + llm_chat_ = (*f_chat_create)(static_cast(kDLMetal), 0); + // load image module + const PackedFunc* f_image_mod_create = Registry::Get("mlc.llm_image_module_create"); + ICHECK(f_image_mod_create) << "Cannot find mlc.llm_image_module_create"; + llm_image_mod_ = (*f_image_mod_create)(static_cast(kDLMetal), 0); + + // chat-related functions + reload_func_ = llm_chat_->GetFunction("reload"); + unload_func_ = llm_chat_->GetFunction("unload"); + prefill_func_ = llm_chat_->GetFunction("prefill"); + embed_func_ = llm_chat_->GetFunction("embed"); + prefill_with_embed_func_ = llm_chat_->GetFunction("prefill_with_embed"); + decode_func_ = llm_chat_->GetFunction("decode"); + get_message_ = llm_chat_->GetFunction("get_message"); + stopped_func_ = llm_chat_->GetFunction("stopped"); + reset_chat_func_ = llm_chat_->GetFunction("reset_chat"); + runtime_stats_text_func_ = llm_chat_->GetFunction("runtime_stats_text"); + verbose_runtime_stats_text_func_ = llm_chat_->GetFunction("verbose_runtime_stats_text"); + process_system_prompts_func_ = llm_chat_->GetFunction("process_system_prompts"); + // image-module-related functions + image_mod_reload_func_ = llm_image_mod_->GetFunction("reload"); + image_mod_unload_func_ = llm_image_mod_->GetFunction("unload"); + image_mod_embed_func_ = llm_image_mod_->GetFunction("embed"); + image_mod_reset_func_ = llm_image_mod_->GetFunction("reset"); + image_mod_runtime_stats_text_func_ = llm_image_mod_->GetFunction("runtime_stats_text"); + // helper variables + first_input_after_image = false; + image_height = 224; + image_width = 224; + image_data.reserve(image_height * image_width * 4); + + ICHECK(reload_func_ != nullptr); + ICHECK(unload_func_ != nullptr); + ICHECK(prefill_func_ != nullptr); + ICHECK(embed_func_ != nullptr); + ICHECK(prefill_with_embed_func_ != nullptr); + ICHECK(decode_func_ != nullptr); + ICHECK(get_message_ != nullptr); + ICHECK(stopped_func_ != nullptr); + ICHECK(reset_chat_func_ != nullptr); + ICHECK(runtime_stats_text_func_ != nullptr); + ICHECK(verbose_runtime_stats_text_func_ != nullptr); + ICHECK(process_system_prompts_func_ != nullptr); + ICHECK(image_mod_unload_func_ != nullptr); + ICHECK(image_mod_reload_func_ != nullptr); + ICHECK(image_mod_embed_func_ != nullptr); + ICHECK(image_mod_reset_func_ != nullptr); + ICHECK(image_mod_runtime_stats_text_func_ != nullptr); + + energy_events["init.end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + + unload_counter = 0; + reload_counter = 0; + reset_chat_counter = 0; + decode_counter = 0; + prefill_counter = 0; + get_message_counter = 0; + stopped_counter = 0; + } + return self; +} + +- (void)unload { + energy_events["unload." + std::to_string(unload_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + unload_func_(); + energy_events["unload." + std::to_string(unload_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + unload_counter++; +} + +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson { + std::string lib_prefix = modelLib.UTF8String; + std::string model_path = modelPath.UTF8String; + std::string app_config_json = appConfigJson.UTF8String; + std::replace(lib_prefix.begin(), lib_prefix.end(), '-', '_'); + lib_prefix += '_'; + Module lib = (*Registry::Get("runtime.SystemLib"))(lib_prefix); + + energy_events["reload." + std::to_string(reload_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + reload_func_(lib, model_path, app_config_json); + energy_events["reload." + std::to_string(reload_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + reload_counter++; +} + +- (void)resetChat { + energy_events["reset_chat." + std::to_string(reset_chat_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + reset_chat_func_(); + energy_events["reset_chat." + std::to_string(reset_chat_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + reset_chat_counter++; +} + +- (void)prefill:(NSString*)input { + std::string prompt = input.UTF8String; + if (first_input_after_image) { + prefill_func_(prompt, true, (int)PlaceInPrompt::kEnd); + first_input_after_image = false; + } else { + energy_events["prefill." + std::to_string(prefill_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + prefill_func_(prompt); + energy_events["prefill." + std::to_string(prefill_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + prefill_counter++; + } +} + +- (void)decode { + energy_events["generate.decode." + std::to_string(decode_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + + decode_func_(); + + energy_events["generate.decode." + std::to_string(decode_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + decode_counter++; +} + +- (NSString*)getMessage { + energy_events["get_message." + std::to_string(get_message_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + std::string ret = get_message_(); + energy_events["get_message." + std::to_string(get_message_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + get_message_counter++; + return [NSString stringWithUTF8String:ret.c_str()]; +} + +- (bool)stopped { + energy_events["stopped." + std::to_string(stopped_counter) + ".start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + bool stopped = stopped_func_().operator bool(); + energy_events["stopped." + std::to_string(stopped_counter) + ".end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + stopped_counter++; + + return stopped; +} + +- (NSString*)runtimeStatsText:(bool)useVision { + + energy_events["verbose_runtime_stats_text.start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + std::string chat_mod_stats = verbose_runtime_stats_text_func_(); + energy_events["verbose_runtime_stats_text.end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + + if (useVision) { + std::string image_mod_stats = image_mod_runtime_stats_text_func_(); + chat_mod_stats += ", " + image_mod_stats; + } + return [NSString stringWithUTF8String:chat_mod_stats.c_str()]; +} + +- (void)processSystemPrompts { + energy_events["process_system_prompts.start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + process_system_prompts_func_(); + energy_events["process_system_prompts.end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); +} + +- (void)evaluate { + LOG(INFO) << "Total-mem-budget=" << os_proc_available_memory() / (1 << 20) << "MB"; + energy_events["evaluate.start"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + llm_chat_->GetFunction("evaluate")(); + energy_events["evaluate.end"] = std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + LOG(INFO) << "Left-mem-budget=" << os_proc_available_memory() / (1 << 20) << "MB"; +} + +- (void)unloadImageModule { + image_mod_unload_func_(); + first_input_after_image = false; +} + +- (void)reloadImageModule:(NSString*)modelLib modelPath:(NSString*)modelPath { + first_input_after_image = false; + std::string lib_prefix = modelLib.UTF8String; + std::string model_path = modelPath.UTF8String; + std::replace(lib_prefix.begin(), lib_prefix.end(), '-', '_'); + lib_prefix += '_'; + Module lib = (*Registry::Get("runtime.SystemLib"))(lib_prefix); + image_mod_reload_func_(lib, model_path); +} + +- (void)resetImageModule { + image_mod_reset_func_(); + first_input_after_image = false; +} + +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder { + // prefill the previous placeholder string + std::string prev_placeholder = prevPlaceholder.UTF8String; + prefill_func_(prev_placeholder, false, (int)PlaceInPrompt::kBegin); + + // prefill with image embedding + // step 1. get image rawdata: credit from https://stackoverflow.com/a/1262893 + CGImageRef imageRef = [image CGImage]; + CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); + NSUInteger bytesPerPixel = 4; + NSUInteger bytesPerRow = bytesPerPixel * image_width; + NSUInteger bitsPerComponent = 8; + CGContextRef context = CGBitmapContextCreate( + image_data.data(), image_width, image_height, bitsPerComponent, bytesPerRow, colorSpace, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(colorSpace); + CGContextDrawImage(context, CGRectMake(0, 0, image_width, image_height), imageRef); + CGContextRelease(context); + // step 2. create tvm NDArray + ShapeTuple shape = {1, int(image_height), int(image_width), 4}; + DLDataType dtype = DataType::UInt(8); + DLDevice device = DLDevice{kDLMetal, 0}; + size_t nbytes = size_t(dtype.bits / 8); + for (auto s : shape) { + nbytes *= (size_t)s; + } + NDArray input_image = NDArray::Empty(shape, dtype, device); + input_image.CopyFromBytes(image_data.data(), nbytes); + // step 3. prefill with image embedding + NDArray embedding = image_mod_embed_func_(input_image); + prefill_with_embed_func_(embedding, false); + + // prefill the post placeholder string + std::string post_placeholder = postPlaceholder.UTF8String; + prefill_func_(post_placeholder, false, (int)PlaceInPrompt::kMiddle); + + // update the flag + first_input_after_image = true; +} + +- (void)resetEnergyEvents { + energy_events.clear(); + unload_counter = 0; + reload_counter = 0; + reset_chat_counter = 0; + decode_counter = 0; + prefill_counter = 0; + get_message_counter = 0; + stopped_counter = 0; + +} + +- (void)saveEnergyEventsToCSVWithFilename:(NSString *)fileName { + + // path to documents + NSArray *paths = NSSearchPathForDirectoriesInDomains(NSDocumentDirectory, NSUserDomainMask, YES); + NSString *documentsDirectory = [paths objectAtIndex:0]; + NSString *filePath = [documentsDirectory stringByAppendingPathComponent:fileName]; + + // Create the file + NSFileManager *fileManager = [NSFileManager defaultManager]; + [fileManager createFileAtPath:filePath contents:nil attributes:nil]; + NSFileHandle *fileHandle = [NSFileHandle fileHandleForWritingAtPath:filePath]; + + // Iterate through the unordered_map and write to the file + for (const auto &pair : energy_events) { + NSString *line = [NSString stringWithFormat:@"%s,%s\n", pair.first.c_str(), pair.second.c_str()]; + [fileHandle writeData:[line dataUsingEncoding:NSUTF8StringEncoding]]; + } + + // Close the file + [fileHandle closeFile]; +} + +@end diff --git a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h new file mode 100644 index 0000000..9521175 --- /dev/null +++ b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h @@ -0,0 +1,132 @@ +// +// Use this file to import your target's public headers that you would like to expose to Swift. +// LLM Chat Module +// +// Exposed interface of Object-C, enables swift binding. +#import +#import +#include + +/** + * The chat module that can be used by the swift app. + * It is a centralized interface that also provides multimodal support, i.e. vision modules. + * + * A chat flow can be implemented as follows, for each round of conversation + * + * @code + * + * chat.prefill(input); + * while(!chat.stopped()) { + * displayReply(chat.getMessage()); + * chat.decode(); + * } + * + * @endcode + * + * The execution logic of this module should be placed on a dedicated thread. + * + * @seealso ThreadWorker + */ +@interface ChatModule : NSObject + +/** + * Unload the current model and free all memory. + * @note This function is useful to get memory estimation before launch next model. + */ +- (void)unload; + +/** + * Reload the chat module to a new model. + * + * @param modelLib The name of the modelLib + * @param modelPath The path to the model artifacts. + * @param appConfigJson The partial config that is used to partially override the model + * configuration. + */ +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson; + +/** + * Reset the current chat session. + */ +- (void)resetChat; + +/** + * Run prefill stage for a given input and decode the first output token. + * + *@param input The user input prompt. + */ +- (void)prefill:(NSString*)input; + +/** + *Run one decode step to decode the next token. + */ +- (void)decode; + +/** + * @returns The output message in the current round. + */ +- (NSString*)getMessage; + +/** + * @returns Whether the current round stopped + */ +- (bool)stopped; + +/** + * Get the runtime statistics for the chat module, and optionally the image module. + * + *@param useVision Whether an image module is used. + */ +- (NSString*)runtimeStatsText:(bool)useVision; + +/** + * Pre-process by prefilling the system prompts, running prior to any user input. + */ +- (void)processSystemPrompts; + +/** + * \brief Run one round of prefill and decode. + * + * This function is not supposed to be used by apps. + * and is only included here when setting up the app + * for debugging purposes. + */ +- (void)evaluate; + +/** + * Unload the current image model and free all memory. + * @note This function is useful to get memory estimation before launch next model. + */ +- (void)unloadImageModule; + +/** + * Reload the image module to a new model. + * + * @param modelLib The name of the modelLib + * @param modelPath The path to the model artifacts. + */ +- (void)reloadImageModule:(NSString*)modelLib modelPath:(NSString*)modelPath; + +/** + * Reset the current image model. + */ +- (void)resetImageModule; + +/** + * Prefill the LLM with the embedding of the input image. + * + * @param image The uploaded image. + * @param prevPlaceholder The previous placeholder in the prompt, i.e. . + * @param postPlaceholder The post placeholder in the prompt, i.e. . + */ +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder; + +- (void)resetEnergyEvents; + +- (void)saveEnergyEventsToCSVWithFilename:(NSString *)fileName; + +@end diff --git a/ios/MLCSwift/Sources/Swift/LLMChat.swift b/ios/MLCSwift/Sources/Swift/LLMChat.swift new file mode 100644 index 0000000..fa7d889 --- /dev/null +++ b/ios/MLCSwift/Sources/Swift/LLMChat.swift @@ -0,0 +1 @@ +@_exported import LLMChatObjC diff --git a/ios/MLCSwift/Sources/Swift/ThreadWorker.swift b/ios/MLCSwift/Sources/Swift/ThreadWorker.swift new file mode 100644 index 0000000..79f1eb2 --- /dev/null +++ b/ios/MLCSwift/Sources/Swift/ThreadWorker.swift @@ -0,0 +1,31 @@ +import Foundation + +// A simple thread worker that is backed by a single thread +// +// Instead of dispatch queue, we need a dedicated thread for metal compute +// so all thread local resources are centralized at a single thread +public class ThreadWorker : Thread { + private var cond = NSCondition(); + private var queue = Array<()->Void>(); + + public override func main() { + Thread.setThreadPriority(1) + while (true) { + self.cond.lock() + while (queue.isEmpty) { + self.cond.wait() + } + let task = self.queue.removeFirst() + self.cond.unlock() + task() + } + } + + public func push(task: @escaping ()->Void) { + self.cond.lock() + self.queue.append(task) + self.cond.signal() + self.cond.unlock() + + } +} diff --git a/ios/MLCSwift/tvm_home b/ios/MLCSwift/tvm_home new file mode 120000 index 0000000..e15bf64 --- /dev/null +++ b/ios/MLCSwift/tvm_home @@ -0,0 +1 @@ +../../3rdparty/tvm \ No newline at end of file diff --git a/ios/README.md b/ios/README.md new file mode 100644 index 0000000..de94ee7 --- /dev/null +++ b/ios/README.md @@ -0,0 +1,3 @@ +# MLC-LLM IOS + +[Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) diff --git a/ios/prepare_libs.sh b/ios/prepare_libs.sh new file mode 100755 index 0000000..d874238 --- /dev/null +++ b/ios/prepare_libs.sh @@ -0,0 +1,74 @@ +function help { + echo -e "OPTION:" + echo -e " -s, --simulator Build for Simulator" + echo -e " -a, --arch x86_64 | arm64 Simulator arch " + echo -e " -h, --help Prints this help\n" +} + +is_simulator="false" +arch="arm64" + +# Args while-loop +while [ "$1" != "" ]; +do + case $1 in + -s | --simulator ) is_simulator="true" + ;; + -a | --arch ) shift + arch=$1 + ;; + -h | --help ) help + exit + ;; + *) + echo "$script: illegal option $1" + usage + exit 1 # error + ;; + esac + shift +done + +set -euxo pipefail + +sysroot="iphoneos" +type="Release" + +if [ "$is_simulator" = "true" ]; then + if [ "$arch" = "arm64" ]; then + # iOS simulator on Apple processors + rustup target add aarch64-apple-ios-sim + else + # iOS simulator on x86 processors + rustup target add x86_64-apple-ios + fi + sysroot="iphonesimulator" + type="Debug" +else + # iOS devices + rustup target add aarch64-apple-ios +fi + +mkdir -p build/ && cd build/ + +cmake ../..\ + -DCMAKE_BUILD_TYPE=$type\ + -DCMAKE_SYSTEM_NAME=iOS\ + -DCMAKE_SYSTEM_VERSION=14.0\ + -DCMAKE_OSX_SYSROOT=$sysroot\ + -DCMAKE_OSX_ARCHITECTURES=$arch\ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0\ + -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON\ + -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON\ + -DCMAKE_INSTALL_PREFIX=.\ + -DCMAKE_CXX_FLAGS="-O3"\ + -DMLC_LLM_INSTALL_STATIC_LIB=ON\ + -DUSE_METAL=ON +make mlc_llm_static +cmake --build . --target install --config release -j +cd .. + +rm -rf MLCSwift/tvm_home +ln -s ../../3rdparty/tvm MLCSwift/tvm_home + +python prepare_model_lib.py diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py new file mode 100644 index 0000000..55ad6f7 --- /dev/null +++ b/ios/prepare_model_lib.py @@ -0,0 +1,29 @@ +import json +import os +from tvm.contrib import cc + + +def main(): + app_config = json.load(open("MLCChat/app-config.json", "r")) + target = "iphone" + artifact_path = os.path.abspath(os.path.join("../../../../", "melt_models_converted")) + + tar_list = [] + + for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): + paths = [ + os.path.join(artifact_path, model_lib_path), + ] + valid_paths = [p for p in paths if os.path.isfile(p)] + if not valid_paths: + raise RuntimeError( + f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" + ) + tar_list.append(valid_paths[0]) + + cc.create_staticlib(os.path.join("build", "lib", "libmodel_iphone.a"), tar_list) + print(f"Creating lib from {tar_list}..") + + +if __name__ == "__main__": + main() diff --git a/ios/prepare_params.sh b/ios/prepare_params.sh new file mode 100755 index 0000000..3814f6f --- /dev/null +++ b/ios/prepare_params.sh @@ -0,0 +1,32 @@ +#!/bin/bash +set -euxo pipefail + +# NOTE: this is optional, prepackage weight into app +rm -rf dist +mkdir -p dist + +declare -a builtin_list=( + # "Mistral-7B-Instruct-v0.2-q3f16_1" + # "OpenHermes-2.5-Mistral-7B-q3f16_1" + # "Llama-2-7b-chat-hf-q3f16_1" + # "RedPajama-INCITE-Chat-3B-v1-q4f16_1" + # "vicuna-v1-7b-q3f16_0" + # "rwkv-raven-1b5-q8f16_0" + # "rwkv-raven-3b-q8f16_0" + # "rwkv-raven-7b-q8f16_0" +) + +for model in "${builtin_list[@]}"; do + if [ -d ../dist/$model/params ]; then + cp -r ../dist/$model/params dist/$model + elif [ -d ../dist/prebuilt/$model ]; then + cp -r ../dist/prebuilt/$model dist/$model + elif [ -d ../dist/prebuilt/mlc-chat-$model ]; then + cp -r ../dist/prebuilt/mlc-chat-$model dist/$model + elif [ -d ../dist/prebuilt/$model-MLC ]; then + cp -r ../dist/prebuilt/$model-MLC dist/$model + else + echo "Cannot find prebuilt weights for " $model + exit 1 + fi +done diff --git a/mlc_llm/__init__.py b/mlc_llm/__init__.py new file mode 100644 index 0000000..b74f007 --- /dev/null +++ b/mlc_llm/__init__.py @@ -0,0 +1,7 @@ +from . import dispatch +from . import quantization +from . import relax_model +from . import transform +from . import utils +from . import core +from .core import build_model, BuildArgs diff --git a/mlc_llm/build.py b/mlc_llm/build.py new file mode 100644 index 0000000..b7619aa --- /dev/null +++ b/mlc_llm/build.py @@ -0,0 +1,47 @@ +"""Script for building/compiling models.""" +import contextlib +import sys + +from mlc_llm import core + + +@contextlib.contextmanager +def debug_on_except(): + try: + yield + finally: + raised_exception = sys.exc_info()[1] + if not isinstance(raised_exception, Exception): + return + + import traceback + + try: + import ipdb as pdb + except ImportError: + import pdb + + traceback.print_exc() + pdb.post_mortem() + + +def main(): + """Main method for building model from command line.""" + empty_args = core.convert_build_args_to_argparser() # Create new ArgumentParser + parsed_args = empty_args.parse_args() # Parse through command line + + with contextlib.ExitStack() as stack: + # Enter an exception-catching context before post-processing + # the arguments, in case the post-processing itself raises an + # exception. + if parsed_args.pdb: + stack.enter_context(debug_on_except()) + + # Post processing of arguments + parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access + + core.build_model_from_args(parsed_args) + + +if __name__ == "__main__": + main() diff --git a/mlc_llm/core.py b/mlc_llm/core.py new file mode 100644 index 0000000..bd86e0a --- /dev/null +++ b/mlc_llm/core.py @@ -0,0 +1,1015 @@ +# pylint: disable=missing-docstring, redefined-outer-name, not-callable +import argparse +import functools +import json +import os +import pickle +from dataclasses import asdict, dataclass, field, fields +from typing import Any, Dict, Optional + +import mlc_llm +import tvm +import tvm.relax.backend.contrib.cublas as _ +from mlc_llm import utils +from mlc_llm.relax_model import ( + chatglm, + gpt_bigcode, + gpt_neox, + gptj, + llama, + llama_batched_vllm, + minigpt, + mistral, + param_manager, + rwkv, + stablelm_3b, +) +from mlc_llm.relax_model.commons import ( + create_shard_info_func, + create_shard_transformation_func, +) +from mlc_llm.relax_model.param_manager import ( + chain_parameter_transforms, + transform_params_for_each_rank, +) +from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention +from tvm import dlight as dl +from tvm import relax +from tvm.contrib.nvcc import parse_compute_version +from tvm.relax.backend import get_patterns_with_prefix +from tvm.relax.backend.contrib.cutlass import annotate_workspace + + +@dataclass +class BuildArgs: + r"""BuildArgs is the dataclass that organizes the arguments we use in + building a model. + + To use :meth:`mlc_llm.build_model`, users pass in an instance of :class:`BuildArgs`; for + CLI entry points, an equivalent :class:`ArgumentParser` instance is generated based + on the definition of this class using :meth:`mlc_llm.convert_build_args_to_argparser`. + + Parameters + ---------- + model: str + The name of the model to build. If it is ``auto``, we will automatically + set the model name according to ``--model-path``, ``hf-path``, or the model + folders under ``--artifact-path/models``. + + hf_path: str + Hugging Face path from which to download params, tokenizer, and config. + + quantization: str + The quantization mode we use to compile. + + max_seq_len: int + The maximum allowed sequence length for the model. + + target: str + The target platform to compile the model for. + + db_path: str + Path to log database for all models. Default: ``./log_db/``. + + reuse_lib: str + Whether to reuse a previously generated lib. + + artifact_path: str + Where to store the output. + + use_cache: int + Whether to use previously pickled IRModule and skip trace. + + convert_weights_only: bool + Whether to only convert model weights and not build the model. If both + ``convert_weight_only`` and ``build_model_only`` are set, the behavior is undefined. + + build_model_only: bool + Whether to only build model and do not convert model weights. + + debug_dump: bool + Whether to dump debugging files during compilation. + + debug_load_script: bool + Whether to load the script for debugging. + + llvm_mingw: str + ``/path/to/llvm-mingw-root``, use llvm-mingw to cross compile to windows. + + system_lib: bool + A parameter to ``relax.build``. + + sep_embed: bool + Build with separated embedding layer, only applicable to LlaMa. This + feature is in testing stage, and will be formally replaced after massive + overhaul of embedding feature for all models and use cases. + + sliding_window: int + The sliding window size in sliding window attention (SWA). This optional field + overrides the `sliding_window` in config.json for those models that use SWA. + Currently only useful when compiling Mistral. + + prefill_chunk_size: int + The chunk size during prefilling. By default, the chunk size is the same as + max sequence length. Currently only useful when compiling Mistral. + + attention_sink_size: int + Number of attention sinks (https://arxiv.org/abs/2309.17453). + Only supported on mistral yet. + + cc_path: str + ``/path/to/cross_compiler_path``; currently only used for cross-compile + for nvidia/jetson device. + + use_safetensors: bool + Specifies whether to use ``.safetensors`` instead of the default ``.bin`` + when loading in model weights. + + enable_batching: bool + Build the model for batched inference. + This is a temporary flag used to control the model execution flow in single- + sequence and batching settings for now. We will eventually merge two flows + in the future and remove this flag then. + + no_cutlass_attn: bool + Disable offloading attention operations to CUTLASS. + + no_cutlass_norm: bool + Disable offloading layer and RMS norm operations to CUTLASS. + + no_cublas: bool + Disable the step that offloads matmul to cuBLAS. Without this flag, + matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or + ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. + + use_cuda_graph: bool + Specifies whether to enable CUDA Graph for the decoder. MLP and QKV + projection between two attention layers are put into a graph. + + num_shards: int + Number of shards to split the model into in tensor parallelism multi-gpu + inference. Only useful when ``build_model_only`` is set. + + use_flash_attn_mqa: bool + Offload multi-query attention workload to Flash Attention. + + pdb: bool + If set, drop into a pdb debugger on error. + + use_vllm_attention: bool + Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. + """ + model: str = field( + default="auto", + metadata={ + "help": ( + 'The name of the model to build. If it is "auto", we will ' + 'automatically set the model name according to "--model-path", ' + '"hf-path" or the model folders under "--artifact-path/models"' + ) + }, + ) + hf_path: str = field( + default=None, + metadata={"help": "Hugging Face path from which to download params, tokenizer, and config"}, + ) + quantization: str = field( + default="q4f16_1", + metadata={ + "help": "The quantization mode we use to compile.", + "choices": [*utils.quantization_schemes.keys()], + }, + ) + max_seq_len: int = field( + default=-1, + metadata={"help": "The maximum allowed sequence length for the model."}, + ) + max_vocab_size: int = field( + default=40000, + metadata={"help": "The maximum allowed vocabulary size for the model."}, + ) + target: str = field( + default="auto", + metadata={"help": "The target platform to compile the model for."}, + ) + reuse_lib: str = field( + default=None, metadata={"help": "Whether to reuse a previously generated lib."} + ) + artifact_path: str = field(default="dist", metadata={"help": "Where to store the output."}) + use_cache: int = field( + default=1, + metadata={"help": "Whether to use previously pickled IRModule and skip trace."}, + ) + convert_weights_only: bool = field( + default=False, + metadata={ + "dest": "convert_weights_only", + "action": "store_true", + "help": "Whether to only convert model weights and not build the model.", + }, + ) + build_model_only: bool = field( + default=False, + metadata={ + "help": "Whether to only build model and do not convert model weights.", + "action": "store_true", + }, + ) + debug_dump: bool = field( + default=False, + metadata={ + "help": "Whether to dump debugging files during compilation.", + "action": "store_true", + }, + ) + debug_load_script: bool = field( + default=False, + metadata={ + "help": "Whether to load the script for debugging.", + "action": "store_true", + }, + ) + llvm_mingw: str = field( + default="", + metadata={"help": "/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows."}, + ) + cc_path: str = field( + default="", + metadata={ + "help": ( + "/path/to/cross_compiler_path, Currently only used for " + "cross-compile for nvidia/jetson device." + ) + }, + ) + system_lib: bool = field( + default=False, + metadata={"help": "A parameter to `relax.build`.", "action": "store_true"}, + ) + sep_embed: bool = field( + default=False, + metadata={ + "help": ( + "Build with separated embedding layer, only applicable to LlaMa. " + "This feature is in testing stage, and will be formally replaced after " + "massive overhaul of embedding feature for all models and use cases" + ), + "action": "store_true", + }, + ) + use_safetensors: bool = field( + default=False, + metadata={ + "help": ( + "Specifies whether to use ``.safetensors`` instead of the default " + "``.bin`` when loading in model weights." + ), + "action": "store_true", + }, + ) + enable_batching: bool = field( + default=False, + metadata={ + "help": ( + "Build the model for batched inference." + "This is a temporary flag used to control the model execution flow in single-" + "sequence and batching settings for now. We will eventually merge two flows" + "in the future and remove this flag then." + ), + "action": "store_true", + }, + ) + max_batch_size: int = field( + default=80, + metadata={ + "help": ( + "The maximum batch size for build. It has effect only when batching is enabled." + ), + }, + ) + no_cutlass_attn: bool = field( + default=False, + metadata={ + "help": ("Disable offloading attention operations to CUTLASS."), + "action": "store_true", + }, + ) + no_cutlass_norm: bool = field( + default=False, + metadata={ + "help": ("Disable offloading layer and RMS norm operations to CUTLASS."), + "action": "store_true", + }, + ) + no_cublas: bool = field( + default=False, + metadata={ + "help": ( + "Disable the step that offloads matmul to cuBLAS. Without this flag, " + "matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, " + "target is CUDA and TVM has been built with cuBLAS enabled." + ), + "action": "store_true", + }, + ) + use_cuda_graph: bool = field( + default=False, + metadata={ + "help": ( + "Specifies whether to enable CUDA Graph for the decoder. MLP and QKV " + "projection between two attention layers are put into a graph." + ), + "action": "store_true", + }, + ) + num_shards: int = field( + default=1, + metadata={ + "help": ( + "Number of shards to split the model into in tensor parallelism multi-gpu " + "inference. Only useful when --build-model-only is set." + ), + }, + ) + use_presharded_weights: bool = field( + default=False, + metadata={ + "action": "store_true", + "help": "Produce separate weight sets for each shard.", + }, + ) + use_flash_attn_mqa: bool = field( + default=False, + metadata={ + "help": ("Offload multi-query attention workload to Flash Attention."), + "action": "store_true", + }, + ) + sliding_window: int = field( + default=-1, + metadata={ + "help": ( + "The sliding window size in sliding window attention (SWA). " + "This optional field overrides the `sliding_window` in config.json for " + "those models that use SWA. Currently only useful when compiling Mistral." + ), + }, + ) + prefill_chunk_size: int = field( + default=-1, + metadata={ + "help": ( + "The chunk size during prefilling. By default, the chunk size is " + "the same as the sliding window size or the max sequence length. " + "Currently only useful when compiling Mistral." + ), + }, + ) + attention_sink_size: int = field( + default=0, + metadata={ + "help": ( + "The number of attention sinks to keep in cache." + "Only supported on mistral yet." + ), + }, + ) + pdb: bool = field( + default=False, + metadata={ + "help": ("If set, drop into a pdb debugger on error"), + "action": "store_true", + }, + ) + use_vllm_attention: bool = field( + default=False, + metadata={ + "help": ( + "Use vLLM paged KV cache and attention kernel, only relevant when " + "enable_batching=True." + ), + "action": "store_true", + }, + ) + + @property + def convert_weight_only(self): + """A backwards-compatibility helper""" + return self.convert_weights_only + + +def convert_build_args_to_argparser() -> argparse.ArgumentParser: + """Convert from BuildArgs to an equivalent ArgumentParser.""" + args = argparse.ArgumentParser() + for field in fields(BuildArgs): + name = field.name.replace("_", "-") + field_name = f"--{name}" + # `kwargs` contains `help`, `choices`, and `action` + kwargs = field.metadata.copy() + if field.type == bool: + # boolean arguments do not need to specify `type` + args.add_argument(field_name, default=field.default, **kwargs) + else: + args.add_argument(field_name, type=field.type, default=field.default, **kwargs) + + # Most models contain more than a single parameter (citation + # needed), so "weights" should be plural. The initial use of + # "--convert-weight-only" caused enough typos that it is worth + # fixing. The old argument spelling is retained for backwards + # compatibility. + args.add_argument( + "--convert-weight-only", + default=False, + dest="convert_weights_only", + action="store_true", + help="Equivalent to --convert-weights-only, retained for backwards compatibility.", + ) + + return args + + +def _parse_args(parsed) -> argparse.Namespace: + assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0 + if parsed.use_safetensors: + try: + import safetensors # pylint: disable=import-outside-toplevel, unused-import + except ImportError as error: + raise ImportError( + "`use_safetensors` option is toggled, please install safetensors package." + ) from error + + parsed.export_kwargs = {} + parsed.lib_format = "so" + parsed.system_lib_prefix = None + parsed = _setup_model_path(parsed) + + utils.parse_target(parsed) + utils.argparse_postproc_common(parsed) + + if parsed.use_vllm_attention: + assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." + assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." + assert tvm.get_global_func( + "tvm.contrib.vllm.single_query_cached_kv_attention", True + ), "TVM needs to be built with -DUSE_VLLM=ON." + + model_name = [ + parsed.model, + parsed.quantization.name, + ] + if parsed.use_presharded_weights: + model_name.append(f"presharded-{parsed.num_shards}gpu") + + parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name)) + + return parsed + + +def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-branches + if args.hf_path: + if args.model != "auto": + assert args.model == os.path.basename(args.hf_path), ( + 'When both "--model" and "--hf-path" is specified, the ' + 'value of "--model" is required to match the basename of "--hf-path". ' + f'Got "--model {args.model}" and "--hf-path {args.hf_path}"' + ) + else: + args.model = os.path.basename(args.hf_path) + args.model_path = os.path.join(args.artifact_path, "models", args.model) + if os.path.exists(args.model_path): + print(f"Weights exist at {args.model_path}, skipping download.") + else: + os.makedirs(args.model_path, exist_ok=True) + os.system("git lfs install") + os.system(f"git clone https://huggingface.co/{args.hf_path} {args.model_path}") + print(f"Downloaded weights to {args.model_path}") + validate_config(args.model_path) + elif args.model != "auto": + if os.path.isdir(args.model): + args.model = os.path.normpath(args.model) # Remove potential trailing `/` + args.model_path = args.model + args.model = os.path.basename(args.model) + else: + args.model_path = os.path.join(args.artifact_path, "models", args.model) + validate_config(args.model_path) + else: + lookup_path = os.path.join(args.artifact_path, "models") + print(f'"--model" is set to "auto". Searching in {lookup_path} for existing models.') + for dirname in os.listdir(lookup_path): + if os.path.isdir(os.path.join(lookup_path, dirname)) and os.path.isfile( + os.path.join(lookup_path, dirname, "config.json") + ): + try: + validate_config(os.path.join(lookup_path, dirname)) + except: # pylint: disable=bare-except + pass + else: + args.model_path = os.path.join(lookup_path, dirname) + args.model = dirname + break + if args.model == "auto": + raise ValueError("Please specify either the model_path or the hf_path.") + + print(f'Using path "{args.model_path}" for model "{args.model}"') + return args + + +def validate_config(model_path: str): + if os.path.exists(os.path.join(model_path, "mlc-chat-config.json")): + raise KeyError( + f"The model located in the directory {model_path} has already been compiled " + "by MLC-LLM. There is no need to compile it again. If you wish to compile " + "a new model, please provide a directory (or hf-path) that contains the " + "pre-compiled model in raw HuggingFace format instead." + ) + if model_path.split("/")[-1].startswith("minigpt"): + # minigpt does not contain a config.json file so we skip the check + return + config_path = os.path.join(model_path, "config.json") + assert os.path.exists( + config_path + ), f"Expecting HuggingFace config, but file not found: {config_path}." + with open(config_path, encoding="utf-8") as i_f: + config = json.load(i_f) + assert ( + "model_type" in config + ), f"Invalid config format. Expecting HuggingFace config format in: {config_path}" + assert ( + config["model_type"] in utils.supported_model_types + ), f"Model type {config['model_type']} not supported." + + +def get_cuda_sm_version(): + major, minor = parse_compute_version(tvm.cuda(0).compute_version) + + if major == 8: + sm = 80 + else: + sm = 10 * major + minor + + return sm + + +def mod_transform_before_build( + mod: tvm.IRModule, + param_manager: param_manager.ParamManager, + args: argparse.Namespace, + config: Dict, +) -> tvm.IRModule: + """First-stage: Legalize ops and trace""" + if args.model.startswith("minigpt"): + model_names = ["embed"] + else: + model_names = [ + "prefill", + "decode", + ] + + if not args.use_vllm_attention: + model_names += [ + "create_kv_cache", + "softmax_with_temperature", + "get_metadata", + ] + else: + # This is equivalent to prefill but without KV cache. It is used for + # determining the number of paged cache blocks that can be allocated. + model_names.append("evaluate") + + if args.sep_embed: + model_names = ["embed", "prefill_with_embed"] + model_names[1:] + if args.enable_batching: + model_names[2] = "decode_with_embed" + if args.model.lower().startswith("rwkv-"): + model_names += ["reset_kv_cache"] + + mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) + + use_ft_quant = args.quantization.name in [ + "q4f16_ft", + "q8f16_ft", + "q4f16_ft_group", + "q8f16_ft_group", + ] + mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) + + if ( + not args.enable_batching + and hasattr(config, "num_attention_heads") + and hasattr(config, "hidden_size") + and hasattr(config, "position_embedding_base") + and getattr(config, "dtype", "float16") == "float16" + ): + max_seq_len = None + if args.max_seq_len > 0: + max_seq_len = args.max_seq_len + elif hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + + if max_seq_len: + num_key_value_heads = config.get_num_key_value_heads() + # pylint: disable=no-value-for-parameter + mod = fuse_split_rotary_embedding( + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + )(mod) + + if args.target_kind == "cuda": + patterns = [] + + has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) + + if has_cutlass and not args.no_cutlass_attn: + # pylint: disable=no-value-for-parameter + if args.use_flash_attn_mqa: + mod = rewrite_attention(use_flash_mqa=True)(mod) + mod = rewrite_attention(use_flash_mqa=False)(mod) + patterns += get_patterns_with_prefix("cutlass.attention") + + if has_cutlass and not args.no_cutlass_norm: + patterns += get_patterns_with_prefix("cutlass.layer_norm") + patterns += get_patterns_with_prefix("cutlass.rms_norm") + + if has_cutlass and use_ft_quant: + patterns += get_patterns_with_prefix("cutlass.decode_matmul") + + has_cublas = tvm.get_global_func("relax.ext.cublas", True) + + if has_cublas and args.quantization.name in ("q0f16", "q0f32") and not args.no_cublas: + patterns += get_patterns_with_prefix("cublas") + + if len(patterns) > 0: + os.makedirs("./tmp", exist_ok=True) + + sm = get_cuda_sm_version() + options = {"cutlass": {"sm": sm, "find_first_valid": False}} + + if hasattr(config, "rms_norm_eps"): + options["cutlass"]["rms_eps"] = config.rms_norm_eps + + mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=True + ), + annotate_workspace, + relax.transform.AllocateWorkspace(), + relax.transform.RunCodegen(options, entry_functions=model_names), + ] + )(mod) + + if args.target_kind == "android": + mod = mlc_llm.transform.FuseTranspose1Matmul()(mod) + mod = mlc_llm.transform.FuseTranspose2Matmul()(mod) + mod = mlc_llm.transform.FuseTransposeMatmul()(mod) + mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter + mod = mlc_llm.transform.FuseDecodeMatmulEwise()(mod) + mod = mlc_llm.transform.FuseDecodeTake()(mod) + mod = relax.transform.DeadCodeElimination(model_names)(mod) + mod = mlc_llm.transform.CleanUpTIRAttrs()(mod) + mod_deploy = mod + + utils.debug_dump_script(mod_deploy, "mod_deploy.py", args) + + return mod_deploy + + +def dump_mlc_chat_config( + args: argparse.Namespace, + vocab_size: int, + max_window_size: int, + temperature: float = 0.7, + repetition_penalty: float = 1.0, + top_p: float = 0.95, + mean_gen_len: int = 128, + max_gen_len: int = 512, + shift_fill_factor: float = 0.3, + rwkv_world=False, +): + args.params_path = os.path.join(args.artifact_path, "params") + config: Dict[str, Any] = {} + + if args.reuse_lib: + config["model_lib"] = f"{args.reuse_lib}" + if not args.reuse_lib.endswith(args.quantization.name): + raise RuntimeError(f"Trying to reuse lib without suffix {args.quantization.name}") + else: + config["model_lib"] = f"{args.model}-{args.quantization.name}" + + config["local_id"] = f"{args.model}-{args.quantization.name}" + config["conv_template"] = args.conv_template + config["temperature"] = temperature + config["repetition_penalty"] = repetition_penalty + config["top_p"] = top_p + config["mean_gen_len"] = mean_gen_len + config["max_gen_len"] = max_gen_len + config["num_shards"] = args.num_shards + config["use_presharded_weights"] = args.use_presharded_weights + config["shift_fill_factor"] = shift_fill_factor + if rwkv_world: + config["tokenizer_files"] = ["tokenizer_model"] + else: + config["tokenizer_files"] = utils.get_tokenizer_files(args.params_path) + config["model_category"] = args.model_category + config["model_name"] = args.model + config["vocab_size"] = vocab_size + config["prefill_chunk_size"] = args.prefill_chunk_size + if args.sliding_window != -1: + # Do not add max window size if use sliding window + config["sliding_window"] = args.sliding_window + + # only use sinks if sliding window enabled + if args.attention_sink_size > 0: + config["attention_sink_size"] = args.attention_sink_size + else: + config["max_window_size"] = max_window_size + + args.chat_config_path = os.path.join(args.params_path, "mlc-chat-config.json") + with open(args.chat_config_path, "w", encoding="utf-8") as outfile: + json.dump(config, outfile, indent=4) + print(f"Finish exporting chat config to {args.chat_config_path}") + + +def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: + target_kind = args.target_kind + if args.system_lib_prefix: + mod_deploy = mod_deploy.with_attrs({"system_lib_prefix": args.system_lib_prefix}) + + utils.debug_dump_script(mod_deploy, "mod_before_build.py", args) + utils.debug_dump_benchmark_script( + mod_deploy, f"{args.model}_{args.quantization.name}".replace("-", "_"), args + ) + + if target_kind != "cpu": + dispatch_target = ( + args.target + if args.target_kind != "webgpu" + else tvm.target.Target("apple/m1-gpu-restricted") + ) + with dispatch_target: + if args.target_kind == "android": + mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()( # pylint: disable=not-callable + mod_deploy + ) + mod_deploy = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod_deploy) + mod_deploy = ( + mlc_llm.transform.LiftTIRGlobalBufferAlloc()( # pylint: disable=not-callable + mod_deploy + ) + ) + if not args.enable_batching: + mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) + + if args.debug_load_script: + mod_deploy = utils.debug_load_script("mod_build_stage_debug.py", args) + + utils.debug_dump_script(mod_deploy, "mod_build_stage.py", args) + + use_cuda_graph = args.use_cuda_graph and target_kind == "cuda" + + with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": use_cuda_graph}): + # The num_input attribute is needed to capture transformed weights passed as input + # into a cuda graph. + # NOTE: CUDA graph for batching is not enabled and is left as a TODO item. + if not args.enable_batching: + mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) + ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) + + output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}" + + utils.debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args) + args.lib_path = os.path.join(args.artifact_path, output_filename) + ex.export_library(args.lib_path, **args.export_kwargs) + print(f"Finish exporting to {args.lib_path}") + + +def build_model_from_args(args: argparse.Namespace): + if args.quantization == "q4f16_0": + print( + "WARNING: q4f16_1 is preferred to q4f16_0, " + "and it is highly recommended to use q4f16_1 instead" + ) + + use_ft_quant = args.quantization.name in [ + "q4f16_ft", + "q8f16_ft", + "q4f16_ft_group", + "q8f16_ft_group", + ] + + if args.num_shards > 1: + if (not args.build_model_only) and (not args.convert_weights_only): + raise ValueError( + "`num_shards` should be used together with " + "`--build-model-only` and `--convert-weight-only`" + ) + + if use_ft_quant and not args.use_presharded_weights: + print( + "WARNING: FT quantization with multi-gpus requires presharding weights." + "Forcing --use-presharded-weights." + ) + args.use_presharded_weights = True + + os.makedirs(args.artifact_path, exist_ok=True) + if args.debug_dump: + os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True) + cache_path = os.path.join(args.artifact_path, "mod_cache_before_build.pkl") + args.raw_params_path = os.path.join(args.artifact_path, "raw_params") + use_cache = args.use_cache and os.path.isfile(cache_path) + if args.sep_embed and args.model_category != "llama": + raise ValueError(f"separate embedding not supported on {args.model}") + + if args.model_category == "minigpt": + # Special case for minigpt, which neither provides nor requires a configuration. + config = {} + else: + with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: + config = json.load(i_f) + + if not use_cache or args.convert_weights_only: + model_generators = { + "llama": llama, + "mistral": mistral, + "stablelm_epoch": stablelm_3b, + "gpt_neox": gpt_neox, + "gpt_bigcode": gpt_bigcode, + "minigpt": minigpt, + "gptj": gptj, + "rwkv": rwkv, + "rwkv_world": rwkv, + "chatglm": chatglm, + } + + if args.use_vllm_attention: + model_generators["llama"] = llama_batched_vllm + model_generators["mistral"] = llama_batched_vllm + + assert args.model_category in model_generators, f"Model {args.model} not supported" + + mod, param_manager, params, model_config = model_generators[args.model_category].get_model( + args, config + ) + + if args.model_category == "mistral": + args.sliding_window = model_config.sliding_window + args.attention_sink_size = model_config.attention_sink_size + + for qspec_updater_class in param_manager.qspec_updater_classes: + qspec_updater = qspec_updater_class(param_manager) + qspec_updater.visit_module(mod) + + if not args.build_model_only: + parameter_transforms = [] + + # Run pre-quantization if provided. + args.model_path = param_manager.run_pre_quantize(args.model_path) + param_manager.init_torch_pname_to_bin_name(args.use_safetensors) + parameter_transforms.append(param_manager.create_parameter_transformation()) + + # Run pre-sharding if required + if args.num_shards > 1 and args.use_presharded_weights: + mod_shard = create_shard_transformation_func(param_manager, args, model_config) + mod_shard = transform_params_for_each_rank(mod_shard, num_shards=args.num_shards) + parameter_transforms.append(mod_shard) + + # Chain all parameter transforms together. This allows + # ReorderTransformFunc to be applied to the single + # resulting parameter transformation function. + mod_transform = functools.reduce(chain_parameter_transforms, parameter_transforms) + + seq = tvm.ir.transform.Sequential( + [ + relax.transform.CanonicalizeBindings(), + relax.transform.EliminateCommonSubexpr(), + relax.transform.DeadCodeElimination(), + # TODO(Lunderberg): Implement + # relax.transform.Simplify() that applies + # canonicalization, CSE, and DCE until + # convergence. + relax.transform.CanonicalizeBindings(), + relax.transform.EliminateCommonSubexpr(), + relax.transform.DeadCodeElimination(), + param_manager.optimize_transform_param_order(), + ], + name="SimplifyModTransform", + ) + + mod_transform = seq(mod_transform) + + params = utils.convert_weights(mod_transform, param_manager, params, args) + + if args.num_shards > 1 and use_ft_quant: + preprocessed = [] + weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight") + is_int4 = args.quantization.name in ["q4f16_ft", "q4f16_ft_group"] + sm = get_cuda_sm_version() + + for p in params: + if p.dtype == "int8": + preprocessed.append(weight_preprocess_func(p, sm, is_int4)) + else: + preprocessed.append(p) + + params = preprocessed + + utils.save_params( + params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1 + ) + + if args.model_category != "minigpt": + utils.copy_tokenizer(args) + if args.model_category == "rwkv" or args.model_category == "rwkv_world": + # TODO: refactor config into model definition + dump_mlc_chat_config( + args, + vocab_size=config["vocab_size"], + max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, + top_p=0.6, + temperature=1.2, + repetition_penalty=0.996, + rwkv_world=True, + ) + elif args.model_category == "chatglm": + dump_mlc_chat_config( + args, + vocab_size=config["padded_vocab_size"], + max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, + ) + else: + dump_mlc_chat_config( + args, + vocab_size=config["vocab_size"], + max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, + ) + + if args.convert_weights_only: + exit(0) + + mod = mod_transform_before_build(mod, param_manager, args, model_config) + if args.num_shards > 1: + # We require a "create_sharding_info" function for all + # multi-GPU models, even if they are using pre-sharded + # weights. When using pre-sharded weights, the list of + # initialization-time transforms to apply is empty. + sharding_module = create_shard_info_func(param_manager, args, model_config) + mod.update(sharding_module) + + with open(cache_path, "wb") as outfile: + pickle.dump(mod, outfile) + print(f"Save a cached module to {cache_path}.") + else: + print( + f"Load cached module from {cache_path} and skip tracing. " + "You can use --use-cache=0 to retrace" + ) + with open(cache_path, "rb") as pkl: + mod = pickle.load(pkl) + if not args.reuse_lib: + build(mod, args) + else: + print(f"Reuse existing prebuilt lib {args.reuse_lib}...") + + +def build_model(args: BuildArgs) -> (Optional[str], Optional[str], Optional[str]): + r"""Builds/compiles a model. + + Parameters + ---------- + args : :class:`BuildArgs` + A dataclass of arguments for building models.mlc_llm/core.py + + Returns + ---------- + lib_path: Optional[str] + The path to the model library file. Return ``None`` if not applicable. + model_path: Optional[str] + The path to the folder of the model's parameters. Return ``None`` if not applicable. + chat_config_path: Optional[str] + The path to the chat config `.json` file. Return ``None`` if not applicable. + """ + # Convert BuildArgs to argparse.Namespace so that we can share the rest + # of the code with the command line workflow + build_args_as_dict = asdict(args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + args = _parse_args(build_args_namespace) + build_model_from_args(args) + + # Prepare output; some workflows may or may not have the paths to return + lib_path = args.lib_path if hasattr(args, "lib_path") else None + model_path = args.params_path if hasattr(args, "params_path") else None + chat_config_path = args.chat_config_path if hasattr(args, "chat_config_path") else None + + return lib_path, model_path, chat_config_path diff --git a/mlc_llm/dispatch/__init__.py b/mlc_llm/dispatch/__init__.py new file mode 100644 index 0000000..234b60a --- /dev/null +++ b/mlc_llm/dispatch/__init__.py @@ -0,0 +1,2 @@ +from .dispatch_tir_operator import DispatchTIROperator +from .dispatch_tir_operator_adreno import DispatchTIROperatorAdreno diff --git a/mlc_llm/dispatch/dispatch_tir_operator.py b/mlc_llm/dispatch/dispatch_tir_operator.py new file mode 100644 index 0000000..21a7d27 --- /dev/null +++ b/mlc_llm/dispatch/dispatch_tir_operator.py @@ -0,0 +1,53 @@ +# pylint: disable=missing-docstring +import tvm +from tvm import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="DispatchTIROperator") +class DispatchTIROperator: # pylint: disable=too-few-public-methods + def __init__(self, model: str): + # pylint: disable=import-outside-toplevel + if model == "llama": + from .llama import lookup + + elif model == "gpt_neox": + from .gpt_neox import lookup + + elif model == "gpt_bigcode": + lookup = None + + elif model == "minigpt": + lookup = None + + elif model == "rwkv": + lookup = None + + elif model == "rwkv_world": + lookup = None + + elif model == "gptj": + lookup = None + + elif model == "chatglm": + lookup = None + + else: + raise ValueError(f"Model {model} not supported") + self.lookup = lookup + + # pylint: enable=import-outside-toplevel + + def transform_module( + self, + mod: IRModule, + ctx: tvm.transform.PassContext, + ) -> IRModule: + if self.lookup is None: + return mod + for gv in mod.functions: + scheduled_func = self.lookup(mod[gv]) + if scheduled_func is not None: + mod[gv] = scheduled_func + print("- Dispatch to pre-scheduled op:", gv.name_hint) + + return mod diff --git a/mlc_llm/dispatch/dispatch_tir_operator_adreno.py b/mlc_llm/dispatch/dispatch_tir_operator_adreno.py new file mode 100644 index 0000000..937a158 --- /dev/null +++ b/mlc_llm/dispatch/dispatch_tir_operator_adreno.py @@ -0,0 +1,8356 @@ +import tvm +from tvm import IRModule +from tvm.script import tir as T + + +@T.prim_func(private=True) +def fused_decode4_matmul3( + lv1587: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv1588: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1587[v_i // T.int64(8), v_j], lv1588[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1587[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1588[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + + +@T.prim_func(private=True) +def fused_decode4_matmul3_after( + lv1587: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv1588: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" + ) + lv1587_local = T.alloc_buffer( + (T.int64(512), T.int64(4096)), "uint32", scope="local" + ) + lv1588_local = T.alloc_buffer( + (T.int64(128), T.int64(4096)), "float16", scope="local" + ) + lv1583_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(2)): + for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1583_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(2048) + + ax2_1 * T.int64(64) + + (ax2_y * T.int64(8) + ax2_2), + ) + v2k = T.axis.spatial( + T.int64(2048), + ( + ax2_1 * T.int64(64) + + ax2_y * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv1583[v0, v1, v2]) + T.writes(lv1583_shared[v0, v1, v2k]) + lv1583_shared[v0, v1, v2k] = lv1583[v0, v1, v2] + for k_1 in range(T.int64(8)): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1588_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(64) + + (k_1 * T.int64(8) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1588[v0, v1]) + T.writes(lv1588_local[v0, v1]) + lv1588_local[v0, v1] = lv1588[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1587_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(256) + + (k_1 * T.int64(8) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1587[v0, v1]) + T.writes(lv1587_local[v0, v1]) + lv1587_local[v0, v1] = lv1587[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(2048) + + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(2048), + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv1583_shared[v_i0, v_i1, v_ki], + lv1587_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv1583_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1587_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(64) + + (k_1 * T.int64(8) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv1588_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv1588_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1583_shared[v0, v1, v2]) + lv1583_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction_sum"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(4)) + T.reads(lv1583_shared[v0, v1, v2]) + T.writes(lv1583_shared[v0, v1, v2]) + lv1583_shared[v0, v1, v2] = ( + lv1583_shared[v0, v1, v2] + + lv1583_shared[v0, v1, v2 + T.int64(16)] + ) + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1583_shared[v0, v1, v_i2k]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = ( + lv1583_shared[v0, v1, v_i2k] + + lv1583_shared[v0, v1, v_i2k + T.int64(4)] + + lv1583_shared[v0, v1, v_i2k + T.int64(8)] + + lv1583_shared[v0, v1, v_i2k + T.int64(12)] + ) + + +def sch_fused_decode4_matmul3(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[32, 64, 2] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l41, l42, l43, l44, l45 = sch.get_loops(block=b22) + l46, l47, l48 = sch.split(loop=l45, factors=[None, 64, 8], preserve_unit_iters=True) + sch.vectorize(loop=l48) + sch.bind(loop=l47, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode6_fused_matmul7_add1( + lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), + lv1624: T.Buffer((T.int64(344), T.int64(4096)), "float16"), + lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), + lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1623[v_i // T.int64(8), v_j], lv1624[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1623[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1624[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv200[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv200[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv198[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv198[v_ax0, v_ax1, v_ax2] + + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode6_fused_matmul7_add1_after( + lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), + lv1624: T.Buffer((T.int64(344), T.int64(4096)), "float16"), + lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), + lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(16384)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(16384)), "float16", scope="local" + ) + lv1623_local = T.alloc_buffer( + (T.int64(1376), T.int64(4096)), "uint32", scope="local" + ) + lv1624_local = T.alloc_buffer( + (T.int64(344), T.int64(4096)), "float16", scope="local" + ) + lv200_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2752)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(8), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(16384), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(4)): + for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(3)): + for ax2_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax2_2 in T.vectorized(T.int64(2)): + with T.block("lv200_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(11008), + k_0 * T.int64(2752) + + ( + ax2_0 * T.int64(1024) + + ax2_1 * T.int64(8) + + (ax2_y * T.int64(2) + ax2_2) + ), + ) + v2k = T.axis.spatial( + T.int64(2752), + ( + ax2_0 * T.int64(1024) + + ax2_1 * T.int64(8) + + (ax2_y * T.int64(2) + ax2_2) + ), + ) + T.where( + (ax2_0 * T.int64(128) + ax2_1) < T.int64(344) + ) + T.reads(lv200[v0, v1, v2]) + T.writes(lv200_shared[v0, v1, v2k]) + lv200_shared[v0, v1, v2k] = lv200[v0, v1, v2] + for k_1 in range(T.int64(22)): + for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("lv1624_check"): + T.where((k_1 * T.int64(4) + ax2_y) < T.int64(86)) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(16384), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1624_local"): + v0 = T.axis.spatial( + T.int64(344), + k_0 * T.int64(86) + + (k_1 * T.int64(4) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1624[v0, v1]) + T.writes(lv1624_local[v0, v1]) + lv1624_local[v0, v1] = lv1624[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1623_local"): + v0 = T.axis.spatial( + T.int64(1376), + k_0 * T.int64(344) + + (k_1 * T.int64(4) + ax2_y) + * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1623[v0, v1]) + T.writes(lv1623_local[v0, v1]) + lv1623_local[v0, v1] = lv1623[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(16384), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(11008), + k_0 * T.int64(2752) + + (k_1 * T.int64(4) + ax2_y) + * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(2752), + (k_1 * T.int64(4) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv200_shared[v_i0, v_i1, v_ki], + lv1623_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv200_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1623_local[ + v_k // T.int64(8), + v_i2, + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(16384), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(344), + k_0 * T.int64(86) + + (k_1 * T.int64(4) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv1624_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local[ + v_i0, v_i1, v_i2k + ] = ( + var_matmul_intermediate_local[ + v_i0, v_i1, v_i2k + ] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv1624_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(16384), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv200_shared[v0, v1, v2]) + lv200_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(16) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv200_shared[v0, v1, v_i2k]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = ( + lv198[v0, v1, v2] + + lv200_shared[v0, v1, v_i2k] + + lv200_shared[v0, v1, v_i2k + T.int64(4)] + + lv200_shared[v0, v1, v_i2k + T.int64(8)] + + lv200_shared[v0, v1, v_i2k + T.int64(12)] + ) + + +def sch_fused_decode6_fused_matmul7_add1(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[8, 256, 2] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[344, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + sch.reverse_compute_inline(block=b28) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode5_fused_matmul6_multiply1( + lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv1618: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1617[v_i // T.int64(8), v_j], lv1618[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1617[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1618[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode5_fused_matmul6_multiply1_after( + lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv1618: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" + ) + lv1617_local = T.alloc_buffer( + (T.int64(512), T.int64(11008)), "uint32", scope="local" + ) + lv1618_local = T.alloc_buffer( + (T.int64(128), T.int64(11008)), "float16", scope="local" + ) + lv1622_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(4)): + for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1622_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(1024) + + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2, + ) + v2k = T.axis.spatial( + T.int64(1024), + ( + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv1622[v0, v1, v2]) + T.writes(lv1622_shared[v0, v1, v2k]) + lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2] + for k_1 in range(T.int64(16)): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1618_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1618[v0, v1]) + T.writes(lv1618_local[v0, v1]) + lv1618_local[v0, v1] = lv1618[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1617_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(128) + + (k_1 * T.int64(2) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1617[v0, v1]) + T.writes(lv1617_local[v0, v1]) + lv1617_local[v0, v1] = lv1617[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(1024) + + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(1024), + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv1622_shared[v_i0, v_i1, v_ki], + lv1617_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv1622_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1617_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv1618_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv1618_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1622_shared[v0, v1, v2]) + lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1622_shared[v0, v1, v_i2k], lv4[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] * ( + lv1622_shared[v0, v1, v_i2k] + + lv1622_shared[v0, v1, v_i2k + T.int64(4)] + ) + + +def sch_fused_decode5_fused_matmul6_multiply1(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="T_multiply", func_name="main") + sch.reverse_compute_inline(block=b41) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l42, l43, l44, l45, l46 = sch.get_loops(block=b22) + l47, l48, l49 = sch.split(loop=l46, factors=[None, 64, 8], preserve_unit_iters=True) + sch.vectorize(loop=l49) + sch.bind(loop=l48, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_fused_decode9_matmul7( + lv19: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), + lv20: T.Buffer((T.int64(128), T.int64(22016)), "float16"), + lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(22016)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(22016)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv19[v_i // T.int64(8), v_j], lv20[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv19[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv20[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(22016), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2] + ) + + +@T.prim_func(private=True) +def fused_fused_decode9_matmul7_after( + lv19: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), + lv20: T.Buffer((T.int64(128), T.int64(22016)), "float16"), + lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(352256)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(352256)), "float16", scope="local" + ) + lv19_local = T.alloc_buffer((T.int64(512), T.int64(22016)), "uint32", scope="local") + lv20_local = T.alloc_buffer( + (T.int64(128), T.int64(22016)), "float16", scope="local" + ) + lv1654_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(172), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(352256), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(1)): + for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1654_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(4096) + + ax2_y * T.int64(256) + + ax2_1 * T.int64(8) + + ax2_2, + ) + v2k = T.axis.spatial( + T.int64(4096), + ( + ax2_y * T.int64(256) + + ax2_1 * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv1654[v0, v1, v2]) + T.writes(lv1654_shared[v0, v1, v2k]) + lv1654_shared[v0, v1, v2k] = lv1654[v0, v1, v2] + for k_1 in range(T.int64(8)): + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(352256), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv20_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(128) + + (k_1 * T.int64(16) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv20[v0, v1]) + T.writes(lv20_local[v0, v1]) + lv20_local[v0, v1] = lv20[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv19_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(512) + + (k_1 * T.int64(16) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv19[v0, v1]) + T.writes(lv19_local[v0, v1]) + lv19_local[v0, v1] = lv19[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(352256), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(4096) + + (k_1 * T.int64(16) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(4096), + (k_1 * T.int64(16) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv1654_shared[v_i0, v_i1, v_ki], + lv19_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv1654_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv19_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(352256), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(128) + + (k_1 * T.int64(16) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv20_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv20_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(352256), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1654_shared[v0, v1, v2]) + lv1654_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction_1"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v_i2k = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(8)) + T.reads(lv1654_shared[v0, v1, v_i2k]) + T.writes(lv1654_shared[v0, v1, v_i2k]) + lv1654_shared[v0, v1, v_i2k] = ( + lv1654_shared[v0, v1, v_i2k] + + lv1654_shared[v0, v1, v_i2k + T.int64(32)] + ) + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction_2"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v_i2k = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(4)) + T.reads(lv1654_shared[v0, v1, v_i2k]) + T.writes(lv1654_shared[v0, v1, v_i2k]) + lv1654_shared[v0, v1, v_i2k] = ( + lv1654_shared[v0, v1, v_i2k] + + lv1654_shared[v0, v1, v_i2k + T.int64(16)] + ) + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(2048), + i0_i1_i2_fused_1 * T.int64(64) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1654_shared[v0, v1, v_i2k]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = ( + lv1654_shared[v0, v1, v_i2k] + + lv1654_shared[v0, v1, v_i2k + T.int64(4)] + + lv1654_shared[v0, v1, v_i2k + T.int64(8)] + + lv1654_shared[v0, v1, v_i2k + T.int64(12)] + ) + + +@T.prim_func(private=True) +def fused_fused_decode7_matmul4( + lv3: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), + lv4: T.Buffer((T.int64(128), T.int64(12288)), "float16"), + lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(12288)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(12288)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv3[v_i // T.int64(8), v_j], lv4[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv3[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv4[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(12288), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1615[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1615[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2] + ) + + +@T.prim_func(private=True) +def fused_fused_decode7_matmul4_after( + lv3: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), + lv4: T.Buffer((T.int64(128), T.int64(12288)), "float16"), + lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(12288)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(24576)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(24576)), "float16", scope="local" + ) + lv3_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") + lv4_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + lv1615_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(24576), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(4)): + for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1615_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(1024) + + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2, + ) + v2k = T.axis.spatial( + T.int64(1024), + ( + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv1615[v0, v1, v2]) + T.writes(lv1615_shared[v0, v1, v2k]) + lv1615_shared[v0, v1, v2k] = lv1615[v0, v1, v2] + for k_1 in range(T.int64(16)): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(24576), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv4_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv4[v0, v1]) + T.writes(lv4_local[v0, v1]) + lv4_local[v0, v1] = lv4[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv3_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(128) + + (k_1 * T.int64(2) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv3[v0, v1]) + T.writes(lv3_local[v0, v1]) + lv3_local[v0, v1] = lv3[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(24576), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(1024) + + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(1024), + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv1615_shared[v_i0, v_i1, v_ki], + lv3_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv1615_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv3_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(24576), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv4_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv4_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(24576), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1615_shared[v0, v1, v2]) + lv1615_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1615_shared[v0, v1, v_i2k]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = ( + lv1615_shared[v0, v1, v_i2k] + + lv1615_shared[v0, v1, v_i2k + T.int64(4)] + ) + + +@T.prim_func(private=True) +def fused_decode5_fused_matmul6_silu1( + lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv1612: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1611[v_i // T.int64(8), v_j], lv1612[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1611[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1612[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid( + var_matmul_intermediate[v_i0, v_i1, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2], + compute[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + * compute[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode5_fused_matmul6_silu1_after( + lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv1612: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(11008)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" + ) + lv1611_local = T.alloc_buffer( + (T.int64(512), T.int64(11008)), "uint32", scope="local" + ) + lv1612_local = T.alloc_buffer( + (T.int64(128), T.int64(11008)), "float16", scope="local" + ) + lv1622_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(4)): + for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1622_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(1024) + + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2, + ) + v2k = T.axis.spatial( + T.int64(1024), + ( + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv1622[v0, v1, v2]) + T.writes(lv1622_shared[v0, v1, v2k]) + lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2] + for k_1 in range(T.int64(16)): + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1612_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1612[v0, v1]) + T.writes(lv1612_local[v0, v1]) + lv1612_local[v0, v1] = lv1612[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1611_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(128) + + (k_1 * T.int64(2) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1611[v0, v1]) + T.writes(lv1611_local[v0, v1]) + lv1611_local[v0, v1] = lv1611[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(1024) + + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(1024), + (k_1 * T.int64(2) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv1622_shared[v_i0, v_i1, v_ki], + lv1611_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv1622_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1611_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(32) + + (k_1 * T.int64(2) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv1612_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv1612_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(512) + + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1622_shared[v0, v1, v2]) + lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1622_shared[v0, v1, v2]) + T.writes(lv1622_shared[v0, v1, v2]) + lv1622_shared[v0, v1, v2] = ( + lv1622_shared[v0, v1, v2] + + lv1622_shared[v0, v1, v2 + T.int64(4)] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(11008), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(512), + i0_i1_i2_fused_1 * T.int64(8) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1622_shared[v0, v1, v_i2k]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv1622_shared[ + v0, v1, v_i2k + ] * T.sigmoid(lv1622_shared[v0, v1, v_i2k]) + + +def sch_fused_decode5_fused_matmul6_silu1(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="compute", func_name="main") + sch.compute_inline(block=b41) + b42 = sch.get_block(name="T_multiply", func_name="main") + sch.reverse_compute_inline(block=b42) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l43, l44, l45, l46, l47 = sch.get_loops(block=b22) + l48, l49, l50 = sch.split(loop=l47, factors=[None, 64, 8], preserve_unit_iters=True) + sch.vectorize(loop=l50) + sch.bind(loop=l49, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + +@T.prim_func(private=True) +def fused_decode81_fused_matmul1_cast2( + lv1576: T.Buffer((T.int64(512), T.int64(64000)), "uint32"), + lv1577: T.Buffer((T.int64(128), T.int64(64000)), "float16"), + lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(64000)), "float32" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(64000)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(64000)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(64000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1576[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1577[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(64000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(64000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] + ) + +def sch_fused_decode81_fused_matmul1_cast2(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[160, 100, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="compute", func_name="main") + sch.reverse_compute_inline(block=b41) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l42, l43, l44, l45, l46 = sch.get_loops(block=b22) + l47, l48, l49 = sch.split( + loop=l46, factors=[None, 100, 2], preserve_unit_iters=True + ) + sch.vectorize(loop=l49) + sch.bind(loop=l48, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + + + +@T.prim_func(private=True) +def fused_decode4_fused_matmul4_add1( + lv1605: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv1606: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1605[v_i // T.int64(8), v_j], lv1606[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1605[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1606[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv197[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv197[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv1581[v_ax0, v_ax1, v_ax2], + var_matmul_intermediate[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv1581[v_ax0, v_ax1, v_ax2] + + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode4_fused_matmul4_add1_after( + lv1605: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv1606: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" + ) + lv1605_local = T.alloc_buffer( + (T.int64(512), T.int64(4096)), "uint32", scope="local" + ) + lv1606_local = T.alloc_buffer( + (T.int64(128), T.int64(4096)), "float16", scope="local" + ) + lv197_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(2)): + for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv197_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(2048) + + ax2_1 * T.int64(64) + + (ax2_y * T.int64(8) + ax2_2), + ) + v2k = T.axis.spatial( + T.int64(2048), + ( + ax2_1 * T.int64(64) + + ax2_y * T.int64(8) + + ax2_2 + ), + ) + T.reads(lv197[v0, v1, v2]) + T.writes(lv197_shared[v0, v1, v2k]) + lv197_shared[v0, v1, v2k] = lv197[v0, v1, v2] + for k_1 in range(T.int64(8)): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1606_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(64) + + (k_1 * T.int64(8) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1606[v0, v1]) + T.writes(lv1606_local[v0, v1]) + lv1606_local[v0, v1] = lv1606[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1605_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(256) + + (k_1 * T.int64(8) + ax2_y) * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv1605[v0, v1]) + T.writes(lv1605_local[v0, v1]) + lv1605_local[v0, v1] = lv1605[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(2048) + + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + v_ki = T.axis.reduce( + T.int64(2048), + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + lv197_shared[v_i0, v_i1, v_ki], + lv1605_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + lv197_shared[ + v_i0, v_i1, v_ki + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1605_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(64) + + (k_1 * T.int64(8) + ax2_y) + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv1606_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + * lv1606_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(32768), + i0_i1_i2_fused_0 * T.int64(1024) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv197_shared[v0, v1, v2]) + lv197_shared[v0, v1, v2] = var_matmul_intermediate_local[ + v0, v1, v_i2k + ] + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction_sum"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(4)) + T.reads(lv197_shared[v0, v1, v2]) + T.writes(lv197_shared[v0, v1, v2]) + lv197_shared[v0, v1, v2] = ( + lv197_shared[v0, v1, v2] + + lv197_shared[v0, v1, v2 + T.int64(16)] + ) + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + v_i2k = T.axis.spatial( + T.int64(1024), + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + ax2, + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv197_shared[v0, v1, v_i2k], lv1581[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = ( + lv1581[v0, v1, v2] + + lv197_shared[v0, v1, v_i2k] + + lv197_shared[v0, v1, v_i2k + T.int64(4)] + + lv197_shared[v0, v1, v_i2k + T.int64(8)] + + lv197_shared[v0, v1, v_i2k + T.int64(12)] + ) + +@T.prim_func(private=True) +def fused_decode82_fused_matmul1_cast2( + lv1576: T.Buffer((T.int64(512), T.int64(64000)), "uint32"), + lv1577: T.Buffer((T.int64(128), T.int64(64000)), "float16"), + lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(64000)), "float32" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2048), T.int64(64000)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(64000)), "float16" + ) + for i, j in T.grid(T.int64(2048), T.int64(64000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1576[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1577[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(64000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(64000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] + ) + +def sch_fused_decode82_fused_matmul1_cast2(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[160, 100, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="compute", func_name="main") + sch.reverse_compute_inline(block=b41) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l42, l43, l44, l45, l46 = sch.get_loops(block=b22) + l47, l48, l49 = sch.split( + loop=l46, factors=[None, 100, 2], preserve_unit_iters=True + ) + sch.vectorize(loop=l49) + sch.bind(loop=l48, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + +def sch_fused_decode4_fused_matmul4_add1(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[32, 64, 2] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="T_add", func_name="main") + sch.reverse_compute_inline(block=b41) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l42, l43, l44, l45, l46 = sch.get_loops(block=b22) + l47, l48, l49 = sch.split(loop=l46, factors=[None, 64, 8], preserve_unit_iters=True) + sch.vectorize(loop=l49) + sch.bind(loop=l48, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + +@T.prim_func(private=True) +def fused_decode3_fused_matmul1_cast2( + lv1576: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), + lv1577: T.Buffer((T.int64(128), T.int64(32000)), "float16"), + lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(32000)), "float32" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(32000)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1576[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1577[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] + ) + +@T.prim_func(private=True) +def fused_decode3_fused_matmul1_cast2_after( + lv1576: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), + lv1577: T.Buffer((T.int64(128), T.int64(32000)), "float16"), + lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(32000)), "float32" + ), +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(512000)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(512000)), "float16", scope="local" + ) + lv1576_local = T.alloc_buffer( + (T.int64(512), T.int64(32000)), "uint32", scope="local" + ) + lv1577_local = T.alloc_buffer( + (T.int64(128), T.int64(32000)), "float16", scope="local" + ) + lv1575_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + ) + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(125), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(512000), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(1)): + for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_2 in T.vectorized(T.int64(8)): + with T.block("lv1575_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(4096), + k_0 * T.int64(4096) + + ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + ax2_2 + ) + v2k = T.axis.spatial( + T.int64(4096), + (ax2_y * T.int64(512) + + ax2_1 * T.int64(8) + ax2_2) + ) + T.reads(lv1575[v0, v1, v2]) + T.writes(lv1575_shared[v0, v1, v2k]) + lv1575_shared[v0, v1, v2k] = lv1575[v0, v1, v2] + for k_1 in range(T.int64(16)): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(512000), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + ax1 + ) + T.reads() + T.writes(var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k]) + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1577_local"): + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(128) + + (k_1 * T.int64(8) + ax2_y) + ax0 + ) + v1 = T.axis.spatial( + T.int64(32000), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax1 + ) + T.reads(lv1577[v0, v1]) + T.writes(lv1577_local[v0, v1]) + lv1577_local[v0, v1] = lv1577[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv1576_local"): + v0 = T.axis.spatial( + T.int64(512), + k_0 * T.int64(512) + + (k_1 * T.int64(8) + ax2_y) * T.int64(4) + + k_2 + ax0 + ) + v1 = T.axis.spatial( + T.int64(32000), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1 + ) + T.reads(lv1576[v0, v1]) + T.writes(lv1576_local[v0, v1]) + lv1576_local[v0, v1] = lv1576[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial( + T.int64(32000), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2 + ) + v_i2k = T.axis.spatial( + T.int64(512000), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + + i0_i1_i2_fused_2 + ) + v_k = T.axis.reduce( + T.int64(4096), + k_0 * T.int64(4096) + + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + k_3 + ) + v_ki = T.axis.reduce( + T.int64(4096), + (k_1 * T.int64(8) + ax2_y) * T.int64(32) + + k_2 * T.int64(8) + k_3 + ) + T.reads( + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k], + lv1575_shared[v_i0, v_i1, v_ki], lv1576_local[v_k // T.int64(8), v_i2] + ) + T.writes(var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k]) + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] + + lv1575_shared[v_i0, v_i1, v_ki] + * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv1576_local[v_k // T.int64(8), v_i2], + T.Cast("uint32", v_k % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7))) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2k = T.axis.spatial( + T.int64(512000), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + ax1 + ) + v0 = T.axis.spatial( + T.int64(128), + k_0 * T.int64(128) + + (k_1 * T.int64(8) + ax2_y) + ax0 + ) + v1 = T.axis.spatial( + T.int64(32000), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax1 + ) + T.reads( + lv1577_local[v0, v1], + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] + ) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2k]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2k] + + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] * lv1577_local[v0, v1] + ) + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_update"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(2048), + ax2_y * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax2 + ) + v_i2k = T.axis.spatial( + T.int64(512000), + i0_i1_i2_fused_0 * T.int64(2048) + + i0_i1_i2_fused_1 * T.int64(32) + + ax2_y * T.int64(4) + ax2 + ) + T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) + T.writes(lv1575_shared[v0, v1, v2]) + lv1575_shared[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v_i2k] + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("reduction_2"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v_i2k = T.axis.spatial( + T.int64(2048), + ax2_y * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax2 + ) + T.where(ax2_y < T.int64(4)) + T.reads(lv1575_shared[v0, v1, v_i2k]) + T.writes(lv1575_shared[v0, v1, v_i2k]) + lv1575_shared[v0, v1, v_i2k] = ( + lv1575_shared[v0, v1, v_i2k] + lv1575_shared[v0, v1, v_i2k + T.int64(1024)] + ) + for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial( + T.int64(32000), + i0_i1_i2_fused_0 * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax2 + ) + v_i2k = T.axis.spatial( + T.int64(2048), + ax2_y * T.int64(256) + + i0_i1_i2_fused_1 * T.int64(4) + ax2 + ) + T.where(ax2_y < T.int64(1)) + T.reads(lv1575_shared[v0, v1, v_i2k]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = T.Cast( + "float32", lv1575_shared[v0, v1, v_i2k] + + lv1575_shared[v0, v1, v_i2k + T.int64(256)] + + lv1575_shared[v0, v1, v_i2k + T.int64(512)] + + lv1575_shared[v0, v1, v_i2k + T.int64(768)] + ) + + +def sch_fused_decode3_fused_matmul1_cast2(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[80, 100, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") + b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") + b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) + v23 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate( + block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 + ) + sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) + sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) + l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) + sch.vectorize(loop=l29) + l30, l31, l32, l33, l34 = sch.get_loops(block=b21) + sch.vectorize(loop=l34) + l35, l36, l37, l38, l39 = sch.get_loops(block=b19) + sch.vectorize(loop=l39) + sch.vectorize(loop=l12) + b40 = sch.decompose_reduction(block=b1, loop=l16) + b41 = sch.get_block(name="compute", func_name="main") + sch.reverse_compute_inline(block=b41) + sch.enter_postproc() + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") + l42, l43, l44, l45, l46 = sch.get_loops(block=b22) + l47, l48, l49 = sch.split( + loop=l46, factors=[None, 100, 2], preserve_unit_iters=True + ) + sch.vectorize(loop=l49) + sch.bind(loop=l48, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode2_fused_NT_matmul3_add( + lv50: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), + lv51: T.Buffer((T.int64(344), T.int64(4096)), "float16"), + p_lv5: T.handle, + p_lv3: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv5 = T.match_buffer(p_lv5, (T.int64(1), n, T.int64(11008)), "float16") + lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(4096)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(11008)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv50[v_i // T.int64(8), v_j], lv51[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv50[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv51[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv5[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv5[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv3[v_ax0, v_ax1, v_ax2], + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv3[v_ax0, v_ax1, v_ax2] + + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode2_fused_NT_matmul3_add_after( + lv8: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), + lv9: T.Buffer((T.int64(344), T.int64(4096)), "float16"), + p_lv5: T.handle, + p_lv3: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv6 = T.match_buffer(p_lv5, (1, n, 11008), "float16") + lv2 = T.match_buffer(p_lv3, (1, n, 4096), "float16") + var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") + + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" + ) + lv8_local = T.alloc_buffer((T.int64(512), T.int64(4096)), "uint32", scope="local") + lv9_local = T.alloc_buffer( + (T.int64(128), T.int64(4096)), "float16", scope="local" + ) + #lv6_shared = T.alloc_buffer( + # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + #) + for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + with T.block("n_check"): + T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_1 in range(T.int64(344)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2k = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv9_local"): + v0 = T.axis.spatial( + T.int64(344), k_1 + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv8_local"): + v0 = T.axis.spatial( + T.int64(1376), + k_1 * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(11008), + k_1 * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + lv6[v_i0, v_i1, v_k], + lv8_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + lv6[ + v_i0, v_i1, v_k + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(344), + k_1 + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv9_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + * lv9_local[v0, v1] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2[v_i0, v_i1, v_i2]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_i2] + + +@T.prim_func(private=True) +def fused_decode_NT_matmul( + lv8: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv9: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + p_lv6: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16") + var_NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(4096)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv9[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv6[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv6[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + + +@T.prim_func(private=True) +def fused_decode_NT_matmul_after( + lv8: T.Buffer((512, 4096), "uint32"), + lv9: T.Buffer((128, 4096), "float16"), + p_lv6: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int32() + lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") + var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") + # with T.block("root"): + decode_local = T.alloc_buffer((4096, 4096), "float16", scope="local") + lv8_local = T.alloc_buffer((512, 4096), "uint32", scope="local") + lv9_local = T.alloc_buffer((128, 4096), "float16", scope="local") + lv6_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + 31) // 32, thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(32, thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): + for i2_1 in T.thread_binding(16, thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(4): + for i2_2_init in T.vectorized(8): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + 4096, i2_0 * 128 + i2_1 * 8 + i2_2_init + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float16(0) + for k_0 in range(128): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv9_local"): + v0 = T.axis.spatial(128, k_0 + ax0) + v1 = T.axis.spatial( + 4096, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_1 in range(4): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv8_local"): + v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) + v1 = T.axis.spatial( + 4096, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_2 in range(8): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("decode"): + v_i = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 + ) + v_j = T.axis.spatial( + 4096, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads( + lv8_local[v_i // 8, v_j], + lv9_local[v_i // 32, v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[v_i // 8, v_j], + T.Cast("uint32", v_i % 8) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv9_local[v_i // 32, v_j] + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(1): + with T.block("lv6_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 + ) + T.reads(lv6[v0, v1, v2]) + T.writes(lv6_pad_local[v0, v1, v2]) + lv6_pad_local[v0, v1, v2] = T.if_then_else( + v1 < n, lv6[v0, v1, v2], T.float16(0) + ) + for i0_i1_fused_1_2 in range(4): + for i2_2 in T.vectorized(8): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + 4096, i2_0 * 128 + i2_1 * 8 + i2_2 + ) + v_k = T.axis.reduce( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv6_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = ( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + + lv6_pad_local[v_i0, v_i1, v_k] + * decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(8): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial(4096, i2_0 * 128 + i2_1 * 8 + ax2) + T.reads( + var_NT_matmul_intermediate_pad_local[v0, v1, v2] + ) + T.writes(var_NT_matmul_intermediate[v0, v1, v2]) + if v1 < n: + var_NT_matmul_intermediate[ + v0, v1, v2 + ] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul2_silu( + lv36: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv37: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + p_lv45: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(11008)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(11008), T.int64(4096)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(11008)), "float16" + ) + compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv36[v_i // T.int64(8), v_j], lv37[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv36[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv37[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], + compute[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + * compute[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul2_silu_after( + lv36: T.Buffer((512, 11008), "uint32"), + lv37: T.Buffer((128, 11008), "float16"), + p_lv45: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int32() + lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16") + # with T.block("root"): + decode_local = T.alloc_buffer((4096, 11008), "float16", scope="local") + lv36_local = T.alloc_buffer((512, 11008), "uint32", scope="local") + lv37_local = T.alloc_buffer((128, 11008), "float16", scope="local") + lv45_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 11008), "float16", scope="local" + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + 31) // 32, thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(86, thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): + for i2_1 in T.thread_binding(16, thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(4): + for i2_2_init in T.vectorized(8): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + i2_2_init + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float16(0) + for k_0 in range(128): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv37_local"): + v0 = T.axis.spatial(128, k_0 + ax0) + v1 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv37[v0, v1]) + T.writes(lv37_local[v0, v1]) + lv37_local[v0, v1] = lv37[v0, v1] + for k_1 in range(4): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv36_local"): + v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) + v1 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv36[v0, v1]) + T.writes(lv36_local[v0, v1]) + lv36_local[v0, v1] = lv36[v0, v1] + for k_2 in range(8): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("decode"): + v_i = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 + ) + v_j = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads( + lv36_local[v_i // 8, v_j], + lv37_local[v_i // 32, v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv36_local[v_i // 8, v_j], + T.Cast("uint32", v_i % 8) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv37_local[v_i // 32, v_j] + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(1): + with T.block("lv45_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 + ) + T.reads(lv45[v0, v1, v2]) + T.writes(lv45_pad_local[v0, v1, v2]) + lv45_pad_local[v0, v1, v2] = T.if_then_else( + v1 < n, lv45[v0, v1, v2], T.float16(0) + ) + for i0_i1_fused_1_2 in range(4): + for i2_2 in T.vectorized(8): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + i2_2 + ) + v_k = T.axis.reduce( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv45_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = ( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + + lv45_pad_local[v_i0, v_i1, v_k] + * decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(8): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial(11008, i2_0 * 128 + i2_1 * 8 + ax2) + T.reads( + var_NT_matmul_intermediate_pad_local[v0, v1, v2] + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[ + v0, v1, v2 + ] = var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] * T.sigmoid( + var_NT_matmul_intermediate_pad_local[v0, v1, v2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul2_multiply( + lv43: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), + lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), + p_lv45: T.handle, + p_lv132: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") + lv132 = T.match_buffer(p_lv132, (T.int64(1), n, T.int64(11008)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(11008)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(11008), T.int64(4096)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(11008)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv43[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv44[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv132[v_ax0, v_ax1, v_ax2], + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv132[v_ax0, v_ax1, v_ax2] + * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul2_multiply_after( + lv43: T.Buffer((512, 11008), "uint32"), + lv44: T.Buffer((128, 11008), "float16"), + p_lv45: T.handle, + p_lv132: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int32() + lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16") + lv132 = T.match_buffer(p_lv132, (1, n, 11008), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16") + # with T.block("root"): + decode_local = T.alloc_buffer((4096, 11008), "float16", scope="local") + lv43_local = T.alloc_buffer((512, 11008), "uint32", scope="local") + lv44_local = T.alloc_buffer((128, 11008), "float16", scope="local") + lv45_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (1, (n + 31) // 32 * 32, 11008), "float16", scope="local" + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + 31) // 32, thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(86, thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): + for i2_1 in T.thread_binding(16, thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(4): + for i2_2_init in T.vectorized(8): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + i2_2_init + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float16(0) + for k_0 in range(128): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv44_local"): + v0 = T.axis.spatial(128, k_0 + ax0) + v1 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv44[v0, v1]) + T.writes(lv44_local[v0, v1]) + lv44_local[v0, v1] = lv44[v0, v1] + for k_1 in range(4): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("lv43_local"): + v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) + v1 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads(lv43[v0, v1]) + T.writes(lv43_local[v0, v1]) + lv43_local[v0, v1] = lv43[v0, v1] + for k_2 in range(8): + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("decode"): + v_i = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 + ) + v_j = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + ax1 + ) + T.reads( + lv43_local[v_i // 8, v_j], + lv44_local[v_i // 32, v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv43_local[v_i // 8, v_j], + T.Cast("uint32", v_i % 8) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv44_local[v_i // 32, v_j] + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(1): + with T.block("lv45_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 + ) + T.reads(lv45[v0, v1, v2]) + T.writes(lv45_pad_local[v0, v1, v2]) + lv45_pad_local[v0, v1, v2] = T.if_then_else( + v1 < n, lv45[v0, v1, v2], T.float16(0) + ) + for i0_i1_fused_1_2 in range(4): + for i2_2 in T.vectorized(8): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial(1, 0) + v_i1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + 11008, i2_0 * 128 + i2_1 * 8 + i2_2 + ) + v_k = T.axis.reduce( + 4096, k_0 * 32 + k_1 * 8 + k_2 + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv45_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = ( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + + lv45_pad_local[v_i0, v_i1, v_k] + * decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(1, 4): + for ax2 in T.vectorized(8): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial( + (n + 31) // 32 * 32, + i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 + + i0_i1_fused_1_1 * 4 + + ax1, + ) + v2 = T.axis.spatial(11008, i2_0 * 128 + i2_1 * 8 + ax2) + T.reads( + lv132[v0, v1, v2], + var_NT_matmul_intermediate_pad_local[v0, v1, v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = ( + lv132[v0, v1, v2] + * var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + ) + + +@T.prim_func(private=True) +def fused_decode_fused_NT_matmul_add( + lv29: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv30: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + p_lv41: T.handle, + p_lv2: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(4096)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv29[v_i // T.int64(8), v_j], lv30[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv29[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv30[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv41[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + lv2[v_ax0, v_ax1, v_ax2], + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv2[v_ax0, v_ax1, v_ax2] + + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode_fused_NT_matmul_add_after( + lv8: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), + lv9: T.Buffer((T.int64(128), T.int64(4096)), "float16"), + p_lv41: T.handle, + p_lv2: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv6 = T.match_buffer(p_lv41, (1, n, 4096), "float16") + lv2 = T.match_buffer(p_lv2, (1, n, 4096), "float16") + var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") + + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" + ) + lv8_local = T.alloc_buffer((T.int64(512), T.int64(4096)), "uint32", scope="local") + lv9_local = T.alloc_buffer( + (T.int64(128), T.int64(4096)), "float16", scope="local" + ) + #lv6_shared = T.alloc_buffer( + # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + #) + for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + with T.block("n_check"): + T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_1 in range(T.int64(128)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2k = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv9_local"): + v0 = T.axis.spatial( + T.int64(128), k_1 + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv8_local"): + v0 = T.axis.spatial( + T.int64(512), + k_1 * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_1 * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + lv6[v_i0, v_i1, v_k], + lv8_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + lv6[ + v_i0, v_i1, v_k + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_1 + ) + v1 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv9_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + * lv9_local[v0, v1] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(4096), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2[v_i0, v_i1, v_i2]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_i2] + + +@T.prim_func(private=True) +def fused_decode4_fused_matmul6_add4( + lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv1364: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + lv2067: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + linear_bias192: T.Buffer((T.int64(2560),), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1363[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1364[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv2067[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] + ) + + +def sch_fused_decode4_fused_matmul6_add4(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[160, 8, 2] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + sch.reverse_compute_inline(block=b28) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode6_fused_matmul9_add7_cast8_cast12_add5( + lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv1394: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + lv2121: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), + linear_bias197: T.Buffer((T.int64(2560),), "float32"), + lv329: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + var_compute_intermediate_1 = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1393[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1394[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv2121[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ + v_i0, v_i1, v_i2 + ] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], + lv329[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + + lv329[v_ax0, v_ax1, v_ax2] + ) + + +def sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[640, 2, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + bb1 = sch.get_block(name="compute", func_name="main") + bb2 = sch.get_block(name="compute_1", func_name="main") + bb3 = sch.get_block(name="T_add_1", func_name="main") + sch.compute_inline(block=b28) + sch.compute_inline(block=bb1) + sch.compute_inline(block=bb2) + sch.reverse_compute_inline(block=bb3) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode5_fused_matmul8_add6_gelu1_cast11( + lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), + lv1388: T.Buffer((T.int64(80), T.int64(10240)), "float16"), + lv2115: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + linear_bias196: T.Buffer((T.int64(10240),), "float32"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(10240)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(10240)) + ) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1387[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1388[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv2115[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[ + v_ax0, v_ax1, v_ax2 + ] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[ + v_ax0, v_ax1, v_ax2 + ] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = ( + T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2] + ) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2] + ) + + +def sch_fused_decode5_fused_matmul8_add6_gelu1_cast11(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 4] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[80, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + bb1 = sch.get_block(name="T_multiply", func_name="main") + bb2 = sch.get_block(name="compute", func_name="main") + bb3 = sch.get_block(name="T_multiply_1", func_name="main") + bb4 = sch.get_block(name="T_add_1", func_name="main") + bb5 = sch.get_block(name="T_multiply_2", func_name="main") + bb6 = sch.get_block(name="compute_1", func_name="main") + sch.compute_inline(block=b28) + sch.compute_inline(block=bb1) + sch.compute_inline(block=bb2) + sch.compute_inline(block=bb3) + sch.compute_inline(block=bb4) + sch.compute_inline(block=bb5) + sch.reverse_compute_inline(block=bb6) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode4_fused_matmul6_add4_add5( + lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv1382: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + lv328: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + linear_bias195: T.Buffer((T.int64(2560),), "float16"), + lv2062: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + var_T_add_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1381[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1382[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv328[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv328[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2062[v_ax0, v_ax1, v_ax2] + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + + lv2062[v_ax0, v_ax1, v_ax2] + ) + + +def sch_fused_decode4_fused_matmul6_add4_add5(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[160, 8, 2] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + bb4 = sch.get_block(name="T_add_1", func_name="main") + sch.compute_inline(block=b28) + sch.reverse_compute_inline(block=bb4) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode3_matmul3( + lv2515: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), + lv2516: T.Buffer((T.int64(80), T.int64(50432)), "float32"), + lv705: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), + var_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(50432)), "float32" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) + for i, j in T.grid(T.int64(2560), T.int64(50432)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2515[v_i // T.int64(8), v_j], lv2516[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float32", + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv2515[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7), + ) + * lv2516[v_i // T.int64(32), v_j] + ) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate[v_i0, v_i1, v_i2] + + lv705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + ) + + +def sch_fused_decode3_matmul3(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[197, 128, 2] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[80, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 128, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7( + lv2509: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv2510: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + lv4105: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), + linear_bias383: T.Buffer((T.int64(2560),), "float32"), + lv701: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), + p_output0_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float32" + ), +): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + var_compute_intermediate_1 = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + var_T_add_intermediate_1 = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2509[v_i // T.int64(8), v_j], lv2510[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv2509[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv2510[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv4105[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_decode_intermediate[v_k, v_i2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias383[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ + v_i0, v_i1, v_i2 + ] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], + lv701[v_ax0, v_ax1, v_ax2], + ) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = ( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + + lv701[v_ax0, v_ax1, v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2] + ) + + +def sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block(name="decode", func_name="main") + b1 = sch.get_block(name="matmul", func_name="main") + l2, l3, l4, l5 = sch.get_loops(block=b1) + l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) + v7, v8, v9 = sch.sample_perfect_tile( + loop=l6, n=3, max_innermost_factor=4, decision=[5, 256, 2] + ) + l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) + v13, v14, v15 = sch.sample_perfect_tile( + loop=l5, n=3, max_innermost_factor=8, decision=[320, 4, 8] + ) + l16, l17, l18 = sch.split( + loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True + ) + sch.reorder(l10, l11, l16, l17, l18, l12) + sch.bind(loop=l10, thread_axis="blockIdx.x") + sch.bind(loop=l11, thread_axis="threadIdx.x") + sch.compute_inline(block=b0) + b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) + b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) + v21 = sch.sample_categorical( + candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate( + block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 + ) + l22, l23, l24, l25, l26 = sch.get_loops(block=b19) + sch.vectorize(loop=l26) + sch.vectorize(loop=l12) + b27 = sch.decompose_reduction(block=b1, loop=l16) + b28 = sch.get_block(name="T_add", func_name="main") + bb1 = sch.get_block(name="compute", func_name="main") + bb2 = sch.get_block(name="compute_1", func_name="main") + bb3 = sch.get_block(name="T_add_1", func_name="main") + bb4 = sch.get_block(name="compute_2", func_name="main") + sch.compute_inline(block=b28) + sch.compute_inline(block=bb1) + sch.compute_inline(block=bb2) + sch.compute_inline(block=bb3) + sch.reverse_compute_inline(block=bb4) + sch.enter_postproc() + sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") + l29, l30, l31, l32, l33 = sch.get_loops(block=b20) + l34, l35, l36 = sch.split( + loop=l33, factors=[None, 256, 8], preserve_unit_iters=True + ) + sch.vectorize(loop=l36) + sch.bind(loop=l35, thread_axis="threadIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func(private=True) +def fused_decode2_fused_NT_matmul3_add6_gelu1_cast11( + lv36: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), + lv37: T.Buffer((T.int64(80), T.int64(10240)), "float16"), + p_lv57: T.handle, + linear_bias4: T.Buffer((T.int64(10240),), "float32"), + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(10240)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(10240), T.int64(2560)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv36[v_i // T.int64(8), v_j], lv37[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv36[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv37[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv57[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv57[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2] + ) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[ + v_ax0, v_ax1, v_ax2 + ] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[ + v_ax0, v_ax1, v_ax2 + ] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = ( + T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2] + ) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2] + ) + + +@T.prim_func(private=True) +def fused_decode2_fused_NT_matmul3_add6_gelu1_cast11_after( + lv36: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), + lv37: T.Buffer((T.int64(80), T.int64(10240)), "float16"), + p_lv57: T.handle, + linear_bias4: T.Buffer((T.int64(10240),), "float32"), + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) + n = T.int64() + lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(10240)), "float16" + ) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) + decode_local = T.alloc_buffer( + (T.int64(2560), T.int64(10240)), "float16", scope="local" + ) + lv36_local = T.alloc_buffer( + (T.int64(320), T.int64(10240)), "uint32", scope="local" + ) + lv37_local = T.alloc_buffer( + (T.int64(80), T.int64(10240)), "float16", scope="local" + ) + lv57_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + "float16", + scope="local", + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + ( + T.int64(1), + (n + T.int64(31)) // T.int64(32) * T.int64(32), + T.int64(10240), + ), + scope="local", + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding( + T.int64(8), thread="threadIdx.y" + ): + for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2_init, + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float32(0) + for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv37_local"): + v0 = T.axis.spatial( + T.int64(80), + k_0_0 * T.int64(4) + k_0_1 + ax0, + ) + v1 = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv37[v0, v1]) + T.writes(lv37_local[v0, v1]) + lv37_local[v0, v1] = lv37[v0, v1] + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv36_local"): + v0 = T.axis.spatial( + T.int64(320), + k_0_0 * T.int64(16) + + k_0_1 * T.int64(4) + + k_1 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv36[v0, v1]) + T.writes(lv36_local[v0, v1]) + lv36_local[v0, v1] = lv36[v0, v1] + for k_2 in range(T.int64(8)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("decode"): + v_i = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax0, + ) + v_j = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads( + lv36_local[v_i // T.int64(8), v_j], + lv37_local[v_i // T.int64(32), v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv36_local[ + v_i // T.int64(8), + v_j, + ], + T.Cast( + "uint32", + v_i % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv37_local[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(1)): + with T.block("lv57_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax2, + ) + T.reads(lv57[v0, v1, v2]) + T.writes(lv57_pad_local[v0, v1, v2]) + lv57_pad_local[ + v0, v1, v2 + ] = T.if_then_else( + v1 < n, + lv57[v0, v1, v2], + T.float16(0), + ) + for i0_i1_fused_1_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2, + ) + v_k = T.axis.reduce( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv57_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + T.Cast( + "float32", + lv57_pad_local[v_i0, v_i1, v_k], + ) * T.Cast( + "float32", decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(10240), + i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ], + linear_bias4[v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = T.Cast( + "float16", + ( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias4[v2] + ) + * ( + T.float32(0.5) + + T.erf( + ( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias4[v2] + ) + * T.float32(0.70710678118654757) + ) + * T.float32(0.5) + ), + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul1_add4( + lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + p_lv9: T.handle, + linear_bias: T.Buffer((T.int64(2560),), "float16"), + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv9_1 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(2560), T.int64(2560)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv9[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv9_1[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv9_1[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul1_add4_after( + lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + p_lv9: T.handle, + linear_bias: T.Buffer((T.int64(2560),), "float16"), + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) + n = T.int64() + lv9_1 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) + decode_local = T.alloc_buffer( + (T.int64(2560), T.int64(2560)), "float16", scope="local" + ) + lv8_local = T.alloc_buffer( + (T.int64(320), T.int64(2560)), "uint32", scope="local" + ) + lv9_local = T.alloc_buffer( + (T.int64(80), T.int64(2560)), "float16", scope="local" + ) + lv9_1_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + "float16", + scope="local", + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + "float16", + scope="local", + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding( + T.int64(8), thread="threadIdx.y" + ): + for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2_init, + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float16(0) + for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv9_local"): + v0 = T.axis.spatial( + T.int64(80), + k_0_0 * T.int64(4) + k_0_1 + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv8_local"): + v0 = T.axis.spatial( + T.int64(320), + k_0_0 * T.int64(16) + + k_0_1 * T.int64(4) + + k_1 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_2 in range(T.int64(8)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("decode"): + v_i = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax0, + ) + v_j = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads( + lv8_local[v_i // T.int64(8), v_j], + lv9_local[v_i // T.int64(32), v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[ + v_i // T.int64(8), + v_j, + ], + T.Cast( + "uint32", + v_i % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv9_local[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(1)): + with T.block("lv9_1_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax2, + ) + T.reads(lv9_1[v0, v1, v2]) + T.writes(lv9_1_pad_local[v0, v1, v2]) + lv9_1_pad_local[ + v0, v1, v2 + ] = T.if_then_else( + v1 < n, + lv9_1[v0, v1, v2], + T.float16(0), + ) + for i0_i1_fused_1_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2, + ) + v_k = T.axis.reduce( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv9_1_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = ( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + + lv9_1_pad_local[v_i0, v_i1, v_k] + * decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ], + linear_bias[v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = ( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias[v2] + ) + + +@T.prim_func(private=True) +def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5( + lv43: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv44: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + p_lv63: T.handle, + linear_bias5: T.Buffer((T.int64(2560),), "float32"), + p_lv7: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") + lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(2560), T.int64(10240)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer( + (T.int64(1), n, T.int64(2560)), "float16" + ) + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv43[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv44[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv63[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv63[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2] + ) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ + v_i0, v_i1, v_i2 + ] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], + lv7[v_ax0, v_ax1, v_ax2], + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + + lv7[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_after( + lv43: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv44: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + p_lv63: T.handle, + linear_bias5: T.Buffer((T.int64(2560),), "float32"), + p_lv7: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) + n = T.int64() + lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") + lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) + decode_local = T.alloc_buffer( + (T.int64(10240), T.int64(2560)), "float16", scope="local" + ) + lv43_local = T.alloc_buffer( + (T.int64(1280), T.int64(2560)), "uint32", scope="local" + ) + lv44_local = T.alloc_buffer( + (T.int64(320), T.int64(2560)), "float16", scope="local" + ) + lv63_pad_local = T.alloc_buffer( + ( + T.int64(1), + (n + T.int64(31)) // T.int64(32) * T.int64(32), + T.int64(10240), + ), + "float16", + scope="local", + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + scope="local", + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding( + T.int64(8), thread="threadIdx.y" + ): + for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2_init, + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float32(0) + for k_0_0, k_0_1 in T.grid(T.int64(80), T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv44_local"): + v0 = T.axis.spatial( + T.int64(320), + k_0_0 * T.int64(4) + k_0_1 + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv44[v0, v1]) + T.writes(lv44_local[v0, v1]) + lv44_local[v0, v1] = lv44[v0, v1] + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv43_local"): + v0 = T.axis.spatial( + T.int64(1280), + k_0_0 * T.int64(16) + + k_0_1 * T.int64(4) + + k_1 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv43[v0, v1]) + T.writes(lv43_local[v0, v1]) + lv43_local[v0, v1] = lv43[v0, v1] + for k_2 in range(T.int64(8)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("decode"): + v_i = T.axis.spatial( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax0, + ) + v_j = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads( + lv43_local[v_i // T.int64(8), v_j], + lv44_local[v_i // T.int64(32), v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv43_local[ + v_i // T.int64(8), + v_j, + ], + T.Cast( + "uint32", + v_i % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv44_local[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(1)): + with T.block("lv63_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax2, + ) + T.reads(lv63[v0, v1, v2]) + T.writes(lv63_pad_local[v0, v1, v2]) + lv63_pad_local[ + v0, v1, v2 + ] = T.if_then_else( + v1 < n, + lv63[v0, v1, v2], + T.float16(0), + ) + for i0_i1_fused_1_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2, + ) + v_k = T.axis.reduce( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv63_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + T.Cast( + "float32", + lv63_pad_local[v_i0, v_i1, v_k], + ) * T.Cast( + "float32", decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ], + linear_bias5[v2], + lv7[v0, v1, v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = ( + T.Cast( + "float16", + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias5[v2], + ) + + lv7[v0, v1, v2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul1_add4_add5( + lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + p_lv49: T.handle, + linear_bias3: T.Buffer((T.int64(2560),), "float16"), + p_lv2: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(2560), T.int64(2560)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer( + (T.int64(1), n, T.int64(2560)), "float16" + ) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv29[v_i // T.int64(8), v_j], lv30[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv29[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv30[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv49[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2] + ) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2] + ) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + ) + + +@T.prim_func(private=True) +def fused_decode1_fused_NT_matmul1_add4_add5_after( + lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), + lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), + p_lv49: T.handle, + linear_bias3: T.Buffer((T.int64(2560),), "float16"), + p_lv2: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(2560)), "float16" + ) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) + decode_local = T.alloc_buffer( + (T.int64(2560), T.int64(2560)), "float16", scope="local" + ) + lv29_local = T.alloc_buffer( + (T.int64(320), T.int64(2560)), "uint32", scope="local" + ) + lv30_local = T.alloc_buffer( + (T.int64(80), T.int64(2560)), "float16", scope="local" + ) + lv49_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + "float16", + scope="local", + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + "float16", + scope="local", + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding( + T.int64(8), thread="threadIdx.y" + ): + for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2_init, + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float16(0) + for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv30_local"): + v0 = T.axis.spatial( + T.int64(80), + k_0_0 * T.int64(4) + k_0_1 + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv30[v0, v1]) + T.writes(lv30_local[v0, v1]) + lv30_local[v0, v1] = lv30[v0, v1] + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv29_local"): + v0 = T.axis.spatial( + T.int64(320), + k_0_0 * T.int64(16) + + k_0_1 * T.int64(4) + + k_1 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv29[v0, v1]) + T.writes(lv29_local[v0, v1]) + lv29_local[v0, v1] = lv29[v0, v1] + for k_2 in range(T.int64(8)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("decode"): + v_i = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax0, + ) + v_j = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads( + lv29_local[v_i // T.int64(8), v_j], + lv30_local[v_i // T.int64(32), v_j], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv29_local[ + v_i // T.int64(8), + v_j, + ], + T.Cast( + "uint32", + v_i % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv30_local[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(1)): + with T.block("lv49_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax2, + ) + T.reads(lv49[v0, v1, v2]) + T.writes(lv49_pad_local[v0, v1, v2]) + lv49_pad_local[ + v0, v1, v2 + ] = T.if_then_else( + v1 < n, + lv49[v0, v1, v2], + T.float16(0), + ) + for i0_i1_fused_1_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2, + ) + v_k = T.axis.reduce( + T.int64(2560), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv49_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = ( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + + lv49_pad_local[v_i0, v_i1, v_k] + * decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ], + linear_bias3[v2], + lv2[v0, v1, v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = ( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias3[v2] + + lv2[v0, v1, v2] + ) + + +@T.prim_func(private=True) +def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7( + lv1345: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv1346: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + p_lv2047: T.handle, + linear_bias191: T.Buffer((T.int64(2560),), "float32"), + p_lv317: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") + lv317 = T.match_buffer(p_lv317, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_T_transpose_intermediate = T.alloc_buffer( + (T.int64(2560), T.int64(10240)), "float16" + ) + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer( + (T.int64(1), n, T.int64(2560)), "float16" + ) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1345[v_i // T.int64(8), v_j], lv1346[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1345[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1346[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) + var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2047[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", lv2047[v_i0, v_i1, v_k]) * T.Cast( + "float32", var_T_transpose_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2] + ) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ + v_i0, v_i1, v_i2 + ] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], + lv317[v_ax0, v_ax1, v_ax2], + ) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = ( + var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + + lv317[v_ax0, v_ax1, v_ax2] + ) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( + "float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2] + ) + + +@T.prim_func(private=True) +def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7_after( + lv1345: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), + lv1346: T.Buffer((T.int64(320), T.int64(2560)), "float16"), + p_lv2047: T.handle, + linear_bias191: T.Buffer((T.int64(2560),), "float32"), + p_lv317: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) + n = T.int64() + lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") + lv317 = T.match_buffer(p_lv317, (T.int64(1), n, T.int64(2560)), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) + decode_local = T.alloc_buffer( + (T.int64(10240), T.int64(2560)), "float16", scope="local" + ) + lv1345_local = T.alloc_buffer( + (T.int64(1280), T.int64(2560)), "uint32", scope="local" + ) + lv1346_local = T.alloc_buffer( + (T.int64(320), T.int64(2560)), "float16", scope="local" + ) + lv2047_pad_local = T.alloc_buffer( + ( + T.int64(1), + (n + T.int64(31)) // T.int64(32) * T.int64(32), + T.int64(10240), + ), + "float16", + scope="local", + ) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer( + (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), + scope="local", + ) + for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( + (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" + ): + for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): + for i0_i1_fused_1_1 in T.thread_binding( + T.int64(8), thread="threadIdx.y" + ): + for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for i0_i1_fused_1_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2_init, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2_init, + ) + T.reads() + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = T.float32(0) + for k_0_0, k_0_1 in T.grid(T.int64(80), T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv1346_local"): + v0 = T.axis.spatial( + T.int64(320), + k_0_0 * T.int64(4) + k_0_1 + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv1346[v0, v1]) + T.writes(lv1346_local[v0, v1]) + lv1346_local[v0, v1] = lv1346[v0, v1] + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("lv1345_local"): + v0 = T.axis.spatial( + T.int64(1280), + k_0_0 * T.int64(16) + + k_0_1 * T.int64(4) + + k_1 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads(lv1345[v0, v1]) + T.writes(lv1345_local[v0, v1]) + lv1345_local[v0, v1] = lv1345[v0, v1] + for k_2 in range(T.int64(8)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("decode"): + v_i = T.axis.spatial( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax0, + ) + v_j = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + ax1, + ) + T.reads( + lv1345_local[ + v_i // T.int64(8), v_j + ], + lv1346_local[ + v_i // T.int64(32), v_j + ], + ) + T.writes(decode_local[v_i, v_j]) + decode_local[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv1345_local[ + v_i // T.int64(8), + v_j, + ], + T.Cast( + "uint32", + v_i % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv1346_local[ + v_i // T.int64(32), v_j + ] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(1)): + with T.block("lv2047_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2 + + ax2, + ) + T.reads(lv2047[v0, v1, v2]) + T.writes(lv2047_pad_local[v0, v1, v2]) + lv2047_pad_local[ + v0, v1, v2 + ] = T.if_then_else( + v1 < n, + lv2047[v0, v1, v2], + T.float16(0), + ) + for i0_i1_fused_1_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial( + T.int64(1), T.int64(0) + ) + v_i1 = T.axis.spatial( + (n + T.int64(31)) + // T.int64(32) + * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + i0_i1_fused_1_2, + ) + v_i2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + + i2_1 * T.int64(8) + + i2_2, + ) + v_k = T.axis.reduce( + T.int64(10240), + k_0_0 * T.int64(128) + + k_0_1 * T.int64(32) + + k_1 * T.int64(8) + + k_2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ], + lv2047_pad_local[v_i0, v_i1, v_k], + decode_local[v_k, v_i2], + ) + T.writes( + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + ) + var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] = var_NT_matmul_intermediate_pad_local[ + v_i0, v_i1, v_i2 + ] + T.Cast( + "float32", + lv2047_pad_local[v_i0, v_i1, v_k], + ) * T.Cast( + "float32", decode_local[v_k, v_i2] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial( + (n + T.int64(31)) // T.int64(32) * T.int64(32), + i0_i1_fused_0_i0_i1_fused_1_0_fused + * T.int64(32) + + i0_i1_fused_1_1 * T.int64(4) + + ax1, + ) + v2 = T.axis.spatial( + T.int64(2560), + i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, + ) + T.reads( + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ], + linear_bias191[v2], + lv317[v0, v1, v2], + ) + T.writes(p_output0_intermediate[v0, v1, v2]) + if v1 < n: + p_output0_intermediate[v0, v1, v2] = T.Cast( + "float32", + T.Cast( + "float16", + var_NT_matmul_intermediate_pad_local[ + v0, v1, v2 + ] + + linear_bias191[v2], + ) + + lv317[v0, v1, v2], + ) + + +@T.prim_func(private=True) +def fused_decode2_NT_matmul( + lv4: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), + lv5: T.Buffer((T.int64(128), T.int64(12288)), "float16"), + p_lv6: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16") + var_NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(12288)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + p_output0_intermediate = T.alloc_buffer((T.int64(12288), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(12288)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv4[v_i // T.int64(8), v_j], lv5[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv4[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv5[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(12288), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(p_output0_intermediate[v_ax0, v_ax1]) + p_output0_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(12288), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv6[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv6[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + ) + + +@T.prim_func(private=True) +def fused_decode2_NT_matmul_after( + lv8: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), + lv9: T.Buffer((T.int64(128), T.int64(12288)), "float16"), + p_lv6: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") + var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 12288), "float16") + + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(12288)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(12288)), "float16", scope="local" + ) + lv8_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") + lv9_local = T.alloc_buffer( + (T.int64(128), T.int64(12288)), "float16", scope="local" + ) + #lv6_shared = T.alloc_buffer( + # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + #) + for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(96), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + with T.block("n_check"): + T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_1 in range(T.int64(128)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2k = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv9_local"): + v0 = T.axis.spatial( + T.int64(128), k_1 + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv8_local"): + v0 = T.axis.spatial( + T.int64(512), + k_1 * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_1 * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + lv6[v_i0, v_i1, v_k], + lv8_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + lv6[ + v_i0, v_i1, v_k + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_1 + ) + v1 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv9_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + * lv9_local[v0, v1] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(12288), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + +@T.prim_func(private=True) +def fused_decode4_NT_matmul3( + lv13: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), + lv14: T.Buffer((T.int64(128), T.int64(22016)), "float16"), + p_lv45: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") + var_NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(22016)), "float16" + ) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(22016)), "float16") + p_output0_intermediate = T.alloc_buffer((T.int64(22016), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(22016)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv13[v_i // T.int64(8), v_j], lv14[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv13[v_i // T.int64(8), v_j], + T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv14[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(22016), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(p_output0_intermediate[v_ax0, v_ax1]) + p_output0_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(22016), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv45[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv45[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + ) + + +@T.prim_func(private=True) +def fused_decode4_NT_matmul3_after( + lv8: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), + lv9: T.Buffer((T.int64(128), T.int64(22016)), "float16"), + p_lv6: T.handle, + p_output0: T.handle, +): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") + var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16") + + var_matmul_intermediate_local = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(22016)), "float16", scope="local" + ) + var_matmul_intermediate_local_batch = T.alloc_buffer( + (T.int64(1), ((n+7)//8) * 8, T.int64(22016)), "float16", scope="local" + ) + lv8_local = T.alloc_buffer((T.int64(512), T.int64(22016)), "uint32", scope="local") + lv9_local = T.alloc_buffer( + (T.int64(128), T.int64(22016)), "float16", scope="local" + ) + #lv6_shared = T.alloc_buffer( + # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" + #) + for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): + for i0_i1_i2_fused_0 in T.thread_binding(T.int64(172), thread="blockIdx.x"): + for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): + with T.block("n_check"): + T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) + for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2_init + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_1 in range(T.int64(128)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("matmul_init_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2k = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads() + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2k + ] = T.float16(0) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv9_local"): + v0 = T.axis.spatial( + T.int64(128), k_1 + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv9[v0, v1]) + T.writes(lv9_local[v0, v1]) + lv9_local[v0, v1] = lv9[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("lv8_local"): + v0 = T.axis.spatial( + T.int64(512), + k_1 * T.int64(4) + + k_2 + + ax0, + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads(lv8[v0, v1]) + T.writes(lv8_local[v0, v1]) + lv8_local[v0, v1] = lv8[v0, v1] + for k_3 in range(T.int64(8)): + for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + i0_i1_i2_fused_2, + ) + v_k = T.axis.reduce( + T.int64(4096), + k_1 * T.int64(32) + + k_2 * T.int64(8) + + k_3, + ) + T.reads( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + lv6[v_i0, v_i1, v_k], + lv8_local[v_k // T.int64(8), v_i2], + ) + T.writes( + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + ) + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] = var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + lv6[ + v_i0, v_i1, v_k + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv8_local[ + v_k // T.int64(8), v_i2 + ], + T.Cast( + "uint32", + v_k % T.int64(8), + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + ) + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(4)): + with T.block("multiple_scale"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + v0 = T.axis.spatial( + T.int64(128), + k_1 + ) + v1 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax1, + ) + T.reads( + lv9_local[v0, v1], + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ], + ) + T.writes( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + ) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( + var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + var_matmul_intermediate_local_batch[ + v_i0, v_i1, v_i2 + ] + * lv9_local[v0, v1] + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2 in T.vectorized(T.int64(4)): + with T.block("var_matmul_intermediate_local"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) + v_i2 = T.axis.spatial( + T.int64(22016), + i0_i1_i2_fused_0 * T.int64(128) + + i0_i1_i2_fused_1 * T.int64(4) + + ax2, + ) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + + + +@T.prim_func(private=True) +def fused_NT_matmul1_divide2_maximum1_minimum1_cast3(lv1593: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), p_lv1603: T.handle, p_lv1582: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1603 = T.match_buffer(p_lv1603, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1593[v_i0, v_i2, v_i1, v_k], lv1603[v_i0, v_i3, v_i1, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1593[v_i0, v_i2, v_i1, v_k] * lv1603[v_i0, v_i3, v_i1, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + +@T.prim_func(private=True) +def fused_NT_matmul1_divide2_maximum1_minimum1_cast3_after( + lv1593: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), + p_lv1603: T.handle, + p_lv1582: T.handle, + p_output0: T.handle +): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int64() + lv1603 = T.match_buffer(p_lv1603, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + var_matmul_intermediate_local = T.alloc_buffer( + (1, ((n + 7) // 8) * 8, 4096), "float16", scope="local" + ) + lv1593_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" + ) + for i_by in T.thread_binding(T.int64((n + 7) // 8), thread="blockIdx.y"): + for i_bx in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) + v_i2 = T.axis.spatial( + T.int64(4096), + i_bx * T.int64(128) + + i_tx * T.int64(4) + + i_v8, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + with T.block("lv1593_shared"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32), i_bx) + v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) + T.reads(lv1593[v_i0, v_i1, v_i2, v_i3]) + T.writes(lv1593_shared[v_i0, v_i1, v_i3]) + lv1593_shared[v_i0, v_i1, v_i3] = lv1593[v_i0, v_i1, v_i2, v_i3] + with T.block("matmul_compute"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) + v_i2 = T.axis.spatial(T.int64(32), i_bx) + v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) + v_ik = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_by * T.int64(8) + i_ty < n) + T.reads(lv1593_shared[v_i0, v_i1_1, v_i3], lv1603[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_ik]) + var_matmul_intermediate_local[v_i0, v_i1, v_ik] = var_matmul_intermediate_local[v_i0, v_i1, v_ik] + lv1603[v_i0, v_i1, v_i2, v_i3] * lv1593_shared[v_i0, v_i1_1, v_i3] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) + v_ik = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_ik]) + T.writes(lv1593_shared[v_i0, v_i1_1, v_i2]) + lv1593_shared[v_i0, v_i1_1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_ik] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_1"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_tx < T.int64(16)) + T.reads(lv1593_shared[v_i0, v_i1, v_i2]) + T.writes(lv1593_shared[v_i0, v_i1, v_i2]) + lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(64)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_2"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_tx < T.int64(8)) + T.reads(lv1593_shared[v_i0, v_i1, v_i2]) + T.writes(lv1593_shared[v_i0, v_i1, v_i2]) + lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(32)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_3"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_tx < T.int64(4)) + T.reads(lv1593_shared[v_i0, v_i1, v_i2]) + T.writes(lv1593_shared[v_i0, v_i1, v_i2]) + lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(16)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_4"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_tx < T.int64(2)) + T.reads(lv1593_shared[v_i0, v_i1, v_i2]) + T.writes(lv1593_shared[v_i0, v_i1, v_i2]) + lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(8)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_4"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_tx < T.int64(1)) + T.reads(lv1593_shared[v_i0, v_i1, v_i2]) + T.writes(lv1593_shared[v_i0, v_i1, v_i2]) + lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(4)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax0 in range(T.int64(1)): + with T.block("Output_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), i_bx) + v_i2 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i3 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) + v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128)) + T.where(i_by * T.int64(8) + i_ty < n) + T.reads(lv1593_shared[v_i0, v_i2, v_ik]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", T.min(T.max((lv1593_shared[v_i0, v_i2, v_ik] + lv1593_shared[v_i0, v_i2, v_ik + T.int64(1)] + + lv1593_shared[v_i0, v_i2, v_ik + T.int64(2)] + lv1593_shared[v_i0, v_i2, v_ik + T.int64(3)]) + * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[v_i0, T.int64(0), v_i2, v_i3])) + + + +# [gx,gy, gz] [lx, ly, lz] + +@T.prim_func(private=True) +def NT_matmul3(var_A: T.handle, var_B: T.handle, NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128), n): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_k, v_i2, v_i3], B[v_i0, v_i2, v_i1, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + NT_matmul[v_i0, v_i1, v_i2, v_i3] = NT_matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_k, v_i2, v_i3] * B[v_i0, v_i2, v_i1, v_k] + +@T.prim_func(private=True) +def NT_matmul3_after( + var_A: T.handle, + var_B: T.handle, + NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16") +): + + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_matmul_intermediate_local = T.alloc_buffer( + (1, 8, 4096), "float16", scope="local" + ) + B_shared = T.alloc_buffer( + (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" + ) + for i_bx in T.thread_binding(T.int64(32), thread="blockIdx.x"): + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(8), i_ty) + v_i2 = T.axis.spatial( + T.int64(4096), + i_bx * T.int64(128) + i_tx * T.int64(4) + + i_v8, + ) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) + for ax0 in range((n+255)//256): + with T.block("B_shared"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), i_bx) + v_i2 = T.axis.spatial(((n+255)//256) * 256, ax0 * T.int64(256) + i_ty * T.int64(32) + i_tx) + v_i2k = T.axis.spatial(T.int64(256), i_ty * T.int64(32) + i_tx) + #T.where(ax0 * T.int64(256) + i_ty * T.int64(32) + i_tx < n) + T.reads(B[v_i0, v_i1, T.int64(0), v_i2]) + T.writes(B_shared[v_i0, v_i1, v_i2k]) + B_shared[v_i0, T.int64(0), v_i2k] = T.if_then_else(v_i2 < n, B[v_i0, v_i1, T.int64(0), v_i2], T.float16(0)) + for ax1 in range(32): + #with T.block("n_check"): + # T.where(ax0 * T.int64(256) + ax1 * T.int64(8) + i_ty < n) + for i_v8 in T.vectorized(T.int64(4)): + with T.block("matmul_compute"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(((n+255)//256) * 256, ax0 * T.int64(256) + ax1 * T.int64(8) + i_ty) + v_i1_1 = T.axis.spatial(T.int64(8), i_ty) + v_i2 = T.axis.spatial(T.int64(32), i_bx) + v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) + v_ik = T.axis.spatial(T.int64(256), ax1 * T.int64(8) + i_ty) + v_ik1 = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.reads(B_shared[v_i0, T.int64(0), v_ik], A[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1]) + var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1] = var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1] + T.if_then_else(v_i1 < n, A[v_i0, v_i1, v_i2, v_i3], T.float16(0)) * B_shared[v_i0, T.int64(0), v_ik] + + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(8), i_ty) + v_i2 = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) + v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + T.writes(B_shared[v_i0, T.int64(0), v_ik]) + B_shared[v_i0, T.int64(0), v_ik] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("reduction_1"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_ty < T.int64(4)) + T.reads(B_shared[v_i0, v_i1, v_i2]) + T.writes(B_shared[v_i0, v_i1, v_i2]) + B_shared[v_i0, v_i1, v_i2] = B_shared[v_i0, v_i1, v_i2] + B_shared[v_i0, v_i1, v_i2 + T.int64(512)] + for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i_v8 in T.vectorized(T.int64(4)): + with T.block("Output_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32), i_bx) + v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) + v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) + T.where(i_ty < 1) + T.reads(B_shared[v_i0, v_i1, v_ik]) + T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) + NT_matmul[v_i0, v_i1, v_i2, v_i3] = B_shared[v_i0, v_i1, v_ik] + B_shared[v_i0, v_i1, v_ik + T.int64(128)] + B_shared[v_i0, v_i1, v_ik + T.int64(256)] + B_shared[v_i0, v_i1, v_ik + T.int64(384)] + +@T.prim_func(private=True) +def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") + rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + Ared_temp = T.alloc_buffer((T.int64(1), n)) + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("Ared_temp"): + v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) + T.reads(A[v_bsz, v_i, v_k]) + T.writes(Ared_temp[v_bsz, v_i]) + with T.init(): + Ared_temp[v_bsz, v_i] = T.float32(0) + Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k]) + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rms_norm"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) + T.writes(rms_norm_1[v_bsz, v_i, v_k]) + rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) + +@T.prim_func(private=True) +def rms_norm_after(var_A: T.handle, B: T.Buffer((4096,), "float16"), var_rms_norm: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (1, n, 4096), "float16") + rms_norm_1 = T.match_buffer(var_rms_norm, (1, n, 4096), "float16") + # with T.block("root"): + Ared_temp_shared = T.alloc_buffer((1, n), scope="shared") + Ared_temp_rf_local = T.alloc_buffer((64, 1, n), scope="local") + for ax0_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("Ared_temp_rf_init"): + vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) + T.reads() + T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) + Ared_temp_rf_local[vax1_fused_1, 0, v0] = T.float32(0) + for ax1_fused_0, u in T.grid(64, 1): + with T.block("Ared_temp_rf_update"): + vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) + T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0], A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) + T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) + Ared_temp_rf_local[vax1_fused_1, 0, v0] = Ared_temp_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) + for ax1_fused in range(1): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("Ared_temp"): + vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) + T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0]) + T.writes(Ared_temp_shared[0, v0]) + with T.init(): + Ared_temp_shared[0, v0] = T.float32(0) + Ared_temp_shared[0, v0] = Ared_temp_shared[0, v0] + Ared_temp_rf_local[vax1_fused_1, 0, v0] + for ax0_fused_0 in range(64): + for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("rms_norm"): + v0 = T.axis.spatial(n, ax0_fused) + v1 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1) + T.reads(B[v1], A[0, v0, v1], Ared_temp_shared[0, v0]) + T.writes(rms_norm_1[0, v0, v1]) + rms_norm_1[0, v0, v1] = T.Cast("float16", T.Cast("float32", B[v1]) * (T.Cast("float32", A[0, v0, v1]) / T.sqrt(Ared_temp_shared[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) + +@T.prim_func(private=True) +def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("slice"): + v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) + T.reads(A[v_i, n - T.int64(1), v_k]) + T.writes(slice_1[v_i, v_j, v_k]) + slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k] + +@T.prim_func(private=True) +def slice_after(var_A: T.handle, slice_1: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (1, n, 4096), "float16") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("slice"): + v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) + T.reads(A[0, n - 1, v0]) + T.writes(slice_1[0, 0, v0]) + slice_1[0, 0, v0] = A[0, n - 1, v0] + +@T.prim_func(private=True) +def NT_matmul2(var_A: T.handle, var_B: T.handle, var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), m, T.int64(32), T.int64(128)), "float16") + n = T.int64() + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, m), "float16") + NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), n, T.int64(32), T.int64(128), m): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_k, v_i2, v_i3], B[v_i0, v_i2, v_i1, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + NT_matmul[v_i0, v_i1, v_i2, v_i3] = NT_matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_k, v_i2, v_i3] * B[v_i0, v_i2, v_i1, v_k] + +@T.prim_func(private=True) +def NT_matmul2_after(var_A: T.handle, var_B: T.handle, var_NT_matmul: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + m = T.int32() + A = T.match_buffer(var_A, (1, m, 32, 128), "float16") + n = T.int32() + B = T.match_buffer(var_B, (1, 32, n, m), "float16") + NT_matmul = T.match_buffer(var_NT_matmul, (1, n, 32, 128), "float16") + # with T.block("root"): + NT_matmul_reindex_pad_local = T.alloc_buffer((32, 128, (n + 63) // 64 * 64), "float16", scope="local") + A_reindex_pad_shared = T.alloc_buffer((32, 128, (m + 15) // 16 * 16), "float16", scope="shared") + B_reindex_pad_shared = T.alloc_buffer((32, (n + 63) // 64 * 64, (m + 15) // 16 * 16), "float16", scope="shared") + for ax0_ax2_0_fused in T.thread_binding((n + 63) // 64 * 32, thread="blockIdx.y"): + for ax1_0 in T.thread_binding(4, thread="blockIdx.x"): + for ax2_1 in T.thread_binding(1, thread="vthread.y"): + for ax1_1 in T.thread_binding(1, thread="vthread.x"): + for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): + for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_3_init, ax1_3_init in T.grid(4, 4): + with T.block("NT_matmul_init"): + v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) + v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) + v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) + T.reads() + T.writes(NT_matmul_reindex_pad_local[v0, v1, v2]) + NT_matmul_reindex_pad_local[v0, v1, v2] = T.float16(0) + for ax3_0 in range((m + 15) // 16): + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(2): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) + v1 = T.axis.spatial(128, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.reads(A[0, v2, v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < m, A[0, v2, v0, v1], T.float16(0)) + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(4): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("B_reindex_pad_shared"): + v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) + v1 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.reads(B[0, v0, v1, v2]) + T.writes(B_reindex_pad_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, B[0, v0, v1, v2], T.float16(0)) + for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): + with T.block("NT_matmul_update"): + v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) + v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) + v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) + v3 = T.axis.reduce((m + 15) // 16 * 16, ax3_0 * 16 + ax3_1) + T.reads(NT_matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3]) + T.writes(NT_matmul_reindex_pad_local[v0, v1, v2]) + NT_matmul_reindex_pad_local[v0, v1, v2] = NT_matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3] + for ax0, ax1, ax2_0 in T.grid(1, 4, 2): + for ax2_1_1 in T.vectorized(2): + with T.block("NT_matmul_reindex_pad_local"): + v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64) + ax0) + v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_2 * 4 + ax1) + v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) + T.reads(NT_matmul_reindex_pad_local[v0, v1, v2]) + T.writes(NT_matmul[0, v2, v0, v1]) + if v2 < n: + NT_matmul[0, v2, v0, v1] = NT_matmul_reindex_pad_local[v0, v1, v2] + + +def get_dict_key(func): + return tvm.ir.structural_hash(func), func + + +tir_dispatch_dict = { + get_dict_key(fused_decode4_matmul3): fused_decode4_matmul3_after, + get_dict_key( + fused_decode6_fused_matmul7_add1 + ): fused_decode6_fused_matmul7_add1_after, + get_dict_key( + fused_decode5_fused_matmul6_multiply1 + ): fused_decode5_fused_matmul6_multiply1_after, + get_dict_key( + fused_decode5_fused_matmul6_silu1 + ): fused_decode5_fused_matmul6_silu1_after, + get_dict_key( + fused_decode4_fused_matmul4_add1 + ): fused_decode4_fused_matmul4_add1_after, + get_dict_key( + fused_decode3_fused_matmul1_cast2 + ): fused_decode3_fused_matmul1_cast2_after, + get_dict_key( + fused_decode2_fused_NT_matmul3_add + ): fused_decode2_fused_NT_matmul3_add_after, + get_dict_key(fused_decode_NT_matmul): fused_decode_NT_matmul_after, + get_dict_key(fused_decode2_NT_matmul): fused_decode2_NT_matmul_after, + get_dict_key(fused_decode4_NT_matmul3): fused_decode4_NT_matmul3_after, + get_dict_key( + fused_decode1_fused_NT_matmul2_silu + ): fused_decode1_fused_NT_matmul2_silu_after, + get_dict_key( + fused_decode1_fused_NT_matmul2_multiply + ): fused_decode1_fused_NT_matmul2_multiply_after, + get_dict_key( + fused_decode_fused_NT_matmul_add + ): fused_decode_fused_NT_matmul_add_after, + get_dict_key( + fused_decode4_fused_matmul6_add4 + ): sch_fused_decode4_fused_matmul6_add4(fused_decode4_fused_matmul6_add4), + get_dict_key( + fused_decode6_fused_matmul9_add7_cast8_cast12_add5 + ): sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5( + fused_decode6_fused_matmul9_add7_cast8_cast12_add5 + ), + get_dict_key( + fused_decode5_fused_matmul8_add6_gelu1_cast11 + ): sch_fused_decode5_fused_matmul8_add6_gelu1_cast11( + fused_decode5_fused_matmul8_add6_gelu1_cast11 + ), + get_dict_key(fused_decode81_fused_matmul1_cast2 + ): sch_fused_decode81_fused_matmul1_cast2(fused_decode81_fused_matmul1_cast2 + ), + get_dict_key( + fused_decode4_fused_matmul6_add4_add5 + ): sch_fused_decode4_fused_matmul6_add4_add5(fused_decode4_fused_matmul6_add4_add5), + get_dict_key(fused_decode3_matmul3): sch_fused_decode3_matmul3( + fused_decode3_matmul3 + ), + get_dict_key( + fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7 + ): sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7( + fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7 + ), + get_dict_key( + fused_decode2_fused_NT_matmul3_add6_gelu1_cast11 + ): fused_decode2_fused_NT_matmul3_add6_gelu1_cast11_after, + get_dict_key( + fused_decode1_fused_NT_matmul1_add4 + ): fused_decode1_fused_NT_matmul1_add4_after, + get_dict_key( + fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5 + ): fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_after, + get_dict_key( + fused_decode1_fused_NT_matmul1_add4_add5 + ): fused_decode1_fused_NT_matmul1_add4_add5_after, + get_dict_key( + fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7 + ): fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7_after, + get_dict_key(fused_fused_decode9_matmul7): fused_fused_decode9_matmul7_after, + get_dict_key(fused_fused_decode7_matmul4): fused_fused_decode7_matmul4_after, + get_dict_key(fused_NT_matmul1_divide2_maximum1_minimum1_cast3): fused_NT_matmul1_divide2_maximum1_minimum1_cast3_after, + get_dict_key(NT_matmul3): NT_matmul3_after, + get_dict_key(slice): slice_after, + get_dict_key(rms_norm): rms_norm_after, + get_dict_key(NT_matmul2): NT_matmul2_after, +} + + +def lookup_func(func): + for (hash_value, func_before), f_after in tir_dispatch_dict.items(): + if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( + func, func_before + ): + return f_after + return None + + +@tvm.transform.module_pass(opt_level=0, name="DispatchTIROperatorAdreno") +class DispatchTIROperatorAdreno: + def transform_module( + self, mod: IRModule, ctx: tvm.transform.PassContext + ) -> IRModule: + for gv in mod.functions: + scheduled_func = lookup_func(mod[gv]) + if scheduled_func is not None: + mod[gv] = scheduled_func + + return mod diff --git a/mlc_llm/dispatch/gpt_neox/__init__.py b/mlc_llm/dispatch/gpt_neox/__init__.py new file mode 100644 index 0000000..cdf7c94 --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/__init__.py @@ -0,0 +1,13 @@ +def lookup(func): + from . import dolly_v2_3b, redpajama_incite_chat_3b_v1, redpajama_q4f32 + + ret = dolly_v2_3b.lookup(func) + if ret is not None: + return ret + ret = redpajama_incite_chat_3b_v1.lookup(func) + if ret is not None: + return ret + ret = redpajama_q4f32.lookup(func) + if ret is not None: + return ret + return None diff --git a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py new file mode 100644 index 0000000..274f081 --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py @@ -0,0 +1,1034 @@ +# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals +import tvm +from tvm import tir +from tvm.script import tir as T + +from .dolly_v2_3b_mod import Module as MOD + + +# fmt: off +def fused_NT_matmul1_add3(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[40, 2, 16, 2, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[320, 8, 1]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + l52, l53, l54 = sch.get_loops(block=b47)[-3:] + sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + l62, l63 = sch.get_loops(block=b57)[-2:] + sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=2) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + l71 = sch.get_loops(block=b47)[-1] + _, l73, l74 = sch.split(loop=l71, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l79 = sch.get_loops(block=b57)[-1] + _, l81, l82 = sch.split(loop=l79, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + b120 = sch.get_block(name="NT_matmul", func_name="main") + l124 = sch.get_loops(block=b120)[4] + sch.decompose_reduction(block=b120, loop=l124) + + b1 = sch.get_block("lv10_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b84, b85, b86, b87 = sch.get_child_blocks(b83) + l88 = sch.get_loops(block=b84)[0] + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + l95 = sch.get_loops(block=b85)[0] + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + l102 = sch.get_loops(block=b86)[0] + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + l114 = sch.get_loops(block=b87)[0] + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul1_add3_add5_add5(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_add_1", func_name="main") + b3 = sch.get_block(name="T_add_2", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8 = sch.get_loops(block=b0) + v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l14, l15, l16, l17, l18 = sch.split(loop=l5, factors=[v9, v10, v11, v12, v13], preserve_unit_iters=True) + v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[2, 8, 4, 2, 1]) + l24, l25, l26, l27, l28 = sch.split(loop=l6, factors=[v19, v20, v21, v22, v23], preserve_unit_iters=True) + v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[20, 1, 64, 2, 1]) + l34, l35, l36, l37, l38 = sch.split(loop=l7, factors=[v29, v30, v31, v32, v33], preserve_unit_iters=True) + v39, v40, v41 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[320, 1, 8]) + l42, l43, l44 = sch.split(loop=l8, factors=[v39, v40, v41], preserve_unit_iters=True) + sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l43, l17, l27, l37, l44, l18, l28, l38) + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="blockIdx.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="vthread.x") + l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) + sch.bind(loop=l47, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b48 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b48, loop=l47, preserve_unit_loops=True, index=-1) + b49 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b49, loop=l42, preserve_unit_loops=True, index=-1) + l54, l55, l56 = sch.get_loops(block=b49)[-3:] + sch.fuse(l54, l55, l56, preserve_unit_iters=True) + v58 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v58) + b59 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b59, loop=l42, preserve_unit_loops=True, index=-1) + l64, l65 = sch.get_loops(block=b59)[-2:] + sch.fuse(l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v68 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v68) + sch.enter_postproc() + sch.unannotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch") + l73 = sch.get_loops(block=b49)[-1] + _, l75, l76 = sch.split(loop=l73, factors=[None, 256, 4], preserve_unit_iters=True) + sch.vectorize(loop=l76) + sch.bind(loop=l75, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch") + l81 = sch.get_loops(block=b59)[-1] + _, l83, l84 = sch.split(loop=l81, factors=[None, 256, 4], preserve_unit_iters=True) + sch.vectorize(loop=l84) + sch.bind(loop=l83, thread_axis="threadIdx.x") + b85 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.unroll_explicit") + b122 = sch.get_block(name="NT_matmul", func_name="main") + l126 = sch.get_loops(block=b122)[4] + sch.decompose_reduction(block=b122, loop=l126) + + b1 = sch.get_block("lv48_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b86, b87, b88, b89 = sch.get_child_blocks(b85) + l90 = sch.get_loops(block=b86)[0] + sch.annotate(block_or_loop=l90, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l90, ann_key="pragma_unroll_explicit", ann_val=1) + l97 = sch.get_loops(block=b87)[0] + sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) + l104 = sch.get_loops(block=b88)[0] + sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) + l116 = sch.get_loops(block=b89)[0] + sch.annotate(block_or_loop=l116, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l116, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul1_add3_add5_add5_cast5(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_add_1", func_name="main") + b3 = sch.get_block(name="T_add_2", func_name="main") + b4 = sch.get_block(name="compute", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l6, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 2, 16, 2, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l7, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[64, 2, 10, 1, 2]) + l35, l36, l37, l38, l39 = sch.split(loop=l8, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[64, 20, 2]) + l43, l44, l45 = sch.split(loop=l9, factors=[v40, v41, v42], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l16, l26, l36, l17, l27, l37, l43, l44, l18, l28, l38, l45, l19, l29, l39) + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="blockIdx.x") + l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) + sch.bind(loop=l47, thread_axis="vthread.x") + l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) + sch.bind(loop=l48, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b49 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b49, loop=l48, preserve_unit_loops=True, index=-1) + b50 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b50, loop=l43, preserve_unit_loops=True, index=-1) + l55, l56, l57 = sch.get_loops(block=b50)[-3:] + sch.fuse(l55, l56, l57, preserve_unit_iters=True) + v59 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch", ann_val=v59) + b60 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l43, preserve_unit_loops=True, index=-1) + l65, l66 = sch.get_loops(block=b60)[-2:] + sch.fuse(l65, l66, preserve_unit_iters=True) + v68 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v68) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v69 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v69) + sch.enter_postproc() + sch.unannotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch") + l74 = sch.get_loops(block=b50)[-1] + _, l76, l77 = sch.split(loop=l74, factors=[None, 160, 4], preserve_unit_iters=True) + sch.vectorize(loop=l77) + sch.bind(loop=l76, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + l82 = sch.get_loops(block=b60)[-1] + _, l84, l85 = sch.split(loop=l82, factors=[None, 160, 2], preserve_unit_iters=True) + sch.vectorize(loop=l85) + sch.bind(loop=l84, thread_axis="threadIdx.x") + b86 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b86, ann_key="meta_schedule.unroll_explicit") + b123 = sch.get_block(name="NT_matmul", func_name="main") + l127 = sch.get_loops(block=b123)[4] + sch.decompose_reduction(block=b123, loop=l127) + + b1 = sch.get_block("lv1815_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b87, b88, b89, b90 = sch.get_child_blocks(b86) + l91 = sch.get_loops(block=b87)[0] + sch.annotate(block_or_loop=l91, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l91, ann_key="pragma_unroll_explicit", ann_val=1) + l98 = sch.get_loops(block=b88)[0] + sch.annotate(block_or_loop=l98, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l98, ann_key="pragma_unroll_explicit", ann_val=1) + l105 = sch.get_loops(block=b89)[0] + sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) + l117 = sch.get_loops(block=b90)[0] + sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul3_add4_gelu1(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_multiply", func_name="main") + b3 = sch.get_block(name="compute", func_name="main") + b4 = sch.get_block(name="compute_1", func_name="main") + b5 = sch.get_block(name="compute_2", func_name="main") + b6 = sch.get_block(name="T_multiply_1", func_name="main") + b7 = sch.get_block(name="T_add_1", func_name="main") + b8 = sch.get_block(name="T_multiply_2", func_name="main") + b9 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l10, l11, l12, l13 = sch.get_loops(block=b0) + v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l19, l20, l21, l22, l23 = sch.split(loop=l10, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True) + v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) + l29, l30, l31, l32, l33 = sch.split(loop=l11, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) + v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l12, n=5, max_innermost_factor=64, decision=[160, 4, 16, 1, 1]) + l39, l40, l41, l42, l43 = sch.split(loop=l12, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) + v44, v45, v46 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64, decision=[64, 20, 2]) + l47, l48, l49 = sch.split(loop=l13, factors=[v44, v45, v46], preserve_unit_iters=True) + sch.reorder(l19, l29, l39, l20, l30, l40, l21, l31, l41, l47, l48, l22, l32, l42, l49, l23, l33, l43) + l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) + sch.bind(loop=l50, thread_axis="blockIdx.x") + l51 = sch.fuse(l20, l30, l40, preserve_unit_iters=True) + sch.bind(loop=l51, thread_axis="vthread.x") + l52 = sch.fuse(l21, l31, l41, preserve_unit_iters=True) + sch.bind(loop=l52, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b53 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b53, loop=l52, preserve_unit_loops=True, index=-1) + b54 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b54, loop=l47, preserve_unit_loops=True, index=-1) + l59, l60, l61 = sch.get_loops(block=b54)[-3:] + sch.fuse(l59, l60, l61, preserve_unit_iters=True) + v63 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch", ann_val=v63) + b64 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b64, loop=l47, preserve_unit_loops=True, index=-1) + l69, l70 = sch.get_loops(block=b64)[-2:] + sch.fuse(l69, l70, preserve_unit_iters=True) + v72 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b64, ann_key="meta_schedule.cooperative_fetch", ann_val=v72) + sch.compute_inline(block=b7) + sch.compute_inline(block=b6) + sch.compute_inline(block=b5) + sch.compute_inline(block=b4) + sch.compute_inline(block=b3) + sch.compute_inline(block=b2) + sch.compute_inline(block=b1) + sch.reverse_compute_inline(block=b8) + v73 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b9, ann_key="meta_schedule.unroll_explicit", ann_val=v73) + sch.enter_postproc() + sch.unannotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch") + l85 = sch.get_loops(block=b54)[-1] + _, l87, l88 = sch.split(loop=l85, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l88) + sch.bind(loop=l87, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b64, ann_key="meta_schedule.cooperative_fetch") + l93 = sch.get_loops(block=b64)[-1] + _, l95, l96 = sch.split(loop=l93, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l96) + sch.bind(loop=l95, thread_axis="threadIdx.x") + b97 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b97, ann_key="meta_schedule.unroll_explicit") + b138 = sch.get_block(name="NT_matmul", func_name="main") + l142 = sch.get_loops(block=b138)[4] + sch.decompose_reduction(block=b138, loop=l142) + + b1 = sch.get_block("lv52_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + b98, b99, b100, b101, b102 = sch.get_child_blocks(b97) + l103 = sch.get_loops(block=b98)[0] + sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) + l110 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l110, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l110, ann_key="pragma_unroll_explicit", ann_val=1) + l117 = sch.get_loops(block=b100)[0] + sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) + l129 = sch.get_loops(block=b101)[0] + sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) + l135 = sch.get_loops(block=b102)[0] + sch.annotate(block_or_loop=l135, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l135, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul4_add3(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 16, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 1, 5, 2, 2]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[256, 20, 2]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + l52, l53, l54 = sch.get_loops(block=b47)[-3:] + sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + l62, l63 = sch.get_loops(block=b57)[-2:] + sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=2) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + l71 = sch.get_loops(block=b47)[-1] + _, l73, l74 = sch.split(loop=l71, factors=[None, 80, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l79 = sch.get_loops(block=b57)[-1] + _, l81 = sch.split(loop=l79, factors=[None, 80], preserve_unit_iters=True) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + b118 = sch.get_block(name="NT_matmul", func_name="main") + l122 = sch.get_loops(block=b118)[4] + sch.decompose_reduction(block=b118, loop=l122) + + b1 = sch.get_block("lv56_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b83, b84, b85, b86 = sch.get_child_blocks(b82) + l87 = sch.get_loops(block=b83)[0] + sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) + l94 = sch.get_loops(block=b84)[0] + sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) + l100 = sch.get_loops(block=b85)[0] + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + l112 = sch.get_loops(block=b86)[0] + sch.annotate(block_or_loop=l112, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l112, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul_divide_maximum_minimum_cast2(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="compute", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) + v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) + v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[8, 2, 2, 1, 1]) + l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) + v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) + v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 1, 32, 1, 1]) + l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) + v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[2, 1, 40]) + l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) + sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="blockIdx.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="vthread.x") + l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) + sch.bind(loop=l59, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) + b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) + l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] + sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) + v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) + l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] + sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) + v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) + sch.enter_postproc() + sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") + l88 = sch.get_loops(block=b61)[-1] + _, l90 = sch.split(loop=l88, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") + l95 = sch.get_loops(block=b72)[-1] + _, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + b98 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") + + b136 = sch.get_block(name="NT_matmul", func_name="main") + l140 = sch.get_loops(block=b136)[4] + sch.decompose_reduction(block=b136, loop=l140) + + b1 = sch.get_block("lv1870_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b99, b100, b101, b102 = sch.get_child_blocks(b98) + l103 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) + l109 = sch.get_loops(block=b100)[0] + sch.annotate(block_or_loop=l109, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l109, ann_key="pragma_unroll_explicit", ann_val=1) + l115 = sch.get_loops(block=b101)[0] + sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) + l129 = sch.get_loops(block=b102)[0] + sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul2_divide1_maximum1_minimum1_cast7(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l4, [None, 32]) + sch.reorder(l6, l8, l1, l2, l7, l9, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="compute", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) + v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) + v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) + l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) + v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 4, 1, 16]) + l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) + v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 2, 16, 1, 1]) + l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) + v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[10, 1, 8]) + l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) + sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="blockIdx.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="vthread.x") + l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) + sch.bind(loop=l59, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) + b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) + l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] + sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) + v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) + l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] + sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) + v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) + sch.enter_postproc() + sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") + l88 = sch.get_loops(block=b61)[-1] + _, l90, l91 = sch.split(loop=l88, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l91) + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") + l96 = sch.get_loops(block=b72)[-1] + _, l98, l99 = sch.split(loop=l96, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l99) + sch.bind(loop=l98, thread_axis="threadIdx.x") + b100 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") + b140 = sch.get_block(name="NT_matmul", func_name="main") + l144 = sch.get_loops(block=b140)[5] + sch.decompose_reduction(block=b140, loop=l144) + + b1 = sch.get_block("lv35_pad") + sch.compute_inline(b1) + b1 = sch.get_block("lv36_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, b101, b102, b103, b104 = sch.get_child_blocks(b100) + l105 = sch.get_loops(block=b101)[0] + sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) + l112 = sch.get_loops(block=b102)[0] + sch.annotate(block_or_loop=l112, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l112, ann_key="pragma_unroll_explicit", ann_val=1) + l119 = sch.get_loops(block=b103)[0] + sch.annotate(block_or_loop=l119, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l119, ann_key="pragma_unroll_explicit", ann_val=1) + l133 = sch.get_loops(block=b104)[0] + sch.annotate(block_or_loop=l133, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l133, ann_key="pragma_unroll_explicit", ann_val=1) + + +def matmul1(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 1, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(l1, l2, l3, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[8, 2, 2, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[5, 1, 16, 1, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[4, 8, 1]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86 = sch.split(loop=l84, factors=[None, 32], preserve_unit_iters=True) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l91 = sch.get_loops(block=b68)[-1] + _, l93, l94 = sch.split(loop=l91, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l94) + sch.bind(loop=l93, thread_axis="threadIdx.x") + b95 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + + b96, b97, b98, b99 = sch.get_child_blocks(b95) + l100 = sch.get_loops(block=b96)[0] + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + l106 = sch.get_loops(block=b97)[0] + sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) + l113 = sch.get_loops(block=b98)[0] + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + l127 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + b134 = sch.get_block(name="matmul", func_name="main") + l138 = sch.get_loops(block=b134)[3] + sch.decompose_reduction(block=b134, loop=l138) + + +def matmul8(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + s0, s1 = sch.split(l3, [None, 32]) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(s0, l1, l2, s1, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 32, 4, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 2, 5, 1, 4]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[2, 2, 8]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86, l87 = sch.split(loop=l84, factors=[None, 40, 2], preserve_unit_iters=True) + sch.vectorize(loop=l87) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l92 = sch.get_loops(block=b68)[-1] + _, l94 = sch.split(loop=l92, factors=[None, 40], preserve_unit_iters=True) + sch.bind(loop=l94, thread_axis="threadIdx.x") + b95 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + b1 = sch.get_block("matmul_pad") + sch.reverse_compute_inline(b1) + + b96, b97, b98, b99 = sch.get_child_blocks(b95) + l100 = sch.get_loops(block=b96)[0] + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + l107 = sch.get_loops(block=b97)[0] + sch.annotate(block_or_loop=l107, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l107, ann_key="pragma_unroll_explicit", ann_val=1) + l113 = sch.get_loops(block=b98)[0] + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + l127 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + b134 = sch.get_block(name="matmul", func_name="main") + l138= sch.get_loops(block=b134)[4] + sch.decompose_reduction(block=b134, loop=l138) + + +def fused_layer_norm1_cast6(sch: tir.Schedule): + b0 = sch.get_block(name="A_red_temp", func_name="main") + b1 = sch.get_block(name="T_layer_norm", func_name="main") + b2 = sch.get_block(name="compute", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + sch.reverse_compute_inline(block=b2) + v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=5) + l5, l6, l7 = sch.get_loops(block=b0) + l8, l9 = sch.split(loop=l7, factors=[None, v4], preserve_unit_iters=True) + sch.bind(loop=l9, thread_axis="threadIdx.x") + v10 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=1) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v10) + l11, l12, l13 = sch.get_loops(block=b1) + l14 = sch.fuse(l11, l12, l13, preserve_unit_iters=True) + l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l16, l17, l15) + sch.bind(loop=l16, thread_axis="blockIdx.x") + sch.bind(loop=l17, thread_axis="threadIdx.x") + l18, l19, l20, l21 = sch.get_loops(block=b0) + l22 = sch.fuse(l18, l19, preserve_unit_iters=True) + sch.bind(loop=l22, thread_axis="blockIdx.x") + sch.enter_postproc() + b23 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b23, ann_key="meta_schedule.unroll_explicit") + b24, b25 = sch.get_child_blocks(b23) + l26, l27, l28 = sch.get_loops(block=b24) + sch.annotate(block_or_loop=l26, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l26, ann_key="pragma_unroll_explicit", ann_val=1) + l29, l30, l31 = sch.get_loops(block=b25) + sch.annotate(block_or_loop=l29, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l29, ann_key="pragma_unroll_explicit", ann_val=1) + + +def layer_norm1(sch: tir.Schedule): + b0 = sch.get_block(name="A_red_temp", func_name="main") + b1 = sch.get_block(name="T_layer_norm", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=4) + l4, l5, l6 = sch.get_loops(block=b0) + l7, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) + sch.bind(loop=l8, thread_axis="threadIdx.x") + v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) + l10, l11, l12 = sch.get_loops(block=b1) + l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) + l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l15, l16, l14) + sch.bind(loop=l15, thread_axis="blockIdx.x") + sch.bind(loop=l16, thread_axis="threadIdx.x") + l17, l18, l19, l20 = sch.get_loops(block=b0) + l21 = sch.fuse(l17, l18, preserve_unit_iters=True) + sch.bind(loop=l21, thread_axis="blockIdx.x") + sch.enter_postproc() + b22 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") + b23, b24 = sch.get_child_blocks(b22) + l25, l26, l27 = sch.get_loops(block=b23) + sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) + l28, l29, l30 = sch.get_loops(block=b24) + sch.annotate(block_or_loop=l28, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l28, ann_key="pragma_unroll_explicit", ann_val=1) + + +def sch_softmax_cast(cast_to_fp16: bool): + def f(sch: tir.Schedule): + if cast_to_fp16: + b_cast = sch.get_block("compute") + sch.reverse_compute_inline(b_cast) + b0 = sch.get_block("T_softmax_exp") + sch.compute_inline(b0) + b1 = sch.get_block("T_softmax_norm") + l2, l3, l4, l5 = sch.get_loops(b1) + _, l7 = sch.split(l5, [None, 128]) + sch.bind(l7, "threadIdx.x") + b8 = sch.get_block("T_softmax_expsum") + sch.compute_at(b8, l4) + sch.set_scope(b8, 0, "shared") + _, _, _, l12 = sch.get_loops(b8) + _, l14 = sch.split(l12, [None, 128]) + sch.bind(l14, "threadIdx.x") + b15 = sch.get_block("T_softmax_maxelem") + sch.compute_at(b15, l4) + sch.set_scope(b15, 0, "shared") + _, _, _, l19 = sch.get_loops(b15) + _, l21 = sch.split(l19, [None, 128]) + sch.bind(l21, "threadIdx.x") + l22 = sch.fuse(l2, l3, l4) + sch.bind(l22, "blockIdx.x") + return f + + +@T.prim_func +def softmax_cast_mxn_before(p_lv37: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv37[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv37[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv37[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv37[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + +@T.prim_func +def softmax_cast_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") + # with T.block("root"): + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float32(0)) + for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): + for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) + v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) + T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) + if v_i2_o * T.int64(32) + v_i2_i < n: + T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.Cast("float16", T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i]) + + +# fmt: on + + +def _get_dict(): + tvm.ir.assert_structural_equal(MOD["fused_softmax1_cast8"], softmax_cast_mxn_before) + func_dict = { + softmax_cast_mxn_before: softmax_cast_mxn_after, + } + for name, func in [ + ("fused_NT_matmul1_add3", fused_NT_matmul1_add3), + ("fused_NT_matmul1_add3_add5_add5", fused_NT_matmul1_add3_add5_add5), + ( + "fused_NT_matmul1_add3_add5_add5_cast5", + fused_NT_matmul1_add3_add5_add5_cast5, + ), + ("fused_NT_matmul3_add4_gelu1", fused_NT_matmul3_add4_gelu1), + ("fused_NT_matmul4_add3", fused_NT_matmul4_add3), + ( + "fused_NT_matmul_divide_maximum_minimum_cast2", + fused_NT_matmul_divide_maximum_minimum_cast2, + ), + ( + "fused_NT_matmul2_divide1_maximum1_minimum1_cast7", + fused_NT_matmul2_divide1_maximum1_minimum1_cast7, + ), + ("matmul1", matmul1), + ("matmul8", matmul8), + ("fused_softmax_cast3", sch_softmax_cast(True)), + ("fused_layer_norm1_cast6", fused_layer_norm1_cast6), + ("layer_norm1", layer_norm1), + ]: + sch = tir.Schedule(MOD[name]) + func(sch) + func_dict[MOD[name]] = sch.mod["main"] + return { + (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) + for k, v in func_dict.items() + } + + +DICT = _get_dict() + + +def lookup(func): + for (hash_value, func_before), f_after in DICT.items(): + if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( + func, func_before + ): + return f_after + return None diff --git a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py new file mode 100644 index 0000000..e3ff44b --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py @@ -0,0 +1,511 @@ +# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements +from tvm.script import ir as I +from tvm.script import tir as T + +""" +Operators: +- fused_NT_matmul1_add3 +- fused_NT_matmul1_add3_add5_add5 +- fused_NT_matmul1_add3_add5_add5_cast5 +- fused_NT_matmul2_divide1_maximum1_minimum1_cast7 +- fused_NT_matmul3_add4_gelu1 +- fused_NT_matmul4_add3 +- fused_NT_matmul_divide_maximum_minimum_cast2 +- matmul1 +- matmul8 +- fused_softmax1_cast8 +- fused_softmax_cast3 +- fused_layer_norm1_cast6 +- layer_norm1 +""" + +# fmt: off + +@I.ir_module +class Module: + @T.prim_func + def fused_NT_matmul1_add3(p_lv10: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv10 = T.match_buffer(p_lv10, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv10[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv10[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + + @T.prim_func + def fused_NT_matmul1_add3_add5_add5(p_lv48: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv60: T.handle, p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(2560)), "float16") + lv60 = T.match_buffer(p_lv60, (T.int64(1), n, T.int64(2560)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv48[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv48[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv60[v_ax0, v_ax1, v_ax2], var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] = lv60[v_ax0, v_ax1, v_ax2] + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul1_add3_add5_add5_cast5(p_lv1815: T.handle, lv2496: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias189: T.Buffer((T.int64(2560),), "float16"), p_lv1827: T.handle, p_lv1772: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1815 = T.match_buffer(p_lv1815, (T.int64(1), n, T.int64(2560)), "float16") + lv1827 = T.match_buffer(p_lv1827, (T.int64(1), n, T.int64(2560)), "float16") + lv1772 = T.match_buffer(p_lv1772, (T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1815[v_i0, v_i1, v_k], lv2496[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1815[v_i0, v_i1, v_k] * lv2496[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias189[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias189[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1827[v_ax0, v_ax1, v_ax2], var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = lv1827[v_ax0, v_ax1, v_ax2] + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv1772[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv1772[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_2[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_2[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_NT_matmul2_divide1_maximum1_minimum1_cast7(p_lv35: T.handle, p_lv36: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + m = T.int64() + lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv35[v_i0, v_i1, v_i2, v_k], lv36[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv35[v_i0, v_i1, v_i2, v_k] * lv36[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_NT_matmul3_add4_gelu1(p_lv52: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2560)), "float16") + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute_2 = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv52[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv52[v_i0, v_i1, v_k] * lv1201[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float16(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.Cast("float32", T_multiply[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(compute[v_i0, v_i1, v_i2]) + T.writes(compute_1[v_i0, v_i1, v_i2]) + compute_1[v_i0, v_i1, v_i2] = T.erf(compute[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(compute_1[v_i0, v_i1, v_i2]) + T.writes(compute_2[v_i0, v_i1, v_i2]) + compute_2[v_i0, v_i1, v_i2] = T.Cast("float16", compute_1[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute_2[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute_2[v_ax0, v_ax1, v_ax2] * T.float16(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float16(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul4_add3(p_lv56: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv56[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv1208[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + + @T.prim_func + def fused_NT_matmul_divide_maximum_minimum_cast2(lv1869: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv1870: T.handle, p_lv1839: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1870 = T.match_buffer(p_lv1870, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + lv1839 = T.match_buffer(p_lv1839, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1869[v_i0, v_i1, v_i2, v_k], lv1870[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1869[v_i0, v_i1, v_i2, v_k] * lv1870[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1839[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1839[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_softmax1_cast8(p_lv43: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv43[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv43[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv43[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_softmax_cast3(p_lv1877: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1877 = T.match_buffer(p_lv1877, (T.int64(1), T.int64(32), T.int64(1), n)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1877[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1877[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv1877[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1877[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def matmul1(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def matmul8(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def fused_layer_norm1_cast6(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(lv6[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) + T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + +# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py new file mode 100644 index 0000000..7c9d1c5 --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py @@ -0,0 +1,972 @@ +# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals +import tvm +from tvm import tir +from tvm.script import tir as T + +from .redpajama_incite_chat_3b_v1_mod import Module as MOD + +# fmt: off + +def fused_NT_matmul1_add4(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 2, 32, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[80, 1, 4, 4, 2]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[128, 5, 4]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + l52, l53, l54 = sch.get_loops(block=b47)[-3:] + sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + l62, l63 = sch.get_loops(block=b57)[-2:] + sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + l71 = sch.get_loops(block=b47)[-1] + _, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l79 = sch.get_loops(block=b57)[-1] + _, l81 = sch.split(loop=l79, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv9_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b85, _ = sch.get_child_blocks(b82) + l100 = sch.get_loops(block=b85)[0] + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + b118 = sch.get_block(name="NT_matmul", func_name="main") + l122 = sch.get_loops(block=b118)[4] + sch.decompose_reduction(block=b118, loop=l122) + + +def fused_NT_matmul1_add4_add5(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_add_1", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l4, l5, l6, l7 = sch.get_loops(block=b0) + v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) + v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 8, 4, 2, 1]) + l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) + v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[40, 2, 16, 1, 2]) + l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) + v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[160, 4, 4]) + l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) + sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) + + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="blockIdx.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="vthread.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) + b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) + l53, l54, l55 = sch.get_loops(block=b48)[-3:] + sch.fuse(l53, l54, l55, preserve_unit_iters=True) + v57 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) + b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) + l63, l64 = sch.get_loops(block=b58)[-2:] + sch.fuse(l63, l64, preserve_unit_iters=True) + v66 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) + sch.enter_postproc() + sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") + l72 = sch.get_loops(block=b48)[-1] + _, l74, l75 = sch.split(loop=l72, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l75) + sch.bind(loop=l74, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") + l80 = sch.get_loops(block=b58)[-1] + _, l82, l83 = sch.split(loop=l80, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l83) + sch.bind(loop=l82, thread_axis="threadIdx.x") + b84 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b84, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv49_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b87, _ = sch.get_child_blocks(b84) + l103 = sch.get_loops(block=b87)[0] + sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) + + b121 = sch.get_block(name="NT_matmul", func_name="main") + l125 = sch.get_loops(block=b121)[4] + sch.decompose_reduction(block=b121, loop=l125) + + +def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l4, [None, 32]) + sch.reorder(l6, l8, l1, l2, l7, l9, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="compute", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) + v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) + v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) + l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) + v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 16, 2, 4]) + l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) + v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) + l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) + v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[4, 20, 1]) + l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) + sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="blockIdx.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="vthread.x") + l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) + sch.bind(loop=l59, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) + b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) + l66, l67, l68, l69 = sch.get_loops(block=b61)[-4: ] + sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) + v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) + l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] + sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) + v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) + sch.enter_postproc() + sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") + l88 = sch.get_loops(block=b61)[-1] + _, l90, l91 = sch.split(loop=l88, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l91) + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") + l96 = sch.get_loops(block=b72)[-1] + _, l98, l99 = sch.split(loop=l96, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l99) + sch.bind(loop=l98, thread_axis="threadIdx.x") + b100 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv36_pad") + sch.compute_inline(b1) + b1 = sch.get_block("lv37_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + b140 = sch.get_block(name="NT_matmul", func_name="main") + l144 = sch.get_loops(block=b140)[5] + sch.decompose_reduction(block=b140, loop=l144) + + +def fused_NT_matmul3_add6_gelu1_cast11(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_multiply", func_name="main") + b3 = sch.get_block(name="compute", func_name="main") + b4 = sch.get_block(name="T_multiply_1", func_name="main") + b5 = sch.get_block(name="T_add_1", func_name="main") + b6 = sch.get_block(name="T_multiply_2", func_name="main") + b7 = sch.get_block(name="compute_1", func_name="main") + b8 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l9, l10, l11, l12 = sch.get_loops(block=b0) + v13, v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l18, l19, l20, l21, l22 = sch.split(loop=l9, factors=[v13, v14, v15, v16, v17], preserve_unit_iters=True) + v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[1, 1, 32, 4, 1]) + l28, l29, l30, l31, l32 = sch.split(loop=l10, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) + v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64, decision=[320, 1, 4, 8, 1]) + l38, l39, l40, l41, l42 = sch.split(loop=l11, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) + v43, v44, v45 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64, decision=[80, 32, 1]) + l46, l47, l48 = sch.split(loop=l12, factors=[v43, v44, v45], preserve_unit_iters=True) + sch.reorder(l18, l28, l38, l19, l29, l39, l20, l30, l40, l46, l47, l21, l31, l41, l48, l22, l32, l42) + l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) + sch.bind(loop=l49, thread_axis="blockIdx.x") + l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) + sch.bind(loop=l50, thread_axis="vthread.x") + l51 = sch.fuse(l20, l30, l40, preserve_unit_iters=True) + sch.bind(loop=l51, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b52 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) + b53 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b53, loop=l46, preserve_unit_loops=True, index=-1) + l58, l59, l60 = sch.get_loops(block=b53)[-3:] + sch.fuse(l58, l59, l60, preserve_unit_iters=True) + v62 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v62) + b63 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b63, loop=l46, preserve_unit_loops=True, index=-1) + l68, l69 = sch.get_loops(block=b63)[-2:] + sch.fuse(l68, l69, preserve_unit_iters=True) + v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b63, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + sch.reverse_compute_inline(block=b7) + sch.compute_inline(block=b5) + sch.compute_inline(block=b4) + sch.compute_inline(block=b3) + sch.compute_inline(block=b2) + sch.compute_inline(block=b1) + sch.reverse_compute_inline(block=b6) + v72 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v72) + sch.enter_postproc() + sch.unannotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch") + l77 = sch.get_loops(block=b53)[-1] + _, l79, l80 = sch.split(loop=l77, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l80) + sch.bind(loop=l79, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b63, ann_key="meta_schedule.cooperative_fetch") + l85 = sch.get_loops(block=b63)[-1] + _, l87, l88 = sch.split(loop=l85, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l88) + sch.bind(loop=l87, thread_axis="threadIdx.x") + b89 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b89, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv57_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b92, _ = sch.get_child_blocks(b89) + l108 = sch.get_loops(block=b92)[0] + sch.annotate(block_or_loop=l108, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l108, ann_key="pragma_unroll_explicit", ann_val=1) + + b126 = sch.get_block(name="NT_matmul", func_name="main") + l130 = sch.get_loops(block=b126)[4] + sch.decompose_reduction(block=b126, loop=l130) + + +def fused_NT_matmul4_add7_cast8_cast12_add5(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="compute", func_name="main") + b3 = sch.get_block(name="compute_1", func_name="main") + b4 = sch.get_block(name="T_add_1", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l6, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 4, 16, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l7, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[40, 1, 8, 1, 8]) + l35, l36, l37, l38, l39 = sch.split(loop=l8, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[320, 32, 1]) + l43, l44, l45 = sch.split(loop=l9, factors=[v40, v41, v42], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l16, l26, l36, l17, l27, l37, l43, l44, l18, l28, l38, l45, l19, l29, l39) + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="blockIdx.x") + l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) + sch.bind(loop=l47, thread_axis="vthread.x") + l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) + sch.bind(loop=l48, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b49 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b49, loop=l48, preserve_unit_loops=True, index=-1) + b50 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b50, loop=l43, preserve_unit_loops=True, index=-1) + l55, l56, l57 = sch.get_loops(block=b50)[-3:] + sch.fuse(l55, l56, l57, preserve_unit_iters=True) + v59 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch", ann_val=v59) + b60 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l43, preserve_unit_loops=True, index=-1) + l65, l66 = sch.get_loops(block=b60)[-2:] + sch.fuse(l65, l66, preserve_unit_iters=True) + v68 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v68) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v69 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v69) + sch.enter_postproc() + sch.unannotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch") + l74 = sch.get_loops(block=b50)[-1] + _, l76, l77 = sch.split(loop=l74, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l77) + sch.bind(loop=l76, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + l82 = sch.get_loops(block=b60)[-1] + _, l84, l85 = sch.split(loop=l82, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l85) + sch.bind(loop=l84, thread_axis="threadIdx.x") + b86 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b86, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv63_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b89, _ = sch.get_child_blocks(b86) + l105 = sch.get_loops(block=b89)[0] + sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) + b123 = sch.get_block(name="NT_matmul", func_name="main") + l127 = sch.get_loops(block=b123)[4] + sch.decompose_reduction(block=b123, loop=l127) + + +def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="compute", func_name="main") + b3 = sch.get_block(name="compute_1", func_name="main") + b4 = sch.get_block(name="T_add_1", func_name="main") + b5 = sch.get_block(name="compute_2", func_name="main") + b6 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l7, l8, l9, l10 = sch.get_loops(block=b0) + v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l16, l17, l18, l19, l20 = sch.split(loop=l7, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) + v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 2, 2]) + l26, l27, l28, l29, l30 = sch.split(loop=l8, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) + v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[64, 2, 10, 1, 2]) + l36, l37, l38, l39, l40 = sch.split(loop=l9, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) + v41, v42, v43 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[256, 20, 2]) + l44, l45, l46 = sch.split(loop=l10, factors=[v41, v42, v43], preserve_unit_iters=True) + sch.reorder(l16, l26, l36, l17, l27, l37, l18, l28, l38, l44, l45, l19, l29, l39, l46, l20, l30, l40) + l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) + sch.bind(loop=l47, thread_axis="blockIdx.x") + l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) + sch.bind(loop=l48, thread_axis="vthread.x") + l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) + sch.bind(loop=l49, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b50 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b50, loop=l49, preserve_unit_loops=True, index=-1) + b51 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b51, loop=l44, preserve_unit_loops=True, index=-1) + l56, l57, l58 = sch.get_loops(block=b51)[-3:] + sch.fuse(l56, l57, l58, preserve_unit_iters=True) + v60 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b51, ann_key="meta_schedule.cooperative_fetch", ann_val=v60) + b61 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b61, loop=l44, preserve_unit_loops=True, index=-1) + l66, l67 = sch.get_loops(block=b61)[-2:] + sch.fuse(l66, l67, preserve_unit_iters=True) + v69 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69) + sch.reverse_compute_inline(block=b5) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v70 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v70) + sch.enter_postproc() + sch.unannotate(block_or_loop=b51, ann_key="meta_schedule.cooperative_fetch") + l75 = sch.get_loops(block=b51)[-1] + _, l77, l78 = sch.split(loop=l75, factors=[None, 80, 4], preserve_unit_iters=True) + sch.vectorize(loop=l78) + sch.bind(loop=l77, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") + l83 = sch.get_loops(block=b61)[-1] + _, l85, l86 = sch.split(loop=l83, factors=[None, 80, 2], preserve_unit_iters=True) + sch.vectorize(loop=l86) + sch.bind(loop=l85, thread_axis="threadIdx.x") + b87 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b87, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv2047_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b90, _ = sch.get_child_blocks(b87) + l106 = sch.get_loops(block=b90)[0] + sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) + b124 = sch.get_block(name="NT_matmul", func_name="main") + l128 = sch.get_loops(block=b124)[4] + sch.decompose_reduction(block=b124, loop=l128) + + +def fused_NT_matmul_divide_maximum_minimum_cast2(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="compute", func_name="main") + b5 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) + v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) + v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[4, 1, 8, 1, 1]) + l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) + v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) + v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 1, 16, 2, 1]) + l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) + v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[5, 8, 2]) + l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) + sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="blockIdx.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="vthread.x") + l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) + sch.bind(loop=l59, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) + b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) + l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] + sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) + v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) + l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] + sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) + v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) + sch.enter_postproc() + sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") + l88 = sch.get_loops(block=b61)[-1] + _, l90, l91 = sch.split(loop=l88, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l91) + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") + l96 = sch.get_loops(block=b72)[-1] + _, l98, l99 = sch.split(loop=l96, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l99) + sch.bind(loop=l98, thread_axis="threadIdx.x") + b100 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv2095_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b103, _ = sch.get_child_blocks(b100) + l119 = sch.get_loops(block=b103)[0] + sch.annotate(block_or_loop=l119, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l119, ann_key="pragma_unroll_explicit", ann_val=1) + + b140 = sch.get_block(name="NT_matmul", func_name="main") + l144 = sch.get_loops(block=b140)[4] + sch.decompose_reduction(block=b140, loop=l144) + + +def fused_layer_norm1_cast8(sch: tir.Schedule): + b0 = sch.get_block(name="A_red_temp", func_name="main") + b1 = sch.get_block(name="T_layer_norm", func_name="main") + b2 = sch.get_block(name="compute", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + sch.reverse_compute_inline(block=b2) + v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=5) + l5, l6, l7 = sch.get_loops(block=b0) + l8, l9 = sch.split(loop=l7, factors=[None, v4], preserve_unit_iters=True) + sch.bind(loop=l9, thread_axis="threadIdx.x") + v10 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v10) + l11, l12, l13 = sch.get_loops(block=b1) + l14 = sch.fuse(l11, l12, l13, preserve_unit_iters=True) + l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l16, l17, l15) + sch.bind(loop=l16, thread_axis="blockIdx.x") + sch.bind(loop=l17, thread_axis="threadIdx.x") + l18, l19, l20, l21 = sch.get_loops(block=b0) + l22 = sch.fuse(l18, l19, preserve_unit_iters=True) + sch.bind(loop=l22, thread_axis="blockIdx.x") + sch.enter_postproc() + b23 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b23, ann_key="meta_schedule.unroll_explicit") + b24, b25 = sch.get_child_blocks(b23) + l26, l27, l28 = sch.get_loops(block=b24) + sch.annotate(block_or_loop=l26, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l26, ann_key="pragma_unroll_explicit", ann_val=1) + l29, l30, l31 = sch.get_loops(block=b25) + + +def layer_norm1(sch: tir.Schedule): + b0 = sch.get_block(name="A_red_temp", func_name="main") + b1 = sch.get_block(name="T_layer_norm", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=1) + l4, l5, l6 = sch.get_loops(block=b0) + l7, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) + sch.bind(loop=l8, thread_axis="threadIdx.x") + v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) + l10, l11, l12 = sch.get_loops(block=b1) + l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) + l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l15, l16, l14) + sch.bind(loop=l15, thread_axis="blockIdx.x") + sch.bind(loop=l16, thread_axis="threadIdx.x") + l17, l18, l19, l20 = sch.get_loops(block=b0) + l21 = sch.fuse(l17, l18, preserve_unit_iters=True) + sch.bind(loop=l21, thread_axis="blockIdx.x") + sch.enter_postproc() + b22 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") + b23, b24 = sch.get_child_blocks(b22) + l25, l26, l27 = sch.get_loops(block=b23) + sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) + l28, l29, l30 = sch.get_loops(block=b24) + + +def matmul3(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 1, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(l1, l2, l3, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 16, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[8, 1, 10, 1, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[1, 32, 1]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86, l87 = sch.split(loop=l84, factors=[None, 160, 4], preserve_unit_iters=True) + sch.vectorize(loop=l87) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l92 = sch.get_loops(block=b68)[-1] + _, l94, l95 = sch.split(loop=l92, factors=[None, 160, 2], preserve_unit_iters=True) + sch.vectorize(loop=l95) + sch.bind(loop=l94, thread_axis="threadIdx.x") + b96 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + + _, _, b99, _ = sch.get_child_blocks(b96) + l115 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) + b136 = sch.get_block(name="matmul", func_name="main") + l140 = sch.get_loops(block=b136)[3] + sch.decompose_reduction(block=b136, loop=l140) + + +def matmul9(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + s0, s1 = sch.split(l3, [None, 32]) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(s0, l1, l2, s1, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[8, 1, 8, 2, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 1, 5, 2, 4]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[16, 1, 2]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + b1 = sch.get_block("matmul_pad") + sch.reverse_compute_inline(b1) + + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86, l87 = sch.split(loop=l84, factors=[None, 40, 4], preserve_unit_iters=True) + sch.vectorize(loop=l87) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l92 = sch.get_loops(block=b68)[-1] + _, l94, l95 = sch.split(loop=l92, factors=[None, 40, 2], preserve_unit_iters=True) + sch.vectorize(loop=l95) + sch.bind(loop=l94, thread_axis="threadIdx.x") + b96 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") + b136 = sch.get_block(name="matmul", func_name="main") + l140 = sch.get_loops(block=b136)[4] + sch.decompose_reduction(block=b136, loop=l140) + + +def softmax_1xn(sch: tir.Schedule): + has_cast = True + if has_cast: + b_cast = sch.get_block("compute") + sch.reverse_compute_inline(b_cast) + + b0 = sch.get_block("T_softmax_exp") + sch.compute_inline(b0) + b1 = sch.get_block("T_softmax_norm") + l2, l3, l4, l5 = sch.get_loops(b1) + _, l7 = sch.split(l5, [None, 128]) + sch.bind(l7, "threadIdx.x") + b8 = sch.get_block("T_softmax_expsum") + sch.compute_at(b8, l4) + sch.set_scope(b8, 0, "shared") + _, _, _, l12 = sch.get_loops(b8) + _, l14 = sch.split(l12, [None, 128]) + sch.bind(l14, "threadIdx.x") + b15 = sch.get_block("T_softmax_maxelem") + sch.compute_at(b15, l4) + sch.set_scope(b15, 0, "shared") + _, _, _, l19 = sch.get_loops(b15) + _, l21 = sch.split(l19, [None, 128]) + sch.bind(l21, "threadIdx.x") + l22 = sch.fuse(l2, l3, l4) + sch.bind(l22, "blockIdx.x") + + +def fused_min_max_triu_te_broadcast_to(sch: tir.Schedule): + b0 = sch.get_block("T_broadcast_to") + sch.reverse_compute_inline(b0) + b1 = sch.get_block("make_diag_mask_te") + i, j = sch.get_loops(b1) + i = sch.fuse(i, j) + i, j = sch.split(i, [None, 128]) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + + +@T.prim_func +def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_expsum_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) + v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) + v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) + T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) + T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + + + +def _get_dict(): + # tvm.ir.assert_structural_equal(MOD["softmax"], softmax_mxn_before) + func_dict = { + # softmax_mxn_before: softmax_mxn_after, + } + for name, func in [ + # fmt: off + ("fused_layer_norm1_cast8", fused_layer_norm1_cast8), + ("fused_NT_matmul1_add4_add5", fused_NT_matmul1_add4_add5), + ("fused_NT_matmul2_divide1_maximum1_minimum1_cast9", fused_NT_matmul2_divide1_maximum1_minimum1_cast9), + ("fused_NT_matmul4_add7_cast8_cast12_add5", fused_NT_matmul4_add7_cast8_cast12_add5), + ("fused_NT_matmul3_add6_gelu1_cast11", fused_NT_matmul3_add6_gelu1_cast11), + ("fused_NT_matmul_divide_maximum_minimum_cast2", fused_NT_matmul_divide_maximum_minimum_cast2), + ("matmul3", matmul3), + ("fused_NT_matmul1_add4", fused_NT_matmul1_add4), + ("matmul9", matmul9), + ("layer_norm1", layer_norm1), + ("fused_NT_matmul4_add7_cast8_cast12_add5_cast7", fused_NT_matmul4_add7_cast8_cast12_add5_cast7), + ("fused_min_max_triu_te_broadcast_to", fused_min_max_triu_te_broadcast_to), + ("fused_softmax_cast3", softmax_1xn), + # fmt: on + ]: + # print(f"############### {name} ###############") + sch = tir.Schedule(MOD[name]) + func(sch) + # sch.mod["main"].show(black_format=False) + func_dict[MOD[name]] = sch.mod["main"] + return { + (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) + for k, v in func_dict.items() + } + + +DICT = _get_dict() + + +def lookup(func): + for (hash_value, func_before), f_after in DICT.items(): + if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( + func, func_before + ): + return f_after + return None diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py new file mode 100644 index 0000000..b71567b --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py @@ -0,0 +1,722 @@ +# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements +from tvm.script import ir as I +from tvm.script import tir as T + +# fmt: off + +@I.ir_module +class Module: + @T.prim_func + def cast7(var_A: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560)), "float16") + compute = T.match_buffer(var_compute, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.Cast("float32", A[v_i0, v_i1, v_i2]) + + @T.prim_func + def extend_te(var_A: T.handle, var_concat_te: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16") + m = T.int64() + concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16") + # with T.block("root"): + for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m): + with T.block("concat_te"): + v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j]) + T.reads(A[v_b, v__, v_i, v_j + n - m]) + T.writes(concat_te[v_b, v__, v_i, v_j]) + concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m]) + + @T.prim_func + def full(var_T_full: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n): + with T.block("T_full"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads() + T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) + T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504) + + @T.prim_func + def fused_NT_matmul1_add4(p_lv9: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv9 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv9[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv9[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + + @T.prim_func + def fused_NT_matmul1_add4_add5(p_lv49: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(p_lv36: T.handle, p_lv37: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + m = T.int64() + lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv36[v_i0, v_i1, v_i2, v_k], lv37[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv36[v_i0, v_i1, v_i2, v_k] * lv37[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_NT_matmul3_add6_gelu1_cast11(p_lv57: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv57[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv57[v_i0, v_i1, v_k]) * T.Cast("float32", lv1201[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_NT_matmul4_add7_cast8_cast12_add5(p_lv63: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv53: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") + lv53 = T.match_buffer(p_lv53, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv63[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv63[v_i0, v_i1, v_k]) * T.Cast("float32", lv1208[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv53[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv53[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(p_lv2047: T.handle, lv2510: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias191: T.Buffer((T.int64(2560),), "float32"), p_lv2037: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") + lv2037 = T.match_buffer(p_lv2037, (T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2047[v_i0, v_i1, v_k], lv2510[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2047[v_i0, v_i1, v_k]) * T.Cast("float32", lv2510[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_2[v_i0, v_i1, v_i2]) + var_compute_intermediate_2[v_i0, v_i1, v_i2] = var_compute_intermediate_1[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_2[v_ax0, v_ax1, v_ax2], lv2037[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2037[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_NT_matmul_divide_maximum_minimum_cast2(lv2094: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv2095: T.handle, p_lv2063: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv2095 = T.match_buffer(p_lv2095, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + lv2063 = T.match_buffer(p_lv2063, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv2094[v_i0, v_i1, v_i2, v_k], lv2095[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2094[v_i0, v_i1, v_i2, v_k] * lv2095[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_layer_norm1_cast8(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(lv6[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) + T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16") + # with T.block("root"): + var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16") + for i, j in T.grid(n, n): + with T.block("make_diag_mask_te"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads() + T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) + var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): + with T.block("T_broadcast_to"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) + T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] + + @T.prim_func + def fused_softmax1_cast10(p_lv44: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv44[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_softmax_cast3(p_lv2102: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv2102 = T.match_buffer(p_lv2102, (T.int64(1), T.int64(32), T.int64(1), n)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2102[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv2102[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv2102[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv2102[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def matmul3(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def matmul9(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def reshape3(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (n, T.int64(32), T.int64(80)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)] + + @T.prim_func + def reshape5(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n), "int32") + T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") + # with T.block("root"): + for ax0 in range(n): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(n, ax0) + T.reads(A[T.int64(0), v_ax0 % n]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] + + @T.prim_func + def reshape6(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (n, T.int64(2560)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)] + + @T.prim_func + def reshape7(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + + @T.prim_func + def reshape8(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] + + @T.prim_func + def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(80)), "float16"), C: T.Buffer((T.int64(2048), T.int64(80)), "float16"), var_rotary: T.handle, m: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + # with T.block("root"): + for i_batch_size, i_seq_len, i_num_heads, i_head_dim in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("rotary"): + v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim = T.axis.remap("SSSS", [i_batch_size, i_seq_len, i_num_heads, i_head_dim]) + T.reads(B[m + v_i_seq_len - n, v_i_head_dim], A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40):v_i_head_dim - T.int64(40) + T.int64(81)], C[m + v_i_seq_len - n, v_i_head_dim]) + T.writes(rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) + rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] = T.Select(v_i_head_dim < T.int64(80), B[m + v_i_seq_len - n, v_i_head_dim] * A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] + C[m + v_i_seq_len - n, v_i_head_dim] * T.Select(v_i_head_dim < T.int64(40), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim + T.int64(40)] * T.float16(-1), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40)]), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) + + @T.prim_func + def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("slice"): + v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) + T.reads(A[v_i, n - T.int64(1), v_k]) + T.writes(slice_1[v_i, v__, v_k]) + slice_1[v_i, v__, v_k] = A[v_i, n - T.int64(1), v_k] + + @T.prim_func + def squeeze1(var_A: T.handle, var_T_squeeze: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(80)), "float16") + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) + T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def take_decode1(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "float16"), var_C: T.handle, var_take_decode: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + C = T.match_buffer(var_C, (n,), "int32") + take_decode = T.match_buffer(var_take_decode, (n, T.int64(2560)), "float16") + # with T.block("root"): + for i, j in T.grid(n, T.int64(2560)): + with T.block("take_decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) + T.writes(take_decode[v_i, v_j]) + take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[C[v_i], v_j // T.int64(32)] + + @T.prim_func + def transpose3(var_A: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + + @T.prim_func + def transpose6(var_A: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + + +# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py new file mode 100644 index 0000000..460bec0 --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py @@ -0,0 +1,1010 @@ +from tvm.script import ir as I +from tvm.script import tir as T + +""" + ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done +----------------------------------------------------------------------------------------------------------------------------------------------------------- + 0 | cast | 1 | 1 | 0.0000 | 27.7422 | 27.7422 | 4 | Y + 1 | cast6 | 1 | 1 | 0.0000 | 27.6800 | 27.6800 | 4 | Y + 2 | decode4 | 26214400 | 1 | 167.6862 | 156.3301 | 156.3301 | 172 | Y + 3 | decode5 | 104857600 | 1 | 128.5783 | 815.5153 | 815.5153 | 172 | Y + 4 | decode6 | 104857600 | 1 | 128.6586 | 815.0066 | 815.0066 | 179 | Y + 5 | divide2 | 50432 | 1 | 1.8169 | 27.7575 | 27.7575 | 4 | Y + 6 | fused_NT_matmul1_add4 | 1678049280 | 1 | 2178.8097 | 770.1679 | 770.1679 | 1088 | Y + 7 | fused_NT_matmul1_add4_add5 | 1678376960 | 1 | 2130.5374 | 787.7717 | 787.7717 | 1215 | Y + 8 | fused_NT_matmul2_divide1_maximum1_minimum1_cast9 | 85458944 | 1 | 1211.9454 | 70.5139 | 70.5139 | 192 | Y + 9 | fused_NT_matmul3_add6_gelu1_cast11 | 6717440000 | 1 | 2129.3171 | 3154.7391 | 3154.7391 | 4416 | Y + 10 | fused_NT_matmul4_add7_cast8_cast12_add5 | 6711541760 | 1 | 2072.7296 | 3238.0208 | 3238.0208 | 4544 | Y + 11 | fused_NT_matmul4_add7_cast8_cast12_add5_cast7 | 6711541760 | 1 | 2091.5892 | 3208.8241 | 3208.8241 | 4416 | Y + 12 | fused_NT_matmul_divide_maximum_minimum_cast2 | 667648 | 1 | 23.3021 | 28.6519 | 28.6519 | 64 | Y + 13 | fused_decode1_fused_matmul4_add2_gelu_cast4 | 157337600 | 1 | 812.5380 | 193.6372 | 193.6372 | 319 | Y + 14 | fused_decode2_fused_matmul5_add3_cast1_cast5_add1 | 157291520 | 1 | 730.8166 | 215.2271 | 215.2271 | 320 | Y + 15 | fused_decode2_fused_matmul5_add3_cast1_cast5_add1_cast | 157291520 | 1 | 729.0229 | 215.7566 | 215.7566 | 319 | Y + 16 | fused_decode3_matmul6 | 774635520 | 1 | 868.1608 | 892.2719 | 892.2719 | 1331 | Y + 17 | fused_decode_fused_matmul2_add | 39324160 | 1 | 733.2646 | 53.6289 | 53.6289 | 191 | Y + 18 | fused_decode_fused_matmul2_add_add1 | 39326720 | 1 | 740.8926 | 53.0802 | 53.0802 | 192 | Y + 19 | fused_layer_norm1_cast8 | 4587520 | 1 | 76.3188 | 60.1099 | 60.1099 | 50 | Y + 20 | fused_layer_norm_cast1 | 35840 | 1 | 0.6533 | 54.8634 | 54.8634 | 159 | Y + 21 | fused_reshape2_squeeze | 1 | 1 | 0.0000 | 27.5470 | 27.5470 | 4 | Y + 22 | fused_slice1_cast6 | 1 | 1 | 0.0000 | 27.5899 | 27.5899 | 4 | Y + 23 | fused_transpose4_reshape4 | 1 | 1 | 0.0000 | 27.5157 | 27.5157 | 4 | Y + 24 | layer_norm | 35840 | 1 | 0.6506 | 55.0910 | 55.0910 | 160 | Y + 25 | layer_norm1 | 4587520 | 1 | 74.6941 | 61.4174 | 61.4174 | 50 | Y + 26 | matmul3 | 163840 | 1 | 5.8011 | 28.2428 | 28.2428 | 64 | Y + 27 | matmul9 | 20971520 | 1 | 571.2811 | 36.7096 | 36.7096 | 192 | Y + 28 | reshape | 1 | 1 | 0.0000 | 27.9399 | 27.9399 | 1 | Y + 29 | reshape1 | 1 | 1 | 0.0000 | 27.6659 | 27.6659 | 4 | Y + 30 | reshape2 | 1 | 1 | 0.0000 | 27.6446 | 27.6446 | 4 | Y + 31 | softmax2 | 201728 | 1 | 2.8631 | 70.4578 | 70.4578 | 186 | Y + 32 | squeeze | 1 | 1 | 0.0000 | 27.3156 | 27.3156 | 4 | Y + 33 | take_decode | 10240 | 1 | 0.3712 | 27.5835 | 27.5835 | 4 | Y + 34 | transpose2 | 1 | 1 | 0.0000 | 27.6975 | 27.6975 | 4 | Y +----------------------------------------------------------------------------------------------------------------------------------------------------------- +""" + +# fmt: off + +@I.ir_module +class Module: + @T.prim_func + def cast(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.Cast("float32", A[v_i0, v_i1, v_i2]) + + @T.prim_func + def cast6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = A[v_i0, v_i1, v_i2] + + @T.prim_func + def decode4(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(80), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def decode5(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: T.Buffer((T.int64(80), T.int64(10240)), "float16"), T_transpose: T.Buffer((T.int64(10240), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def decode6(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(320), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(10240)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def divide2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], B[()]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2]) + T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()] + + @T.prim_func + def fused_decode1_fused_matmul4_add2_gelu_cast4(lv32: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), lv33: T.Buffer((T.int64(80), T.int64(10240)), "float16"), lv2115: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias196: T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv32[v_i // T.int64(8), v_j], lv33[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv32[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv33[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2115[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_decode2_fused_matmul5_add3_cast1_cast5_add1(lv38: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv39: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv2121: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias197: T.Buffer((T.int64(2560),), "float32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv38[v_i // T.int64(8), v_j], lv39[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv38[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv39[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2121[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv8[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv8[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_decode2_fused_matmul5_add3_cast1_cast5_add1_cast(lv1154: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv1155: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv4105: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias383: T.Buffer((T.int64(2560),), "float32"), lv380: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1154[v_i // T.int64(8), v_j], lv1155[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1154[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1155[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv4105[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias383[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv380[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv380[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_decode3_matmul6(lv1160: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), lv1161: T.Buffer((T.int64(80), T.int64(50432)), "float32"), lv384: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) + for i, j in T.grid(T.int64(2560), T.int64(50432)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1160[v_i // T.int64(8), v_j], lv1161[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.Cast("float16", T.bitwise_and(T.shift_right(lv1160[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1161[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv384[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv384[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + @T.prim_func + def fused_decode_fused_matmul2_add(lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv2067: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias192: T.Buffer((T.int64(2560),), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv8[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv9[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2067[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] + + @T.prim_func + def fused_decode_fused_matmul2_add_add1(lv26: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv27: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv7: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias195: T.Buffer((T.int64(2560),), "float16"), lv2062: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv26[v_i // T.int64(8), v_j], lv27[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv26[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv27[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv7[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2062[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2062[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_layer_norm_cast1(lv2064: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), weight67: T.Buffer((T.int64(2560),), "float32"), bias65: T.Buffer((T.int64(2560),), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) + var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(lv2064[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv2064[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv2064[v_ax0, v_ax1, v_k2] * lv2064[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2064[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight67[v_ax2], bias65[v_ax2]) + T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv2064[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight67[v_ax2] + bias65[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_reshape2_squeeze(lv2080: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16") + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(lv2080[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2080[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_slice1_cast6(lv4113: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_slice_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("slice"): + v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) + T.reads(lv4113[v_i, T.int64(0), v_k]) + T.writes(var_slice_intermediate[v_i, v__, v_k]) + var_slice_intermediate[v_i, v__, v_k] = lv4113[v_i, T.int64(0), v_k] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_slice_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = var_slice_intermediate[v_i0, v_i1, v_i2] + + @T.prim_func + def fused_transpose4_reshape4(lv2105: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16") + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(lv2105[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2105[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) + T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] + + @T.prim_func + def layer_norm(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) + for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def reshape(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + T.reads(A[T.int64(0), T.int64(0)]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = A[T.int64(0), T.int64(0)] + + @T.prim_func + def reshape1(A: T.Buffer((T.int64(1), T.int64(2560)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax2 % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(2560)] + + @T.prim_func + def reshape2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + + @T.prim_func + def softmax2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(50432))) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) + for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(A[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) + T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) + for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1]) + with T.init(): + T_softmax_expsum[v_i0, v_i1] = T.float32(0) + T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] + + @T.prim_func + def squeeze(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) + T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def take_decode(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode_1: T.Buffer((T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i, j in T.grid(T.int64(1), T.int64(2560)): + with T.block("take_decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) + T.writes(take_decode_1[v_i, v_j]) + take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[C[v_i], v_j // T.int64(32)] + + @T.prim_func + def transpose2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + + ####################################### Dynamic Shape ####################################### + + @T.prim_func + def fused_NT_matmul1_add4(p_lv9: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv9 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv9[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv9[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + + @T.prim_func + def fused_NT_matmul1_add4_add5(p_lv49: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(p_lv36: T.handle, p_lv37: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + m = T.meta_var(T.int64(128)) + lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv36[v_i0, v_i1, v_i2, v_k], lv37[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv36[v_i0, v_i1, v_i2, v_k] * lv37[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_NT_matmul3_add6_gelu1_cast11(p_lv57: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv57[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv57[v_i0, v_i1, v_k] * lv1201[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_NT_matmul4_add7_cast8_cast12_add5(p_lv63: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv53: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") + lv53 = T.match_buffer(p_lv53, (T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv63[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv63[v_i0, v_i1, v_k] * lv1208[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv53[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv53[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(p_lv2047: T.handle, lv2510: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias191: T.Buffer((T.int64(2560),), "float32"), p_lv2037: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") + lv2037 = T.match_buffer(p_lv2037, (T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_compute_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2047[v_i0, v_i1, v_k], lv2510[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2047[v_i0, v_i1, v_k] * lv2510[v_i2, v_k]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_2[v_i0, v_i1, v_i2]) + var_compute_intermediate_2[v_i0, v_i1, v_i2] = var_compute_intermediate_1[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_2[v_ax0, v_ax1, v_ax2], lv2037[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2037[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + + @T.prim_func + def fused_NT_matmul_divide_maximum_minimum_cast2(lv2094: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv2095: T.handle, p_lv2063: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv2095 = T.match_buffer(p_lv2095, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + lv2063 = T.match_buffer(p_lv2063, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv2094[v_i0, v_i1, v_i2, v_k], lv2095[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2094[v_i0, v_i1, v_i2, v_k] * lv2095[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def fused_layer_norm1_cast8(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(lv6[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) + T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def matmul3(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(32)) + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def matmul9(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + m = T.meta_var(T.int64(32)) + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + +# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py new file mode 100644 index 0000000..b6e9123 --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py @@ -0,0 +1,840 @@ +# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals +import tvm +from tvm import tir +from tvm.script import tir as T + +from .redpajama_q4f32_mod import Module as MOD + +# fmt: off + +def fused_NT_matmul1_divide_maximum_minimum(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l4, [None, 32]) + sch.reorder(l6, l8, l1, l2, l7, l9, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[16, 1, 2, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 4]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[8, 1, 4, 2, 2]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[10, 4, 2]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + l65, l66, l67, l68 = sch.get_loops(block=b60)[-4:] + sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + l76, l77, l78, l79 = sch.get_loops(block=b71)[-4:] + sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + l87 = sch.get_loops(block=b60)[-1] + _, l89, l90 = sch.split(loop=l87, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l90) + sch.bind(loop=l89, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + l95 = sch.get_loops(block=b71)[-1] + _, l97 = sch.split(loop=l95, factors=[None, 32], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + b98 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv34_pad") + sch.compute_inline(b1) + b1 = sch.get_block("lv35_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + b140 = sch.get_block(name="NT_matmul", func_name="main") + l144 = sch.get_loops(block=b140)[5] + sch.decompose_reduction(block=b140, loop=l144) + + b101 = sch.get_child_blocks(b98)[2] + l116 = sch.get_loops(block=b101)[0] + sch.annotate(block_or_loop=l116, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l116, ann_key="pragma_unroll_explicit", ann_val=1) + + +def fused_NT_matmul2_add2_gelu(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_multiply", func_name="main") + b3 = sch.get_block(name="compute", func_name="main") + b4 = sch.get_block(name="T_multiply_1", func_name="main") + b5 = sch.get_block(name="T_add_1", func_name="main") + b6 = sch.get_block(name="T_multiply_2", func_name="main") + b7 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l8, l9, l10, l11 = sch.get_loops(block=b0) + v12, v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l17, l18, l19, l20, l21 = sch.split(loop=l8, factors=[v12, v13, v14, v15, v16], preserve_unit_iters=True) + v22, v23, v24, v25, v26 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[1, 2, 16, 2, 2]) + l27, l28, l29, l30, l31 = sch.split(loop=l9, factors=[v22, v23, v24, v25, v26], preserve_unit_iters=True) + v32, v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[320, 1, 8, 4, 1]) + l37, l38, l39, l40, l41 = sch.split(loop=l10, factors=[v32, v33, v34, v35, v36], preserve_unit_iters=True) + v42, v43, v44 = sch.sample_perfect_tile(loop=l11, n=3, max_innermost_factor=64, decision=[160, 4, 4]) + l45, l46, l47 = sch.split(loop=l11, factors=[v42, v43, v44], preserve_unit_iters=True) + sch.reorder(l17, l27, l37, l18, l28, l38, l19, l29, l39, l45, l46, l20, l30, l40, l47, l21, l31, l41) + l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) + sch.bind(loop=l48, thread_axis="blockIdx.x") + l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) + sch.bind(loop=l49, thread_axis="vthread.x") + l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) + sch.bind(loop=l50, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b51 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b51, loop=l50, preserve_unit_loops=True, index=-1) + b52 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b52, loop=l45, preserve_unit_loops=True, index=-1) + l57, l58, l59 = sch.get_loops(block=b52)[-3:] + sch.fuse(l57, l58, l59, preserve_unit_iters=True) + v61 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61) + b62 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b62, loop=l45, preserve_unit_loops=True, index=-1) + l67, l68 = sch.get_loops(block=b62)[-2:] + sch.fuse(l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + sch.compute_inline(block=b5) + sch.compute_inline(block=b4) + sch.compute_inline(block=b3) + sch.compute_inline(block=b2) + sch.compute_inline(block=b1) + sch.reverse_compute_inline(block=b6) + v71 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b7, ann_key="meta_schedule.unroll_explicit", ann_val=v71) + sch.enter_postproc() + sch.unannotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch") + l76 = sch.get_loops(block=b52)[-1] + _, l78, l79 = sch.split(loop=l76, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l79) + sch.bind(loop=l78, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b62)[-1] + _, l86 = sch.split(loop=l84, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l86, thread_axis="threadIdx.x") + b87 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b87, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv51_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b90, _ = sch.get_child_blocks(b87) + l105 = sch.get_loops(block=b90)[0] + sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) + b123 = sch.get_block(name="NT_matmul", func_name="main") + l127 = sch.get_loops(block=b123)[4] + sch.decompose_reduction(block=b123, loop=l127) + + +def fused_NT_matmul3_add_cast_add1(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="compute", func_name="main") + b3 = sch.get_block(name="T_add_1", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8 = sch.get_loops(block=b0) + v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l14, l15, l16, l17, l18 = sch.split(loop=l5, factors=[v9, v10, v11, v12, v13], preserve_unit_iters=True) + v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 4, 32, 1, 1]) + l24, l25, l26, l27, l28 = sch.split(loop=l6, factors=[v19, v20, v21, v22, v23], preserve_unit_iters=True) + v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[40, 1, 4, 16, 1]) + l34, l35, l36, l37, l38 = sch.split(loop=l7, factors=[v29, v30, v31, v32, v33], preserve_unit_iters=True) + v39, v40, v41 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[640, 4, 4]) + l42, l43, l44 = sch.split(loop=l8, factors=[v39, v40, v41], preserve_unit_iters=True) + sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l43, l17, l27, l37, l44, l18, l28, l38) + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="blockIdx.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="vthread.x") + l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) + sch.bind(loop=l47, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b48 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b48, loop=l47, preserve_unit_loops=True, index=-1) + b49 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b49, loop=l42, preserve_unit_loops=True, index=-1) + l54, l55, l56 = sch.get_loops(block=b49)[-3:] + sch.fuse(l54, l55, l56, preserve_unit_iters=True) + v58 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v58) + b59 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b59, loop=l42, preserve_unit_loops=True, index=-1) + l64, l65 = sch.get_loops(block=b59)[-2:] + sch.fuse(l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v68 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v68) + sch.enter_postproc() + sch.unannotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch") + l73 = sch.get_loops(block=b49)[-1] + _, l75, l76 = sch.split(loop=l73, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l76) + sch.bind(loop=l75, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch") + l81 = sch.get_loops(block=b59)[-1] + _, l83, l84 = sch.split(loop=l81, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l84) + sch.bind(loop=l83, thread_axis="threadIdx.x") + b85 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv56_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b88, _ = sch.get_child_blocks(b85) + l104 = sch.get_loops(block=b88)[0] + sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) + b121 = sch.get_block(name="NT_matmul", func_name="main") + l125 = sch.get_loops(block=b121)[4] + sch.decompose_reduction(block=b121, loop=l125) + + +def fused_NT_matmul4_divide2_maximum1_minimum1(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[16, 2, 1, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 32, 1, 2]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[20, 2, 2]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + l65, l66, l67, l68 = sch.get_loops(block=b60)[-4:] + sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + l76, l77, l78, l79 = sch.get_loops(block=b71)[-4:] + sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + l87 = sch.get_loops(block=b60)[-1] + _, l89, l90 = sch.split(loop=l87, factors=[None, 16, 2], preserve_unit_iters=True) + sch.vectorize(loop=l90) + sch.bind(loop=l89, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + l95 = sch.get_loops(block=b71)[-1] + _, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + b98 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv1836_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + b140 = sch.get_block(name="NT_matmul", func_name="main") + l144 = sch.get_loops(block=b140)[4] + sch.decompose_reduction(block=b140, loop=l144) + + +def fused_NT_matmul_add(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[64, 5, 8, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[320, 2, 4]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + l52, l53, l54 = sch.get_loops(block=b47)[-3:] + sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + l62, l63 = sch.get_loops(block=b57)[-2:] + sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + l71 = sch.get_loops(block=b47)[-1] + _, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l79 = sch.get_loops(block=b57)[-1] + _, l81 = sch.split(loop=l79, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv7_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b85, _ = sch.get_child_blocks(b82) + l100 = sch.get_loops(block=b85)[0] + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + b118 = sch.get_block(name="NT_matmul", func_name="main") + l122 = sch.get_loops(block=b118)[4] + sch.decompose_reduction(block=b118, loop=l122) + + +def fused_NT_matmul_add_add1(sch: tir.Schedule): + b0 = sch.get_block(name="NT_matmul", func_name="main") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="T_add_1", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l4, l5, l6, l7 = sch.get_loops(block=b0) + v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) + v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 2, 32, 1, 1]) + l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) + v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[80, 2, 1, 16, 1]) + l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) + v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[320, 1, 8]) + l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) + sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="blockIdx.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="vthread.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) + b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) + l53, l54, l55 = sch.get_loops(block=b48)[-3:] + sch.fuse(l53, l54, l55, preserve_unit_iters=True) + v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) + b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) + l63, l64 = sch.get_loops(block=b58)[-2:] + sch.fuse(l63, l64, preserve_unit_iters=True) + v66 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) + sch.reverse_compute_inline(block=b2) + sch.reverse_compute_inline(block=b1) + v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) + sch.enter_postproc() + sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") + l72 = sch.get_loops(block=b48)[-1] + _, l74, l75 = sch.split(loop=l72, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l75) + sch.bind(loop=l74, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") + l80 = sch.get_loops(block=b58)[-1] + _, l82, l83 = sch.split(loop=l80, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l83) + sch.bind(loop=l82, thread_axis="threadIdx.x") + b84 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b84, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("lv45_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + _, _, b87, _ = sch.get_child_blocks(b84) + l103 = sch.get_loops(block=b87)[0] + sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) + b121 = sch.get_block(name="NT_matmul", func_name="main") + l125 = sch.get_loops(block=b121)[4] + sch.decompose_reduction(block=b121, loop=l125) + + + +def layer_norm(sch: tir.Schedule): + b0 = sch.get_block(name="A_red_temp", func_name="main") + b1 = sch.get_block(name="T_layer_norm", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=4) + _, _, l6 = sch.get_loops(block=b0) + _, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) + sch.bind(loop=l8, thread_axis="threadIdx.x") + v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) + l10, l11, l12 = sch.get_loops(block=b1) + l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) + l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l15, l16, l14) + sch.bind(loop=l15, thread_axis="blockIdx.x") + sch.bind(loop=l16, thread_axis="threadIdx.x") + l17, l18, _, _ = sch.get_loops(block=b0) + l21 = sch.fuse(l17, l18, preserve_unit_iters=True) + sch.bind(loop=l21, thread_axis="blockIdx.x") + sch.enter_postproc() + b22 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") + b23, _ = sch.get_child_blocks(b22) + l25, _, _ = sch.get_loops(block=b23) + sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) + + +def matmul(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 32, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + s0, s1 = sch.split(l3, [None, 32]) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(s0, l1, l2, s1, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[8, 4, 1, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[16, 4, 2, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 80, 1, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[8, 4, 1]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86, l87 = sch.split(loop=l84, factors=[None, 160, 4], preserve_unit_iters=True) + sch.vectorize(loop=l87) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l92 = sch.get_loops(block=b68)[-1] + _, l94, l95 = sch.split(loop=l92, factors=[None, 160, 2], preserve_unit_iters=True) + sch.vectorize(loop=l95) + sch.bind(loop=l94, thread_axis="threadIdx.x") + b96 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + b1 = sch.get_block("matmul_1_pad") + sch.reverse_compute_inline(b1) + + _, _, b99, _ = sch.get_child_blocks(b96) + l115 = sch.get_loops(block=b99)[0] + sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) + b136 = sch.get_block(name="matmul", func_name="main") + l140 = sch.get_loops(block=b136)[4] + sch.decompose_reduction(block=b136, loop=l140) + + + +def matmul8(sch: tir.Schedule): + b0 = sch.get_block(name="matmul", func_name="main") + sch.pad_einsum(b0, [1, 1, 1, 1, 32]) + l1, l2, l3, l4, k = sch.get_loops(b0) + k0, k1 = sch.split(k, [None, 32]) + sch.reorder(l1, l2, l3, k0, l4, k1) + + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 2, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 1, 40, 1, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[8, 2, 2]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] + sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] + sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l84 = sch.get_loops(block=b57)[-1] + _, l86 = sch.split(loop=l84, factors=[None, 80], preserve_unit_iters=True) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l91 = sch.get_loops(block=b68)[-1] + _, l93 = sch.split(loop=l91, factors=[None, 80], preserve_unit_iters=True) + sch.bind(loop=l93, thread_axis="threadIdx.x") + b94 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b94, ann_key="meta_schedule.unroll_explicit") + + b1 = sch.get_block("A_pad") + sch.compute_inline(b1) + b1 = sch.get_block("B_pad") + sch.compute_inline(b1) + + b132 = sch.get_block(name="matmul", func_name="main") + l136 = sch.get_loops(block=b132)[3] + sch.decompose_reduction(block=b132, loop=l136) + + +@T.prim_func +def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_expsum_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) + v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) + v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) + T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) + T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + + +def fused_min_max_triu_te_broadcast_to(sch: tir.Schedule): + b0 = sch.get_block("T_broadcast_to") + sch.reverse_compute_inline(b0) + b1 = sch.get_block("make_diag_mask_te") + i, j = sch.get_loops(b1) + i = sch.fuse(i, j) + i, j = sch.split(i, [None, 128]) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + +def softmax_1xn(sch: tir.Schedule): + has_cast = False + if has_cast: + b_cast = sch.get_block("compute") + sch.reverse_compute_inline(b_cast) + + b0 = sch.get_block("T_softmax_exp") + sch.compute_inline(b0) + b1 = sch.get_block("T_softmax_norm") + l2, l3, l4, l5 = sch.get_loops(b1) + _, l7 = sch.split(l5, [None, 128]) + sch.bind(l7, "threadIdx.x") + b8 = sch.get_block("T_softmax_expsum") + sch.compute_at(b8, l4) + sch.set_scope(b8, 0, "shared") + _, _, _, l12 = sch.get_loops(b8) + _, l14 = sch.split(l12, [None, 128]) + sch.bind(l14, "threadIdx.x") + b15 = sch.get_block("T_softmax_maxelem") + sch.compute_at(b15, l4) + sch.set_scope(b15, 0, "shared") + _, _, _, l19 = sch.get_loops(b15) + _, l21 = sch.split(l19, [None, 128]) + sch.bind(l21, "threadIdx.x") + l22 = sch.fuse(l2, l3, l4) + sch.bind(l22, "blockIdx.x") + +def _get_dict(): + tvm.ir.assert_structural_equal(MOD["softmax"], softmax_mxn_before) + func_dict = { + softmax_mxn_before: softmax_mxn_after, + } + for name, func in [ + # fmt: off + ("fused_NT_matmul1_divide_maximum_minimum", fused_NT_matmul1_divide_maximum_minimum), + ("fused_NT_matmul2_add2_gelu", fused_NT_matmul2_add2_gelu), + ("fused_NT_matmul3_add_cast_add1", fused_NT_matmul3_add_cast_add1), + ("fused_NT_matmul4_divide2_maximum1_minimum1", fused_NT_matmul4_divide2_maximum1_minimum1), + ("fused_NT_matmul_add", fused_NT_matmul_add), + ("fused_NT_matmul_add_add1", fused_NT_matmul_add_add1), + ("layer_norm", layer_norm), + ("matmul", matmul), + ("matmul8", matmul8), + ("softmax2", softmax_1xn), + ("fused_min_max_triu_te_broadcast_to", fused_min_max_triu_te_broadcast_to), + # fmt: on + ]: + # print(f"############### {name} ###############") + sch = tir.Schedule(MOD[name]) + func(sch) + # sch.mod["main"].show(black_format=False) + func_dict[MOD[name]] = sch.mod["main"] + return { + (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) + for k, v in func_dict.items() + } + + +DICT = _get_dict() + + +def lookup(func): + for (hash_value, func_before), f_after in DICT.items(): + if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( + func, func_before + ): + return f_after + return None diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py new file mode 100644 index 0000000..b6c4cbc --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py @@ -0,0 +1,577 @@ +# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements +from tvm.script import ir as I +from tvm.script import tir as T + +# fmt: off + +@I.ir_module +class Module: + @T.prim_func + def extend_te(var_A: T.handle, var_concat_te: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n)) + m = T.int64() + concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m)) + # with T.block("root"): + for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m): + with T.block("concat_te"): + v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j]) + T.reads(A[v_b, v__, v_i, v_j + n - m]) + T.writes(concat_te[v_b, v__, v_i, v_j]) + concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float32(3.4028234663852886e+38), A[v_b, v__, v_i, v_j + n - m]) + + @T.prim_func + def full(var_T_full: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n)) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n): + with T.block("T_full"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads() + T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) + T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(3.4028234663852886e+38) + + @T.prim_func + def fused_NT_matmul1_divide_maximum_minimum(p_lv34: T.handle, p_lv35: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv34 = T.match_buffer(p_lv34, (T.int64(1), T.int64(32), n, T.int64(80))) + m = T.int64() + lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), m, T.int64(80))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv34[v_i0, v_i1, v_i2, v_k], lv35[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv34[v_i0, v_i1, v_i2, v_k] * lv35[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + + @T.prim_func + def fused_NT_matmul2_add2_gelu(p_lv51: T.handle, lv38: T.Buffer((T.int64(10240), T.int64(2560)), "float32"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(2560))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv51[v_i0, v_i1, v_k], lv38[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv38[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul3_add_cast_add1(p_lv56: T.handle, lv45: T.Buffer((T.int64(2560), T.int64(10240)), "float32"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv49: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240))) + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv56[v_i0, v_i1, v_k], lv45[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv45[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate_1[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv49[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv49[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul4_divide2_maximum1_minimum1(lv1835: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), p_lv1836: T.handle, p_lv1806: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1836 = T.match_buffer(p_lv1836, (T.int64(1), T.int64(32), n, T.int64(80))) + lv1806 = T.match_buffer(p_lv1806, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1835[v_i0, v_i1, v_i2, v_k], lv1836[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1835[v_i0, v_i1, v_i2, v_k] * lv1836[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) + + @T.prim_func + def fused_NT_matmul_add(p_lv7: T.handle, lv10: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv7[v_i0, v_i1, v_k], lv10[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * lv10[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + + @T.prim_func + def fused_NT_matmul_add_add1(p_lv45: T.handle, lv31: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias3: T.Buffer((T.int64(2560),), "float32"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(2560))) + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv45[v_i0, v_i1, v_k], lv31[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv31[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n)) + # with T.block("root"): + var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n)) + for i, j in T.grid(n, n): + with T.block("make_diag_mask_te"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads() + T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) + var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float32(-3.4028234663852886e+38), T.float32(3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): + with T.block("T_broadcast_to"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) + T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] + + @T.prim_func + def layer_norm(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def matmul(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80))) + matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def reshape(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n), "int32") + T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") + # with T.block("root"): + for ax0 in range(n): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(n, ax0) + T.reads(A[T.int64(0), v_ax0 % n]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] + + @T.prim_func + def reshape1(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (n, T.int64(2560))) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)] + + @T.prim_func + def reshape2(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80))) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + + @T.prim_func + def reshape3(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + m = T.int64() + A = T.match_buffer(var_A, (m, T.int64(32), T.int64(80))) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), m, T.int64(32), T.int64(80))) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), m, T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)] + + @T.prim_func + def reshape4(var_A: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] + + @T.prim_func + def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(80)), "float32"), C: T.Buffer((T.int64(2048), T.int64(80)), "float32"), var_rotary: T.handle, m: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) + rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(80))) + # with T.block("root"): + for i_batch_size, i_seq_len, i_num_heads, i_head_dim in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("rotary"): + v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim = T.axis.remap("SSSS", [i_batch_size, i_seq_len, i_num_heads, i_head_dim]) + T.reads(B[m + v_i_seq_len - n, v_i_head_dim], A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40):v_i_head_dim - T.int64(40) + T.int64(81)], C[m + v_i_seq_len - n, v_i_head_dim]) + T.writes(rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) + rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] = T.Select(v_i_head_dim < T.int64(80), B[m + v_i_seq_len - n, v_i_head_dim] * A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] + C[m + v_i_seq_len - n, v_i_head_dim] * T.Select(v_i_head_dim < T.int64(40), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim + T.int64(40)] * T.float32(-1), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40)]), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) + + @T.prim_func + def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("slice"): + v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) + T.reads(A[v_i, n - T.int64(1), v_k]) + T.writes(slice_1[v_i, v__, v_k]) + slice_1[v_i, v__, v_k] = A[v_i, n - T.int64(1), v_k] + + @T.prim_func + def softmax(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + @T.prim_func + def softmax2(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + @T.prim_func + def squeeze(var_A: T.handle, var_T_squeeze: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) + T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(80))) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) + T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def take_decode(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "uint32"), var_C: T.handle, var_take_decode: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + C = T.match_buffer(var_C, (n,), "int32") + take_decode_1 = T.match_buffer(var_take_decode, (n, T.int64(2560))) + # with T.block("root"): + for i, j in T.grid(n, T.int64(2560)): + with T.block("take_decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) + T.writes(take_decode_1[v_i, v_j]) + take_decode_1[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[C[v_i], v_j // T.int64(32)], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[C[v_i], v_j // T.int64(32)], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + + @T.prim_func + def transpose(var_A: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) + T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(80))) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + + @T.prim_func + def transpose1(var_A: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(80))) + T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(80))) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] +# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py new file mode 100644 index 0000000..1b1169e --- /dev/null +++ b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py @@ -0,0 +1,743 @@ +# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements +from tvm.script import ir as I +from tvm.script import tir as T + +# fmt: off + +@I.ir_module +class Module: + @T.prim_func + def cast1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = A[v_i0, v_i1, v_i2] + + @T.prim_func + def decode(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), T_transpose: T.Buffer((T.int64(2560), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_1 = T.alloc_buffer((T.int64(2560), T.int64(2560))) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode_1[v_i, v_j]) + decode_1[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode_1[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] + + @T.prim_func + def decode1(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: T.Buffer((T.int64(80), T.int64(10240)), "uint32"), T_transpose: T.Buffer((T.int64(10240), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(10240))) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def decode2(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), T_transpose: T.Buffer((T.int64(2560), T.int64(10240)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(10240), T.int64(2560))) + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def divide1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], B[()]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2]) + T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()] + + @T.prim_func + def fused_decode3_matmul1(lv1352: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), lv1353: T.Buffer((T.int64(80), T.int64(50432)), "uint32"), lv1800: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) + for i, j in T.grid(T.int64(2560), T.int64(50432)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1352[v_i // T.int64(8), v_j], lv1353[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1352[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1353[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1353[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1800[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1800[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + @T.prim_func + def fused_decode4_fused_matmul7_add3(lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1364: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), lv1808: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias192: T.Buffer((T.int64(2560),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1363[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1364[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1364[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1808[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1808[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] + + @T.prim_func + def fused_decode4_fused_matmul7_add3_add4(lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1382: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias195: T.Buffer((T.int64(2560),), "float32"), lv1805: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1381[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1382[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1382[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv5[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv1805[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv1805[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_decode5_fused_matmul9_add5_gelu1(lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), lv1388: T.Buffer((T.int64(80), T.int64(10240)), "uint32"), lv1852: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias196: T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1387[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1388[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1388[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1852[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1852[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_decode6_fused_matmul10_add3_cast1_add4(lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv1394: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1857: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float32"), linear_bias197: T.Buffer((T.int64(2560),), "float32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1393[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1394[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1394[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1857[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1857[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv6[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv6[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_reshape7_squeeze1(lv1821: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(lv1821[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1821[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_slice1_cast1(lv3599: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_slice_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("slice"): + v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) + T.reads(lv3599[v_i, T.int64(0), v_k]) + T.writes(var_slice_intermediate[v_i, v__, v_k]) + var_slice_intermediate[v_i, v__, v_k] = lv3599[v_i, T.int64(0), v_k] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_slice_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = var_slice_intermediate[v_i0, v_i1, v_i2] + + @T.prim_func + def fused_transpose7_reshape8(lv1844: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(lv1844[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1844[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) + T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] + + @T.prim_func + def layer_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) + for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def reshape5(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + T.reads(A[T.int64(0), T.int64(0)]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = A[T.int64(0), T.int64(0)] + + @T.prim_func + def reshape6(A: T.Buffer((T.int64(1), T.int64(2560)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax2 % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(2560)] + + @T.prim_func + def reshape7(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] + + @T.prim_func + def softmax1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(50432))) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) + for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(A[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) + T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) + for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1]) + with T.init(): + T_softmax_expsum[v_i0, v_i1] = T.float32(0) + T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] + + @T.prim_func + def squeeze1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) + T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] + + @T.prim_func + def take_decode1(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "uint32"), C: T.Buffer((T.int64(1),), "int32"), take_decode: T.Buffer((T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i, j in T.grid(T.int64(1), T.int64(2560)): + with T.block("take_decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) + T.writes(take_decode[v_i, v_j]) + take_decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[C[v_i], v_j // T.int64(32)], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[C[v_i], v_j // T.int64(32)], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + + @T.prim_func + def transpose6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + + ########## Dynamic shape ########## + + @T.prim_func + def fused_NT_matmul1_divide_maximum_minimum(p_lv34: T.handle, p_lv35: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + m = T.meta_var(T.int64(128)) + lv34 = T.match_buffer(p_lv34, (T.int64(1), T.int64(32), n, T.int64(80))) + lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), m, T.int64(80))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv34[v_i0, v_i1, v_i2, v_k], lv35[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv34[v_i0, v_i1, v_i2, v_k] * lv35[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + + @T.prim_func + def fused_NT_matmul2_add2_gelu(p_lv51: T.handle, lv38: T.Buffer((T.int64(10240), T.int64(2560)), "float32"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(2560))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv51[v_i0, v_i1, v_k], lv38[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv38[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul3_add_cast_add1(p_lv56: T.handle, lv45: T.Buffer((T.int64(2560), T.int64(10240)), "float32"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv49: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240))) + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv56[v_i0, v_i1, v_k], lv45[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv45[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate_1[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv49[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv49[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def fused_NT_matmul4_divide2_maximum1_minimum1(lv1835: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), p_lv1836: T.handle, p_lv1806: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv1836 = T.match_buffer(p_lv1836, (T.int64(1), T.int64(32), n, T.int64(80))) + lv1806 = T.match_buffer(p_lv1806, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1835[v_i0, v_i1, v_i2, v_k], lv1836[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1835[v_i0, v_i1, v_i2, v_k] * lv1836[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) + + @T.prim_func + def fused_NT_matmul_add(p_lv7: T.handle, lv10: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv7[v_i0, v_i1, v_k], lv10[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * lv10[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] + + @T.prim_func + def fused_NT_matmul_add_add1(p_lv45: T.handle, lv31: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias3: T.Buffer((T.int64(2560),), "float32"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(2560))) + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv45[v_i0, v_i1, v_k], lv31[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv31[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def layer_norm(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) + A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) + for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(A[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] + A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] + + @T.prim_func + def matmul(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(128)) + m = T.meta_var(T.int64(32)) + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80))) + matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + @T.prim_func + def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.meta_var(T.int64(32)) + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + +# fmt: on diff --git a/mlc_llm/dispatch/llama/__init__.py b/mlc_llm/dispatch/llama/__init__.py new file mode 100644 index 0000000..2374080 --- /dev/null +++ b/mlc_llm/dispatch/llama/__init__.py @@ -0,0 +1 @@ +from .main import lookup_func as lookup diff --git a/mlc_llm/dispatch/llama/main.py b/mlc_llm/dispatch/llama/main.py new file mode 100644 index 0000000..166739b --- /dev/null +++ b/mlc_llm/dispatch/llama/main.py @@ -0,0 +1,6712 @@ +import tvm +from tvm import IRModule +from tvm.script import tir as T + + +# fmt: off +@T.prim_func +def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16") + # with T.block("root"): + var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16") + for i, j in T.grid(n, n): + with T.block("make_diag_mask_te"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads() + T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) + var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): + with T.block("T_broadcast_to"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) + T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] + + +def fused_min_max_triu_te_broadcast_to_sch_func(): + sch = tvm.tir.Schedule(fused_min_max_triu_te_broadcast_to) + b0 = sch.get_block("T_broadcast_to") + sch.reverse_compute_inline(b0) + return sch.mod["main"] + + +@T.prim_func +def rms_norm_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096),), "float32"), var_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) + rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + rxplaceholderred_temp = T.alloc_buffer((T.int64(1), n)) + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rxplaceholderred_temp"): + v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) + T.reads(rxplaceholder_1[v_bsz, v_i, v_k]) + T.writes(rxplaceholderred_temp[v_bsz, v_i]) + with T.init(): + rxplaceholderred_temp[v_bsz, v_i] = T.float32(0) + rxplaceholderred_temp[v_bsz, v_i] = rxplaceholderred_temp[v_bsz, v_i] + rxplaceholder_1[v_bsz, v_i, v_k] * rxplaceholder_1[v_bsz, v_i, v_k] + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rms_norm"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(rxplaceholder[v_k], rxplaceholder_1[v_bsz, v_i, v_k], rxplaceholderred_temp[v_bsz, v_i]) + T.writes(rms_norm_1[v_bsz, v_i, v_k]) + rms_norm_1[v_bsz, v_i, v_k] = rxplaceholder[v_k] * (rxplaceholder_1[v_bsz, v_i, v_k] / T.sqrt(rxplaceholderred_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))) + + +@T.prim_func +def rms_norm_after(var_A: T.handle, var_weight: T.Buffer((T.int64(4096),), "float32"), var_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096))) + rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("compute_o"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) + T.reads(A[v_bsz, v_i_o * T.int64(32):v_i_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + T.writes(rms_norm[v_bsz, T.int64(0) : T.int64(n), T.int64(0):T.int64(4096)]) + sq_sum_pad_local = T.alloc_buffer((T.int64(32),), scope="shared") + for bsz, i_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(16)): + for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute"): + v_i_i = T.axis.spatial(T.int64(32), i_1) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1) + T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) + T.writes(sq_sum_pad_local[v_i_i]) + with T.init(): + sq_sum_pad_local[v_i_i] = T.float32(0) + sq_sum_pad_local[v_i_i] = sq_sum_pad_local[v_i_i] + T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], T.float32(0)) * T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], T.float32(0)) + for bsz_i_fused_1, k_0 in T.grid(T.int64(32), T.int64(16)): + for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute_cache_write"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_i = T.axis.spatial(n, bsz_i_fused_1) + v_k = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1) + T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k], var_weight[v_k], sq_sum_pad_local[v_i_i]) + T.writes(rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) + if v_i_i < n: + rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] = var_weight[v_k] * (A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] / T.sqrt(sq_sum_pad_local[v_i_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))) + + +@T.prim_func +def rms_norm_fp16_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096)), "float16") + rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + rxplaceholderred_temp = T.alloc_buffer((T.int64(1), n)) + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rxplaceholderred_temp"): + v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) + T.reads(rxplaceholder_1[v_bsz, v_i, v_k]) + T.writes(rxplaceholderred_temp[v_bsz, v_i]) + with T.init(): + rxplaceholderred_temp[v_bsz, v_i] = T.float32(0) + rxplaceholderred_temp[v_bsz, v_i] = rxplaceholderred_temp[v_bsz, v_i] + T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) * T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rms_norm"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(rxplaceholder[v_k], rxplaceholder_1[v_bsz, v_i, v_k], rxplaceholderred_temp[v_bsz, v_i]) + T.writes(rms_norm_1[v_bsz, v_i, v_k]) + rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", rxplaceholder[v_k]) * (T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) / T.sqrt(rxplaceholderred_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) + + +@T.prim_func +def rms_norm_fp16_after(var_A: T.handle, var_weight: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), dtype="float16") + rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), dtype="float16") + # with T.block("root"): + for i_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("compute_o"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) + T.reads(A[v_bsz, v_i_o * T.int64(32):v_i_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + T.writes(rms_norm[v_bsz, T.int64(0) : T.int64(n), T.int64(0):T.int64(4096)]) + sq_sum_pad_local = T.alloc_buffer((T.int64(32),), scope="shared") + for bsz, i_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(16)): + for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute"): + v_i_i = T.axis.spatial(T.int64(32), i_1) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1) + T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) + T.writes(sq_sum_pad_local[v_i_i]) + with T.init(): + sq_sum_pad_local[v_i_i] = T.float32(0) + sq_sum_pad_local[v_i_i] = sq_sum_pad_local[v_i_i] + T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) * T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]), T.float32(0)) + for bsz_i_fused_1, k_0 in T.grid(T.int64(32), T.int64(16)): + for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute_cache_write"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_i = T.axis.spatial(n, bsz_i_fused_1) + v_k = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1) + T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k], var_weight[v_k], sq_sum_pad_local[v_i_i]) + T.writes(rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) + if v_i_i < n: + rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] = T.Cast("float16", T.Cast("float32", var_weight[v_k]) * (T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) / T.sqrt(sq_sum_pad_local[v_i_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) + + +@T.prim_func +def softmax_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, n)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_expsum_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * n + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n // n) + v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n % n) + v_i3 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % n) + T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * n) + T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_expsum_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) + v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) + v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) + T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) + T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + +@T.prim_func +def softmax_cast_mxn_before(p_lv37: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv37[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv37[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv37[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv37[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + +@T.prim_func +def softmax_cast_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") + # with T.block("root"): + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float32(0)) + for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): + for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) + v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) + T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) + if v_i2_o * T.int64(32) + v_i2_i < n: + T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.Cast("float16", T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i]) + + +@T.prim_func +def softmax_mxn_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m), "float16") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + +@T.prim_func +def softmax_mxn_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), dtype="float16") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") + # with T.block("root"): + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504))) + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float16(0)) + for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): + for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) + v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) + T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) + if v_i2_o * T.int64(32) + v_i2_i < n: + T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i] + + +@T.prim_func +def softmax_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n), "float16") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, n), dtype="float16") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), dtype="float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype="float16") + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype="float16") + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_maxelem_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)]) + T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) + T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504) + T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504))) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): + with T.block("T_softmax_expsum_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) + T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) + T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") + for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): + for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) + v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) + T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + with T.init(): + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0) + T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float16(0)) + for i0_i1_i2_1_fused_0 in range(T.int64(8)): + for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + with T.block("T_softmax_expsum_cache_write"): + v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) + v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) + T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) + T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) + T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) + T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * n + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_softmax_norm"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n // n) + v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n % n) + v_i3 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % n) + T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * n) + T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_1xn_before(var_inp0: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + inp0 = T.match_buffer(var_inp0, (T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(inp0[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], inp0[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inp0[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(inp0[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +@T.prim_func +def softmax_cast_1xn_before(p_lv1614: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1614 = T.match_buffer(p_lv1614, (T.int64(1), T.int64(32), T.int64(1), n)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1614[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1614[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv1614[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1614[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + + +@T.prim_func +def softmax_1xn_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), "float16") + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + + +def softmax_1xn_sch_func(f_softmax, cast_to_fp16: bool = False): + sch = tvm.tir.Schedule(f_softmax) + if cast_to_fp16: + b_cast = sch.get_block("compute") + sch.reverse_compute_inline(b_cast) + + b0 = sch.get_block("T_softmax_exp") + sch.compute_inline(b0) + b1 = sch.get_block("T_softmax_norm") + l2, l3, l4, l5 = sch.get_loops(b1) + l6, l7 = sch.split(l5, [None, 128]) + sch.bind(l7, "threadIdx.x") + b8 = sch.get_block("T_softmax_expsum") + sch.compute_at(b8, l4) + sch.set_scope(b8, 0, "shared") + l9, l10, l11, l12 = sch.get_loops(b8) + l13, l14 = sch.split(l12, [None, 128]) + sch.bind(l14, "threadIdx.x") + b15 = sch.get_block("T_softmax_maxelem") + sch.compute_at(b15, l4) + sch.set_scope(b15, 0, "shared") + l16, l17, l18, l19 = sch.get_loops(b15) + l20, l21 = sch.split(l19, [None, 128]) + sch.bind(l21, "threadIdx.x") + l22 = sch.fuse(l2, l3, l4) + sch.bind(l22, "blockIdx.x") + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] + + +@T.prim_func +def matmul1_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + matmul_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), scope="local") + rxplaceholder_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") + rxplaceholder_1_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) + v_i1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + i1_3_init * T.int64(2) + i1_4_init) + v_i2 = T.axis.spatial(T.int64(1), i2_3_init + i2_4_init) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + i3_3_init + i3_4_init) + T.reads() + T.writes(matmul_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + matmul_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + for k_0, k_1_0 in T.grid((n + T.int64(127)) // T.int64(128), T.int64(8)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): + with T.block("rxplaceholder_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(16)) + v2 = T.axis.spatial(T.int64(1), T.int64(0)) + v3 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) + T.reads(rxplaceholder[v0, v1, v2, v3]) + T.writes(rxplaceholder_pad_shared[v0, v1, v2, v3]) + rxplaceholder_pad_shared[v0, v1, v2, v3] = T.if_then_else(v3 < n, rxplaceholder[v0, v1, v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(8)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): + with T.block("rxplaceholder_1_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(128)) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(128) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(rxplaceholder_1[v0, v1, v2, v3]) + T.writes(rxplaceholder_1_pad_shared[v0, v1, v2, v3]) + rxplaceholder_1_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) + for k_1_1, i0_3, i1_3, i2_3, i3_3, k_1_2, i0_4, i1_4, i2_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(8), T.int64(1), T.int64(2), T.int64(1), T.int64(1)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) + v_i1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + i1_3 * T.int64(2) + i1_4) + v_i2 = T.axis.spatial(T.int64(1), i2_3 + i2_4) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + i3_3 + i3_4) + v_k = T.axis.reduce((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + k_1_1 * T.int64(8) + k_1_2) + T.reads(matmul_local[v_i0, v_i1, v_i2, v_i3], rxplaceholder_pad_shared[v_i0, v_i1, v_i2, v_k], rxplaceholder_1_pad_shared[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + matmul_local[v_i0, v_i1, v_i2, v_i3] = matmul_local[v_i0, v_i1, v_i2, v_i3] + rxplaceholder_pad_shared[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1_pad_shared[v_i0, v_i1, v_k, v_i3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(1), T.int64(1)): + with T.block("matmul_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + ax1) + v2 = T.axis.spatial(T.int64(1), ax2) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + ax3) + T.reads(matmul_local[v0, v1, v2, v3]) + T.writes(matmul[v0, v1, v2, v3]) + matmul[v0, v1, v2, v3] = matmul_local[v0, v1, v2, v3] + + +@T.prim_func +def matmul2_before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + inp0 = T.match_buffer(var_inp0, (T.int64(1), n, T.int64(4096))) + matmul = T.match_buffer(var_matmul, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(inp0[v_i0, v_i1, v_k], inp1[v_k, v_i2]) + T.writes(matmul[v_i0, v_i1, v_i2]) + with T.init(): + matmul[v_i0, v_i1, v_i2] = T.float32(0) + matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2] + +def matmul2_sch_func(): + sch = tvm.tir.Schedule(matmul2_before) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) + v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[2, 2, 2, 4, 1]) + l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) + v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) + l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) + v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[512, 4, 2]) + l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) + sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) + l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) + sch.bind(loop=l42, thread_axis="blockIdx.x") + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="vthread.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) + b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) + _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) + l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) + v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) + b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) + _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) + l63 = sch.fuse(l61, l62, preserve_unit_iters=True) + v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) + v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) + sch.enter_postproc() + sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") + _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) + l71, l72, l73 = sch.split(loop=l70, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l73) + sch.bind(loop=l72, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") + _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) + l79, l80, l81 = sch.split(loop=l78, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l81) + sch.bind(loop=l80, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) + _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) + sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) + _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) + _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) + _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + b119 = sch.get_block(name="matmul", func_name="main") + _, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) + b132 = sch.decompose_reduction(block=b119, loop=l123) + b1 = sch.get_block("inp0_pad") + sch.compute_inline(b1) + b2 = sch.get_block("matmul_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def matmul5_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) + matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(rxplaceholder[T.int64(0), v_i1, v_i2, v_k], rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3]) + T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[T.int64(0), v_i1, v_i2, v_k] * rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3] + + +@T.prim_func +def matmul5_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + C_pad = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), T.int64(128))) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), (n + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") + for i2_0 in range((n + T.int64(127)) // T.int64(128)): + for i0_0_i1_0_i2_1_0_i3_0_fused in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_1_3_init, i3_3_init, i0_4_init, i1_4_init, i2_1_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3_init + i1_4_init) + v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3_init * T.int64(4) + i2_1_4_init) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3_init + i3_4_init) + T.reads() + T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + for k_0, k_1_0 in T.grid((n + T.int64(127)) // T.int64(128), T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(rxplaceholder[v0, v1, v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n and v3 < n, rxplaceholder[v0, v1, v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(64)) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) + T.reads(rxplaceholder_1[v0, v1, v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) + for k_1_1, i0_3, i1_3, i2_1_3, i3_3, k_1_2, i0_4, i1_4, i2_1_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3 + i1_4) + v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3 * T.int64(4) + i2_1_4) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3 + i3_4) + v_k = T.axis.reduce((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + k_1_1 * T.int64(4) + k_1_2) + T.reads(C_pad_local[v_i0, v_i1, v_i2, v_i3], A_pad_shared[T.int64(0), v_i1, v_i2, v_k], B_pad_shared[T.int64(0), v_i1, v_k, v_i3]) + T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[v_i0, v_i1, v_i2, v_i3] = C_pad_local[v_i0, v_i1, v_i2, v_i3] + A_pad_shared[T.int64(0), v_i1, v_i2, v_k] * B_pad_shared[T.int64(0), v_i1, v_k, v_i3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(2)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + ax1) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + ax2) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3]) + T.writes(C_pad[v0, v1, v2, v3]) + C_pad[v0, v1, v2, v3] = C_pad_local[v0, v1, v2, v3] + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i2, i3 in T.grid(n, T.int64(128)): + with T.block("C_pad"): + vi0, vi1, vi2, vi3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(C_pad[vi0, vi1, vi2, vi3]) + T.writes(matmul[vi0, vi1, vi2, vi3]) + matmul[vi0, vi1, vi2, vi3] = C_pad[vi0, vi1, vi2, vi3] + +@T.prim_func +def matmul5_with_m_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + A = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) + B = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128))) + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + + +@T.prim_func +def matmul5_with_m_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128))) + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) + # with T.block("root"): + C_pad = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), T.int64(128))) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), (m + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (m + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") + for i2_0 in range((n + T.int64(127)) // T.int64(128)): + for i0_0_i1_0_i2_1_0_i3_0_fused in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_1_3_init, i3_3_init, i0_4_init, i1_4_init, i2_1_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3_init + i1_4_init) + v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3_init * T.int64(4) + i2_1_4_init) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3_init + i3_4_init) + T.reads() + T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + for k_0, k_1_0 in T.grid((m + T.int64(127)) // T.int64(128), T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial((m + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(rxplaceholder[v0, v1, v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n and v3 < m, rxplaceholder[v0, v1, v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) + v2 = T.axis.spatial((m + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(64)) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) + T.reads(rxplaceholder_1[v0, v1, v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < m, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) + for k_1_1, i0_3, i1_3, i2_1_3, i3_3, k_1_2, i0_4, i1_4, i2_1_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3 + i1_4) + v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3 * T.int64(4) + i2_1_4) + v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3 + i3_4) + v_k = T.axis.reduce((m + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + k_1_1 * T.int64(4) + k_1_2) + T.reads(C_pad_local[v_i0, v_i1, v_i2, v_i3], A_pad_shared[T.int64(0), v_i1, v_i2, v_k], B_pad_shared[T.int64(0), v_i1, v_k, v_i3]) + T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[v_i0, v_i1, v_i2, v_i3] = C_pad_local[v_i0, v_i1, v_i2, v_i3] + A_pad_shared[T.int64(0), v_i1, v_i2, v_k] * B_pad_shared[T.int64(0), v_i1, v_k, v_i3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(2)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + ax1) + v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + ax2) + v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3]) + T.writes(C_pad[v0, v1, v2, v3]) + C_pad[v0, v1, v2, v3] = C_pad_local[v0, v1, v2, v3] + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i2, i3 in T.grid(n, T.int64(128)): + with T.block("C_pad"): + vi0, vi1, vi2, vi3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(C_pad[vi0, vi1, vi2, vi3]) + T.writes(matmul[vi0, vi1, vi2, vi3]) + matmul[vi0, vi1, vi2, vi3] = C_pad[vi0, vi1, vi2, vi3] + + +@T.prim_func +def NT_matmul_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) + NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) + NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] + + +@T.prim_func +def NT_matmul_after(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) + NT_matmul_1 = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) + T.reads(rxplaceholder_1[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], rxplaceholder[T.int64(0):T.int64(4096), T.int64(0):T.int64(4096)]) + T.writes(NT_matmul_1[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") + rxplaceholder_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared") + for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): + for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(1), T.int64(2), T.int64(4), T.int64(2)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(4) + i1_1_4_init) + v_i2_i = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init * T.int64(2) + i2_4_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) + for k_0 in range(T.int64(128)): + for ax0_ax1_ax2_fused_0 in range(T.int64(8)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) % T.int64(32)) + T.reads(rxplaceholder_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + T.writes(A_pad_shared[v0, v1, v2]) + A_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, rxplaceholder_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) + for ax0_ax1_fused_0 in range(T.int64(8)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): + with T.block("rxplaceholder_shared"): + v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) + v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) + T.reads(rxplaceholder[v0, v1]) + T.writes(rxplaceholder_shared[v0, v1]) + rxplaceholder_shared[v0, v1] = rxplaceholder[v0, v1] + for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(4), T.int64(2)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(4) + i1_1_4) + v_i2_i = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3 * T.int64(2) + i2_4) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i], A_pad_shared[T.int64(0), v_i1_i, v_k_i], rxplaceholder_shared[v_i2_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i] + A_pad_shared[T.int64(0), v_i1_i, v_k_i] * rxplaceholder_shared[v_i2_i, v_k_i] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) + T.reads(C_pad_local[v0, v1, v2]) + T.writes(NT_matmul_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: + if v_i1_o * T.int64(32) + v1 < n: + NT_matmul_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = C_pad_local[v0, v1, v2] + + +@T.prim_func +def NT_matmul4_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(32000), T.int64(4096)), "float32"), var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) + NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(32000))) + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(32000), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) + NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] + + +def NT_matmul4_sch_func(): + sch = tvm.tir.Schedule(NT_matmul4_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 256, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) + v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 8, 4, 1]) + l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) + v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[668, 1, 8, 1, 6]) + l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) + v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[128, 4, 8]) + l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) + sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) + l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) + sch.bind(loop=l42, thread_axis="blockIdx.x") + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="vthread.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) + b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) + _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) + l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) + v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) + b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) + _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) + l63 = sch.fuse(l61, l62, preserve_unit_iters=True) + v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) + v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) + sch.enter_postproc() + sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") + _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) + l71, l72, l73 = sch.split(loop=l70, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l73) + sch.bind(loop=l72, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") + _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) + l79, l80, l81 = sch.split(loop=l78, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l81) + sch.bind(loop=l80, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) + _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) + sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) + _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) + _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) + _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + b119 = sch.get_block(name="NT_matmul", func_name="main") + _, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) + b132 = sch.decompose_reduction(block=b119, loop=l123) + b1 = sch.get_block("rxplaceholder_1_pad") + sch.compute_inline(b1) + b3 = sch.get_block("NT_matmul_pad") + sch.reverse_compute_inline(b3) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def NT_matmul9_before(rxplaceholder: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), rxplaceholder_1: T.Buffer((T.int64(32000), T.int64(4096)), "float32"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_k], rxplaceholder_1[v_i2, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) + NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder[v_i0, v_i1, v_k] * rxplaceholder_1[v_i2, v_k] + + +def NT_matmul9_sch_func(): + sch = tvm.tir.Schedule(NT_matmul9_before) + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) + v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) + v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[668, 1, 48, 1, 1]) + l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) + v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[64, 64, 1]) + l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) + sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) + l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) + sch.bind(loop=l42, thread_axis="blockIdx.x") + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="vthread.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) + b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) + l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) + l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) + v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) + b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) + l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) + l63 = sch.fuse(l61, l62, preserve_unit_iters=True) + v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) + v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) + sch.enter_postproc() + sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") + l66, l67, l68, l69, l70 = sch.get_loops(block=b46) + l71, l72, l73 = sch.split(loop=l70, factors=[None, 48, 2], preserve_unit_iters=True) + sch.vectorize(loop=l73) + sch.bind(loop=l72, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") + l74, l75, l76, l77, l78 = sch.get_loops(block=b56) + l79, l80, l81 = sch.split(loop=l78, factors=[None, 48, 2], preserve_unit_iters=True) + sch.vectorize(loop=l81) + sch.bind(loop=l80, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + b83, b84, b85, b86 = sch.get_child_blocks(b82) + l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) + sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) + l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) + l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) + l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + b119 = sch.get_block(name="NT_matmul", func_name="main") + l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) + b132 = sch.decompose_reduction(block=b119, loop=l123) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + + +@T.prim_func +def fused_matmul1_add1(p_lv39: T.handle, lv40: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096))) + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv39[v_i0, v_i1, v_k], lv40[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * lv40[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_matmul1_add1_sch_func(): + sch = tvm.tir.Schedule(fused_matmul1_add1) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 8, 4, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[512, 4, 2]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="matmul", func_name="main") + _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + b133 = sch.decompose_reduction(block=b120, loop=l124) + b1 = sch.get_block("lv39_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_matmul3_multiply(p_lv43: T.handle, lv46: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv48: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) + lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], lv46[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv46[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv48[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_matmul3_multiply_sch_func(): + sch = tvm.tir.Schedule(fused_matmul3_multiply) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_multiply", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 4, 2, 4, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[344, 2, 16, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[512, 1, 8]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="matmul", func_name="main") + _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + b133 = sch.decompose_reduction(block=b120, loop=l124) + b1 = sch.get_block("lv43_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_matmul3_silu(p_lv43: T.handle, lv44: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], lv44[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +def fused_matmul3_silu_sch_func(): + sch = tvm.tir.Schedule(fused_matmul3_silu) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="compute", func_name="main") + b2 = sch.get_block(name="T_multiply", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l4, l5, l6, l7 = sch.get_loops(block=b0) + v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) + v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 2, 2, 8, 1]) + l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) + v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[344, 2, 16, 1, 1]) + l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) + v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[512, 1, 8]) + l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) + sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="blockIdx.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="vthread.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) + b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) + _, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b48) + l56 = sch.fuse(l53, l54, l55, preserve_unit_iters=True) + v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) + b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) + _, l59, l60, l61, l62, l63, l64 = sch.get_loops(block=b58) + l65 = sch.fuse(l63, l64, preserve_unit_iters=True) + v66 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) + sch.compute_inline(block=b1) + v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) + l68, l69, l70 = sch.get_loops(block=b2) + l71 = sch.fuse(l68, l69, l70, preserve_unit_iters=True) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 256, 256], preserve_unit_iters=True) + sch.reorder(l73, l74, l72) + sch.bind(loop=l73, thread_axis="blockIdx.x") + sch.bind(loop=l74, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b48) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") + _, l83, l84, l85, l86, l87 = sch.get_loops(block=b58) + l88, l89, l90 = sch.split(loop=l87, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l90) + sch.bind(loop=l89, thread_axis="threadIdx.x") + b91 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b91, ann_key="meta_schedule.unroll_explicit") + _, b92, b93, b94, b95, _, b96 = sch.get_child_blocks(b91) + _, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b92) + sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) + _, l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b93) + sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) + _, l111, l112, l113, l114, l115, l116, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b94) + sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) + _, l123, l124, l125, l126, l127, l128 = sch.get_loops(block=b95) + sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) + l129, l130, l131 = sch.get_loops(block=b96) + sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) + b132 = sch.get_block(name="matmul", func_name="main") + _, l133, l134, l135, l136, l137, l138, l139, l140, l141, l142, l143, l144 = sch.get_loops(block=b132) + b145 = sch.decompose_reduction(block=b132, loop=l136) + b1 = sch.get_block("lv43_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_matmul4_add1(p_lv49: T.handle, lv50: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv42: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008))) + lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], lv50[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv50[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv42[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_matmul4_add1_sch_func(): + sch = tvm.tir.Schedule(fused_matmul4_add1) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 4, 8, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[1376, 2, 4]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="matmul", func_name="main") + _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + b133 = sch.decompose_reduction(block=b120, loop=l124) + b1 = sch.get_block("lv49_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul_add1_before(p_lv39: T.handle, linear_weight3: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096))) + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv39[v_i0, v_i1, v_k], linear_weight3[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * linear_weight3[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_NT_matmul_add1_after(p_lv33: T.handle, linear_weight3: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv33 = T.match_buffer(p_lv33, (T.int64(1), n, T.int64(4096))) + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) + T.reads(lv33[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight3[T.int64(0):T.int64(4096), T.int64(0):T.int64(4096)], lv2[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + T.writes(var_T_add_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") + lv33_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") + linear_weight3_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared") + for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(1), T.int64(4), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + i1_1_3_init + i1_1_4_init) + v_i2_i = T.axis.spatial(T.int64(4096), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init) + T.reads() + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) + for k_0 in range(T.int64(128)): + for ax0_ax1_ax2_fused_0 in range(T.int64(8)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(2)): + with T.block("lv33_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) % T.int64(32)) + T.reads(lv33[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + T.writes(lv33_pad_shared[v0, v1, v2]) + lv33_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv33[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) + for ax0_ax1_fused_0 in range(T.int64(8)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): + with T.block("linear_weight3_shared"): + v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) + v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) + T.reads(linear_weight3[v0, v1]) + T.writes(linear_weight3_shared[v0, v1]) + linear_weight3_shared[v0, v1] = linear_weight3[v0, v1] + for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(4), T.int64(4), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + i1_1_3 + i1_1_4) + v_i2_i = T.axis.spatial(T.int64(4096), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv33_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight3_shared[v_i2_i, v_k_i]) + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv33_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight3_shared[v_i2_i, v_k_i] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + ax1) + v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) + T.reads(lv2[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: + if v_i1_o * T.int64(32) + v1 < n: + var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv2[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] + var_NT_matmul_intermediate_pad_local[v0, v1, v2] + + +@T.prim_func +def fused_NT_matmul1_divide_add_maximum_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128))) + lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), n, T.int64(128))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv28[T.int64(0), v_i1, v_i2, v_k], lv29[T.int64(0), v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[T.int64(0), v_i1, v_i2, v_k] * lv29[T.int64(0), v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv5[v_ax0, T.int64(0), v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + + +@T.prim_func +def fused_NT_matmul1_divide_add_maximum_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128))) + lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), n, T.int64(128))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n)) + # with T.block("root"): + for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((n + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((n + T.int64(31)) // T.int64(32))) + v_i3_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((n + T.int64(31)) // T.int64(32))) + T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") + for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < n, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) + v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) + v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: + if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < n: + var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088388349161020605) + lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3], T.float32(-3.4028234663852886e+38)) + +@T.prim_func +def fused_NT_matmul1_divide_add_maximum_with_m_before(p_lv30: T.handle, p_lv31: T.handle, p_lv7: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv30 = T.match_buffer(p_lv30, (T.int64(1), T.int64(32), n, T.int64(128))) + m = T.int64() + lv31 = T.match_buffer(p_lv31, (T.int64(1), T.int64(32), m, T.int64(128))) + lv7 = T.match_buffer(p_lv7, (T.int64(1), T.int64(1), n, m)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv30[v_i0, v_i1, v_i2, v_k], lv31[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv30[v_i0, v_i1, v_i2, v_k] * lv31[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv7[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv7[v_ax0, T.int64(0), v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + +@T.prim_func +def fused_NT_matmul1_divide_add_maximum_with_m_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128))) + lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) + v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) + T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") + for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) + v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) + v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: + if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: + var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088388349161020605) + lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3], T.float32(-3.4028234663852886e+38)) + + +@T.prim_func +def fused_NT_matmul6_divide1_add2_maximum1_before(lv2732: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv2733: T.handle, p_lv2709: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv2733 = T.match_buffer(p_lv2733, (T.int64(1), T.int64(32), n, T.int64(128))) + lv2709 = T.match_buffer(p_lv2709, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv2732[T.int64(0), v_i1, v_i2, v_k], lv2733[T.int64(0), v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2732[T.int64(0), v_i1, v_i2, v_k] * lv2733[T.int64(0), v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2709[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv2709[v_ax0, T.int64(0), v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + + +@T.prim_func +def fused_NT_matmul6_divide1_add2_maximum1_after(lv2732: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv2733: T.handle, p_lv2709: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv2733 = T.match_buffer(p_lv2733, (T.int64(1), T.int64(32), n, T.int64(128))) + lv2709 = T.match_buffer(p_lv2709, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32)), scope="local") + lv2732_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), scope="shared") + lv2733_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(128)), scope="shared") + for i3_0 in range((n + T.int64(31)) // T.int64(32)): + for i0_0_i1_0_i2_0_i3_1_0_fused in T.thread_binding(T.int64(32), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_i3_1_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_1_2_fused in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_3_init, i3_1_3_init, i0_4_init, i1_4_init, i2_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + i1_3_init + i1_4_init) + v_i2 = T.axis.spatial(T.int64(1), i2_3_init + i2_4_init) + v_i3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + i3_1_3_init + i3_1_4_init) + T.reads() + T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + for k_0 in range(T.int64(8)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("lv2732_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(64) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(16)) + v2 = T.axis.spatial(T.int64(1), T.int64(0)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(64) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) + T.reads(lv2732[v0, v1, v2, v3]) + T.writes(lv2732_shared[v0, v1, v2, v3]) + lv2732_shared[v0, v1, v2, v3] = lv2732[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2733_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(128)) + v2 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(128) // T.int64(16)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) + T.reads(lv2733[v0, v1, v2, v3]) + T.writes(lv2733_pad_shared[v0, v1, v2, v3]) + lv2733_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, lv2733[v0, v1, v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_3, i3_1_3, k_2, i0_4, i1_4, i2_4, i3_1_4 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) + v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + i1_3 + i1_4) + v_i2 = T.axis.spatial(T.int64(1), i2_3 + i2_4) + v_i3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + i3_1_3 + i3_1_4) + v_k = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1 * T.int64(16) + k_2) + T.reads(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3], lv2732_shared[v_i0, v_i1, v_i2, v_k], lv2733_pad_shared[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] + lv2732_shared[v_i0, v_i1, v_i2, v_k] * lv2733_pad_shared[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + ax1) + v2 = T.axis.spatial(T.int64(1), ax2) + v3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2, v3]) + T.writes(var_NT_matmul_intermediate[v0, v1, v2, v3]) + if v3 < n: + var_NT_matmul_intermediate[v0, v1, v2, v3] = var_NT_matmul_intermediate_pad_local[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(32) + ax0_ax1_ax2_ax3_fused_1) // n) + v_ax2 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax3 = T.axis.spatial(n, (ax0_ax1_ax2_ax3_fused_0 * T.int64(32) + ax0_ax1_ax2_ax3_fused_1) % n) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2709[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + lv2709[v_ax0, T.int64(0), v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + + +@T.prim_func +def fused_NT_matmul2_multiply_before(p_lv43: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv48: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) + lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], linear_weight6[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight6[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv48[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_NT_matmul2_multiply_after(p_lv37: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv42: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096))) + lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(11008))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) + T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight6[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)], lv42[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) + T.writes(var_T_multiply_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="local") + lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") + linear_weight6_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="shared") + for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): + for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(2)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init) + v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init * T.int64(2) + i2_4_init) + T.reads() + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) + for k_0 in range(T.int64(128)): + for ax0_ax1_ax2_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv37_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) + T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + T.writes(lv37_pad_shared[v0, v1, v2]) + lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) + for ax0_ax1_fused_0 in range(T.int64(8)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): + with T.block("linear_weight6_shared"): + v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) + v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) + T.reads(linear_weight6[v0, v1]) + T.writes(linear_weight6_shared[v0, v1]) + linear_weight6_shared[v0, v1] = linear_weight6[v0, v1] + for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(2), T.int64(2)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4) + v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3 * T.int64(2) + i2_4) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight6_shared[v_i2_i, v_k_i]) + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight6_shared[v_i2_i, v_k_i] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) + T.reads(lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: + if v_i1_o * T.int64(32) + v1 < n: + var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] * var_NT_matmul_intermediate_pad_local[v0, v1, v2] + + +@T.prim_func +def fused_NT_matmul2_silu_before(p_lv43: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + compute = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], linear_weight4[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight4[v_i2, v_k] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_NT_matmul2_silu_after(p_lv37: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096))) + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) + for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) + T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight4[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="local") + lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") + linear_weight4_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="shared") + for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): + for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(4), T.int64(2), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init) + v_i2_i = T.axis.spatial(T.int64(11008), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init) + T.reads() + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) + for k_0 in range(T.int64(128)): + for ax0_ax1_ax2_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv37_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) + T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + T.writes(lv37_pad_shared[v0, v1, v2]) + lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) + for ax0_ax1_fused_0 in range(T.int64(8)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): + with T.block("linear_weight4_shared"): + v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) + v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) + T.reads(linear_weight4[v0, v1]) + T.writes(linear_weight4_shared[v0, v1]) + linear_weight4_shared[v0, v1] = linear_weight4[v0, v1] + for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(4), T.int64(4), T.int64(1), T.int64(2), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4) + v_i2_i = T.axis.spatial(T.int64(11008), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3) + v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight4_shared[v_i2_i, v_k_i]) + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight4_shared[v_i2_i, v_k_i] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) + T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: + if v_i1_o * T.int64(32) + v1 < n: + var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_0 in range((n * T.int64(11008) + T.int64(65535)) // T.int64(65536)): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) // T.int64(11008)) + v_ax2 = T.axis.spatial(T.int64(11008), (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) % T.int64(11008)) + T.where((ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1) * T.int64(256) + ax0_ax1_ax2_fused_2 < n * T.int64(11008)) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * T.sigmoid(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + + +@T.prim_func +def fused_NT_matmul3_add1_before(p_lv49: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv42: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008))) + lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], linear_weight5[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * linear_weight5[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_NT_matmul3_add1_after(p_lv43: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv36: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(11008))) + lv36 = T.match_buffer(p_lv36, (T.int64(1), n, T.int64(4096))) + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) + # with T.block("root"): + for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) + T.reads(lv43[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)], linear_weight5[T.int64(0):T.int64(4096), T.int64(0):T.int64(11008)], lv36[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + T.writes(var_T_add_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) + var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") + lv43_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="shared") + linear_weight5_shared = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="shared") + for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3_init + i1_1_4_init) + v_i2_i = T.axis.spatial(T.int64(4096), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3_init) + T.reads() + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) + for k_0 in range(T.int64(344)): + for ax0_ax1_ax2_fused_0 in range(T.int64(4)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv43_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) + v2 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) + T.reads(lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + T.writes(lv43_pad_shared[v0, v1, v2]) + lv43_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) + for ax0_ax1_fused_0 in range(T.int64(8)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): + with T.block("linear_weight5_shared"): + v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) + v1 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) + T.reads(linear_weight5[v0, v1]) + T.writes(linear_weight5_shared[v0, v1]) + linear_weight5_shared[v0, v1] = linear_weight5[v0, v1] + for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3 + i1_1_4) + v_i2_i = T.axis.spatial(T.int64(4096), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3) + v_k_i = T.axis.reduce(T.int64(11008), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv43_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight5_shared[v_i2_i, v_k_i]) + T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv43_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight5_shared[v_i2_i, v_k_i] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(2)): + with T.block("var_NT_matmul_intermediate_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + ax1) + v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + ax2) + T.reads(lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: + if v_i1_o * T.int64(32) + v1 < n: + var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] + var_NT_matmul_intermediate_pad_local[v0, v1, v2] + + + +@T.prim_func +def fused_NT_matmul_divide_maximum_minimum_cast_before(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + +def fused_NT_matmul_divide_maximum_minimum_cast_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul_divide_maximum_minimum_cast_before) + b_cast = sch.get_block("compute") + sch.reverse_compute_inline(b_cast) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[4, 4, 8]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) + l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) + l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + + # inline ewise + sch.reverse_compute_inline(b2) + # l83, l84, l85, l86 = sch.get_loops(block=b2) + # l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) + # v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + # l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) + # sch.bind(loop=l89, thread_axis="blockIdx.x") + # sch.bind(loop=l90, thread_axis="threadIdx.x") + + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) + l96, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + _, l98, l99, l100, l101, l102 = sch.get_loops(block=b71) + l103, l104 = sch.split(loop=l102, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l104, thread_axis="threadIdx.x") + b105 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.unroll_explicit") + _, b106, b107, b108, b109, _ = sch.get_child_blocks(b105) + _, l111, l112, l113, l114, l115, l116 = sch.get_loops(block=b106) + sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) + _, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b107) + sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) + _, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) + _, l137, l138, l139, l140, l141, l142, l143 = sch.get_loops(block=b109) + sch.annotate(block_or_loop=l137, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l137, ann_key="pragma_unroll_explicit", ann_val=1) + + b146 = sch.get_block(name="NT_matmul", func_name="main") + l0, l147, l148, l149, l150, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160 = sch.get_loops(block=b146) + sch.bind(l0, "blockIdx.y") + b161 = sch.decompose_reduction(block=b146, loop=l150) + + b1 = sch.get_block("lv1606_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + +@T.prim_func +def fused_NT_matmul_divide_maximum_minimum_before(lv1540: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv1541: T.handle, p_lv1517: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1541 = T.match_buffer(p_lv1541, (T.int64(1), T.int64(32), n, T.int64(128))) + lv1517 = T.match_buffer(p_lv1517, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1540[v_i0, v_i1, v_i2, v_k], lv1541[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1540[v_i0, v_i1, v_i2, v_k] * lv1541[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) + +def fused_NT_matmul_divide_maximum_minimum_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul_divide_maximum_minimum_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[4, 4, 8]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) + l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) + l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + + # inline ewise + sch.reverse_compute_inline(b2) + # l83, l84, l85, l86 = sch.get_loops(block=b2) + # l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) + # v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + # l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) + # sch.bind(loop=l89, thread_axis="blockIdx.x") + # sch.bind(loop=l90, thread_axis="threadIdx.x") + + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) + l96, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + _, l98, l99, l100, l101, l102 = sch.get_loops(block=b71) + l103, l104 = sch.split(loop=l102, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l104, thread_axis="threadIdx.x") + b105 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.unroll_explicit") + _, b106, b107, b108, b109, _ = sch.get_child_blocks(b105) + _, l111, l112, l113, l114, l115, l116 = sch.get_loops(block=b106) + sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) + _, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b107) + sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) + _, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) + _, l137, l138, l139, l140, l141, l142, l143 = sch.get_loops(block=b109) + sch.annotate(block_or_loop=l137, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l137, ann_key="pragma_unroll_explicit", ann_val=1) + + b146 = sch.get_block(name="NT_matmul", func_name="main") + l0, l147, l148, l149, l150, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160 = sch.get_loops(block=b146) + sch.bind(l0, "blockIdx.y") + b161 = sch.decompose_reduction(block=b146, loop=l150) + + b1 = sch.get_block("lv1541_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + +@T.prim_func +def fused_NT_matmul1_add3_before(p_lv39: T.handle, lv1848: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv39[v_i0, v_i1, v_k], lv1848[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * lv1848[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_NT_matmul1_add3_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul1_add3_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 2, 8, 1, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[256, 1, 4, 4, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[256, 1, 16]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 32, 2], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 4], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="NT_matmul", func_name="main") + l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + sch.bind(l0, "blockIdx.y") + b133 = sch.decompose_reduction(block=b120, loop=l124) + + b1 = sch.get_block("lv39_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul2_divide1_add2_maximum1_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n), "float16") + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv5[v_ax0, T.int64(0), v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + + +def fused_NT_matmul2_divide1_add2_maximum1_sch_func(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 1, 32, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l4, [None, 32]) + sch.reorder(l6, l8, l1, l2, l7, l9, l5) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_add", func_name="main") + b3 = sch.get_block(name="T_maximum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[8, 16, 1]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + _, _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) + l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + _, _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) + l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + l83, l84, l85, l86 = sch.get_loops(block=b2) + l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) + v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) + sch.bind(loop=l89, thread_axis="blockIdx.x") + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + _, _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) + l96, l97, l98 = sch.split(loop=l95, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l98) + sch.bind(loop=l97, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + _, _, l99, l100, l101, l102, l103 = sch.get_loops(block=b71) + l104, l105, l106 = sch.split(loop=l103, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l106) + sch.bind(loop=l105, thread_axis="threadIdx.x") + b107 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b107, ann_key="meta_schedule.unroll_explicit") + _, _, b108, b109, b110, b111, _, b112 = sch.get_child_blocks(b107) + _, _, l113, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b109) + sch.annotate(block_or_loop=l120, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l120, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b110) + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l141, l142, l143, l144, l145, l146, l147 = sch.get_loops(block=b111) + sch.annotate(block_or_loop=l141, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l141, ann_key="pragma_unroll_explicit", ann_val=1) + l148, l149 = sch.get_loops(block=b112) + sch.annotate(block_or_loop=l148, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l148, ann_key="pragma_unroll_explicit", ann_val=1) + b150 = sch.get_block(name="NT_matmul", func_name="main") + l0, l1, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160, l161, l162, l163, l164 = sch.get_loops(block=b150) + l2 = sch.fuse(l0, l1) + sch.bind(l2, "blockIdx.y") + b165 = sch.decompose_reduction(block=b150, loop=l154) + + b1 = sch.get_block("lv28_pad") + sch.compute_inline(b1) + b2 = sch.get_block("lv29_pad") + sch.compute_inline(b2) + b3 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b3) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_cast3_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + m = T.int64() + lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) + v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) + T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16", scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float16", scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float16", scope="shared") + for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) + v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) + v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: + if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: + var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.Cast("float32", T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3])) + +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128))) + m = T.int64() + lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128)), "float32") + lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128)), "float32") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float32") + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) + v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) + T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float32", scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") + for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) + v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) + v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: + if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: + var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + +def fused_NT_matmul2_divide1_add2_maximum1_sch_func(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 1, 32, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l4, [None, 32]) + sch.reorder(l6, l8, l1, l2, l7, l9, l5) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_add", func_name="main") + b3 = sch.get_block(name="T_maximum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[8, 16, 1]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + _, _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) + l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + _, _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) + l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + l83, l84, l85, l86 = sch.get_loops(block=b2) + l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) + v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) + sch.bind(loop=l89, thread_axis="blockIdx.x") + sch.bind(loop=l90, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + _, _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) + l96, l97, l98 = sch.split(loop=l95, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l98) + sch.bind(loop=l97, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + _, _, l99, l100, l101, l102, l103 = sch.get_loops(block=b71) + l104, l105, l106 = sch.split(loop=l103, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l106) + sch.bind(loop=l105, thread_axis="threadIdx.x") + b107 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b107, ann_key="meta_schedule.unroll_explicit") + _, _, b108, b109, b110, b111, _, b112 = sch.get_child_blocks(b107) + _, _, l113, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b109) + sch.annotate(block_or_loop=l120, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l120, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b110) + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + _, _, l141, l142, l143, l144, l145, l146, l147 = sch.get_loops(block=b111) + sch.annotate(block_or_loop=l141, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l141, ann_key="pragma_unroll_explicit", ann_val=1) + l148, l149 = sch.get_loops(block=b112) + sch.annotate(block_or_loop=l148, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l148, ann_key="pragma_unroll_explicit", ann_val=1) + b150 = sch.get_block(name="NT_matmul", func_name="main") + l0, l1, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160, l161, l162, l163, l164 = sch.get_loops(block=b150) + l2 = sch.fuse(l0, l1) + sch.bind(l2, "blockIdx.y") + b165 = sch.decompose_reduction(block=b150, loop=l154) + + b1 = sch.get_block("lv28_pad") + sch.compute_inline(b1) + b2 = sch.get_block("lv29_pad") + sch.compute_inline(b2) + b3 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b3) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul3_multiply1_before(p_lv43: T.handle, lv1866: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv48: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)), "float16") + lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008)), "float16") + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], lv1866[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv1866[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv48[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_NT_matmul3_multiply1_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul3_multiply1_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_multiply", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 8, 2, 2]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[344, 4, 8, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[128, 16, 2]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="NT_matmul", func_name="main") + l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + sch.bind(l0, "blockIdx.y") + b133 = sch.decompose_reduction(block=b120, loop=l124) + + b1 = sch.get_block("lv43_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul3_silu1_before(p_lv43: T.handle, lv1857: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)), "float16") + var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") + compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], lv1857[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv1857[v_i2, v_k] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +def fused_NT_matmul3_silu1_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul3_silu1_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="compute", func_name="main") + b2 = sch.get_block(name="T_multiply", func_name="main") + b3 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l4, l5, l6, l7 = sch.get_loops(block=b0) + v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) + v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 8, 4, 1]) + l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) + v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[344, 4, 8, 1, 1]) + l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) + v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[128, 16, 2]) + l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) + sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="blockIdx.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="vthread.x") + l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) + sch.bind(loop=l46, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) + b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) + _, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b48) + l56 = sch.fuse(l53, l54, l55, preserve_unit_iters=True) + v57 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) + b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) + _, l59, l60, l61, l62, l63, l64 = sch.get_loops(block=b58) + l65 = sch.fuse(l63, l64, preserve_unit_iters=True) + v66 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) + sch.compute_inline(block=b1) + v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) + + # reverse compute inline the silu part + sch.reverse_compute_inline(b2) + # l68, l69, l70 = sch.get_loops(block=b2) + # l71 = sch.fuse(l68, l69, l70, preserve_unit_iters=True) + # l72, l73, l74 = sch.split(loop=l71, factors=[None, 256, 256], preserve_unit_iters=True) + #sch.reorder(l73, l74, l72) + # sch.bind(loop=l73, thread_axis="blockIdx.x") + # sch.bind(loop=l74, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b48) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") + _, l83, l84, l85, l86, l87 = sch.get_loops(block=b58) + l88, l89, l90 = sch.split(loop=l87, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l90) + sch.bind(loop=l89, thread_axis="threadIdx.x") + b91 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b91, ann_key="meta_schedule.unroll_explicit") + _, b92, b93, b94, b95, _ = sch.get_child_blocks(b91) + _, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b92) + sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) + _, l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b93) + sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) + _, l111, l112, l113, l114, l115, l116, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b94) + sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) + _, l123, l124, l125, l126, l127, l128 = sch.get_loops(block=b95) + sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) + # l129, l130, l131 = sch.get_loops(block=b96) + # sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=16) + # sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) + b132 = sch.get_block(name="NT_matmul", func_name="main") + l0, l133, l134, l135, l136, l137, l138, l139, l140, l141, l142, l143, l144 = sch.get_loops(block=b132) + sch.bind(l0, "blockIdx.y") + b145 = sch.decompose_reduction(block=b132, loop=l136) + + b1 = sch.get_block("lv43_pad") + sch.compute_inline(b1) + + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_NT_matmul4_add3_before(p_lv49: T.handle, lv1875: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv42: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008)), "float16") + lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096)), "float16") + var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv49[v_i0, v_i1, v_k], lv1875[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1875[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +def fused_NT_matmul4_add3_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul4_add3_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l3, l4, l5, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 8, 1, 4]) + l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 8, 2, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[688, 16, 1]) + l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="blockIdx.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="vthread.x") + l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) + sch.bind(loop=l45, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) + b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) + _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) + l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) + v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) + b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) + l64 = sch.fuse(l62, l63, preserve_unit_iters=True) + v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + sch.reverse_compute_inline(block=b1) + v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) + sch.enter_postproc() + sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") + _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) + l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l74) + sch.bind(loop=l73, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) + l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) + sch.vectorize(loop=l82) + sch.bind(loop=l81, thread_axis="threadIdx.x") + b83 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") + _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) + _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) + _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) + _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) + _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) + sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) + b120 = sch.get_block(name="NT_matmul", func_name="main") + l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) + sch.bind(l0, "blockIdx.y") + b133 = sch.decompose_reduction(block=b120, loop=l124) + + b1 = sch.get_block("lv49_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def matmul1_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] + + +def matmul1_fp16_sch_func(): + sch = tvm.tir.Schedule(matmul1_fp16_before) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 1, 1, 1, 128]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + sch.split(l5, [None, 128]) + + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4, l5, ko, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[4, 16, 2]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, ko, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + l58, l59, l60, _, l61, l62, l63, l64, l65 = sch.get_loops(block=b57) + l66 = sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + l69, l70, l71, _, l72, l73, l74, l75, l76 = sch.get_loops(block=b68) + l77 = sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + l80, l81, l82, _, l83, l84 = sch.get_loops(block=b57) + l85, l86 = sch.split(loop=l84, factors=[None, 256], preserve_unit_iters=True) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + l87, l88, l89, _, l90, l91 = sch.get_loops(block=b68) + l92, l93, l94 = sch.split(loop=l91, factors=[None, 256, 2], preserve_unit_iters=True) + sch.vectorize(loop=l94) + sch.bind(loop=l93, thread_axis="threadIdx.x") + b95 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") + _, _, b96, b97, b98, b99 = sch.get_child_blocks(b95) + l100, l101, l102, _, l103, l104, l105 = sch.get_loops(block=b96) + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + l106, l107, l108, _, l109, l110, l111, l112 = sch.get_loops(block=b97) + sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) + l113, l114, l115, _, l116, l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b98) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b99) + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + b134 = sch.get_block(name="matmul", func_name="main") + l135, l136, l137, ko, l138, l139, l140, l141, l142, l143, l144, l145, l146, l147, l148 = sch.get_loops(block=b134) + b149 = sch.decompose_reduction(block=b134, loop=ko) + + b1 = sch.get_block("rxplaceholder_pad") + sch.compute_inline(b1) + b2 = sch.get_block("rxplaceholder_1_pad") + sch.compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def matmul8_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n), "float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), n): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] + +@T.prim_func +def matmul8_with_m_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m), "float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") + matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] + +def matmul8_fp16_sch_func(func): + sch = tvm.tir.Schedule(func) + b0 = sch.get_block("matmul") + sch.pad_einsum(b0, [1, 1, 32, 1, 128]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l3, [None, 32]) + l8, l9 = sch.split(l5, [None, 128]) + sch.reorder(l6, l1, l2, l7, l4, l8, l9) + + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, l5, ko, l6 = sch.get_loops(block=b0) + v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) + v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) + l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) + v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 4, 2, 4]) + l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) + v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[4, 1, 16, 2, 1]) + l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) + v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[16, 1, 8]) + l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) + sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, ko, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) + l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) + sch.bind(loop=l53, thread_axis="blockIdx.x") + l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) + sch.bind(loop=l54, thread_axis="vthread.x") + l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) + sch.bind(loop=l55, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) + b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) + _, l58, l59, l60, _, l61, l62, l63, l64, l65 = sch.get_loops(block=b57) + l66 = sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) + v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) + b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) + _, l69, l70, l71, _, l72, l73, l74, l75, l76 = sch.get_loops(block=b68) + l77 = sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) + v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) + v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) + sch.enter_postproc() + sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") + _, l80, l81, l82, _, l83, l84 = sch.get_loops(block=b57) + l85, l86, l87 = sch.split(loop=l84, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l87) + sch.bind(loop=l86, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") + _, l88, l89, l90, _, l91, l92 = sch.get_loops(block=b68) + l93, l94 = sch.split(loop=l92, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l94, thread_axis="threadIdx.x") + b95 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") + _, _, b96, b97, b98, b99, _ = sch.get_child_blocks(b95) + _, l100, l101, l102, _, l103, l104, l105, l106 = sch.get_loops(block=b96) + sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) + _, l107, l108, l109, _, l110, l111, l112 = sch.get_loops(block=b97) + sch.annotate(block_or_loop=l107, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l107, ann_key="pragma_unroll_explicit", ann_val=1) + _, l113, l114, l115, _, l116, l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b98) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + _, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b99) + sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) + b134 = sch.get_block(name="matmul", func_name="main") + l0, l135, l136, l137, ko, l138, l139, l140, l141, l142, l143, l144, l145, l146, l147, l148 = sch.get_loops(block=b134) + sch.bind(l0, "blockIdx.y") + b149 = sch.decompose_reduction(block=b134, loop=ko) + + b1 = sch.get_block("rxplaceholder_pad") + sch.compute_inline(b1) + b2 = sch.get_block("rxplaceholder_1_pad") + sch.compute_inline(b2) + b3 = sch.get_block("matmul_pad") + sch.reverse_compute_inline(b3) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def NT_matmul1_fp16_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096)), "float16") + NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16") + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] + + +def NT_matmul1_fp16_sch_func(): + sch = tvm.tir.Schedule(NT_matmul1_fp16_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 32, 1, 1]) + l1, l2, l3, l4 = sch.get_loops(b0) + l5, l6 = sch.split(l2, [None, 32]) + sch.reorder(l5, l1, l6, l3, l4) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l2, l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) + v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 4, 2, 4]) + l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) + v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[128, 1, 16, 1, 2]) + l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) + v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[512, 2, 4]) + l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) + sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) + l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) + sch.bind(loop=l42, thread_axis="blockIdx.x") + l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) + sch.bind(loop=l43, thread_axis="vthread.x") + l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) + sch.bind(loop=l44, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) + b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) + _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) + l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) + v55 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) + b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) + _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) + l63 = sch.fuse(l61, l62, preserve_unit_iters=True) + v64 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) + sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) + v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) + sch.enter_postproc() + sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") + _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) + l71, l72, l73 = sch.split(loop=l70, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l73) + sch.bind(loop=l72, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") + _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) + l79, l80, l81 = sch.split(loop=l78, factors=[None, 64, 4], preserve_unit_iters=True) + sch.vectorize(loop=l81) + sch.bind(loop=l80, thread_axis="threadIdx.x") + b82 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") + _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) + _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) + sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) + _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) + sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) + _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) + sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) + _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) + sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) + b119 = sch.get_block(name="NT_matmul", func_name="main") + l0, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) + sch.bind(l0, "blockIdx.y") + b132 = sch.decompose_reduction(block=b119, loop=l123) + + b1 = sch.get_block("rxplaceholder_1_pad") + sch.compute_inline(b1) + b2 = sch.get_block("NT_matmul_pad") + sch.reverse_compute_inline(b2) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def decode6(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(4096))) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +@T.prim_func +def decode7(rxplaceholder: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008))) + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +@T.prim_func +def decode8(rxplaceholder: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096))) + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +@T.prim_func +def decode4_fp16(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "float16"), rxplaceholder_2: T.Buffer((T.int64(128), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + +@T.prim_func +def decode5_fp16(rxplaceholder: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(11008)), "float16"), rxplaceholder_2: T.Buffer((T.int64(128), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + +@T.prim_func +def decode6_fp16(rxplaceholder: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(344), T.int64(4096)), "float16"), rxplaceholder_2: T.Buffer((T.int64(344), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +@T.prim_func +def decode_int3_fp16(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode_1[v_i, v_j]) + decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode_1[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] + +@T.prim_func +def decode1_int3_fp16(A: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + +@T.prim_func +def decode2_int3_fp16(A: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +@T.prim_func +def decode_int3_int16_fp16(A: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode_1[v_i, v_j]) + decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode_1[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] + +@T.prim_func +def decode1_int3_int16_fp16(A: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + +@T.prim_func +def decode2_int3_int16_fp16(A: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + +def decode_sch_func(orig_func): + sch = tvm.tir.Schedule(orig_func) + b0 = sch.get_block(name="decode", func_name="main") + l1, l2 = sch.get_loops(block=b0) + l3, l4 = sch.split(loop=l1, factors=[None, 8], preserve_unit_iters=True) + v5, v6, v7 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=4, decision=[32, 8, 2]) + l8, l9, l10 = sch.split(loop=l3, factors=[v5, v6, v7], preserve_unit_iters=True) + v11, v12 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[256, 16]) + l13, l14 = sch.split(loop=l2, factors=[v11, v12], preserve_unit_iters=True) + sch.reorder(l8, l13, l9, l14, l10, l4) + sch.bind(loop=l8, thread_axis="blockIdx.y") + sch.bind(loop=l13, thread_axis="blockIdx.x") + sch.bind(loop=l9, thread_axis="threadIdx.y") + sch.bind(loop=l14, thread_axis="threadIdx.x") + sch.unroll(loop=l4) + b15 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="shared") + sch.compute_inline(block=b15) + b16 = sch.get_block(name="T_transpose", func_name="main") + sch.reverse_compute_at(block=b16, loop=l13, preserve_unit_loops=True, index=-1) + b17 = sch.get_block(name="T_transpose", func_name="main") + l18, l19, l20, l21 = sch.get_loops(block=b17) + l22 = sch.fuse(l20, l21, preserve_unit_iters=True) + l23, l24, l25 = sch.split(loop=l22, factors=[None, v12, 4], preserve_unit_iters=True) + sch.bind(loop=l24, thread_axis="threadIdx.x") + sch.vectorize(loop=l25) + sch.storage_align(block=b0, buffer_index=0, axis=0, factor=32, offset=1) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +@T.prim_func +def fused_decode3_matmul1_before(lv2931: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv2932: T.Buffer((T.int64(128), T.int64(32000)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000))) + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2931[v_i // T.int64(8), v_j], lv2932[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv2932[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv2932[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1511[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1511[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode3_matmul1_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv1124: T.Buffer((T.int64(128), T.int64(32000)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local") + var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local") + lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv1511_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv1511[v0, v1, v2]) + T.writes(lv1511_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("var_decode_intermediate_pad"): + v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1123[v0 // T.int64(8), v1], lv1124[v0 // T.int64(32), v1]) + T.writes(var_decode_intermediate_pad_local[v0, v1]) + var_decode_intermediate_pad_local[v0, v1] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1124[v0 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1124[v0 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_pad_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_pad_local[v0, v1, v2] + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_before(lv3184: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv3185: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv452: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv3184[v_i // T.int64(8), v_j], lv3185[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv3184[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv3185[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv3185[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv452[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv452[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2710[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2710[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_after(lv1143: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv1144: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") + lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv3_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv3[v0, v1, v2]) + T.writes(lv3_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv3_shared[v0, v1, v2] = lv3[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1143[v_j // T.int64(8), v_i], lv1144[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1144[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1144[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode4_matmul5_before(lv3166: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv3167: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv3166[v_i // T.int64(8), v_j], lv3167[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv3166[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv3167[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv3167[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2712[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2712[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode4_matmul5_after(lv1128: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv1129: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") + lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2712_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2712[v0, v1, v2]) + T.writes(lv2712_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2712_shared[v0, v1, v2] = lv2712[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1128[v_j // T.int64(8), v_i], lv1129[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1129[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1129[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_before(lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1618: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1617[v_i // T.int64(8), v_j], lv1618[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1617[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1618[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1618[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_after(lv1153: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1154: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1153[v_j // T.int64(8), v_i], lv1154[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1154[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1154[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_before(lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1612: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1611[v_i // T.int64(8), v_j], lv1612[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1612[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1612[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_after(lv1148: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1149: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1148[v_j // T.int64(8), v_i], lv1149[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1149[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1149[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_before(lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1624: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv230: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv228: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096))) + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1623[v_i // T.int64(8), v_j], lv1624[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1623[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1624[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1624[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv230[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv230[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv228[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv228[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1159: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") + lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(2)): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv6_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504)) + T.reads(lv6[v0, v1, v2]) + T.writes(lv6_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv6_shared[v0, v1, v2] = lv6[v0, v1, v2] + for k_0_1 in range(T.int64(86)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1158[v_j // T.int64(8), v_i], lv1159[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1159[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1159[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) + for k_0_2_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode3_matmul1_fp16_before(lv5865: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv2705: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv5865[v_i // T.int64(8), v_j], lv5866[v_i // T.int64(32), v_j], lv5867[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv5865[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v_i // T.int64(32), v_j] + lv5867[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode3_matmul1_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local", dtype="float16") + var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") + lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv1511_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv1511[v0, v1, v2]) + T.writes(lv1511_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("var_decode_intermediate_pad"): + v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1]) + T.writes(var_decode_intermediate_pad_local[v0, v1]) + var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_pad_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_pad_local[v0, v1, v2] + + +@T.prim_func +def fused_decode3_matmul1_cast_fp16_before(lv1803: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv1804: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1805: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1803[v_i // T.int64(8), v_j], lv1804[v_i // T.int64(32), v_j], lv1805[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1803[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv1804[v_i // T.int64(32), v_j] + lv1805[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + +@T.prim_func +def fused_decode3_matmul1_cast_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local", dtype="float16") + var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") + lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv1511_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv1511[v0, v1, v2]) + T.writes(lv1511_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("var_decode_intermediate_pad"): + v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1]) + T.writes(var_decode_intermediate_pad_local[v0, v1]) + var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_pad_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_fp16_before(lv35: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv37: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv35[v_i // T.int64(8), v_j], lv36[v_i // T.int64(32), v_j], lv37[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv35[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_i // T.int64(32), v_j] + lv37[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2710[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2710[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_fp16_after(lv1143: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv37: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv3_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv3[v0, v1, v2]) + T.writes(lv3_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv3_shared[v0, v1, v2] = lv3[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1143[v_j // T.int64(8), v_i], lv36[v_j // T.int64(32), v_i], lv37[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_j // T.int64(32), v_i] + lv37[v_j // T.int64(32), v_i] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode4_matmul5_fp16_before(lv11: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv13: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv11[v_i // T.int64(8), v_j], lv12[v_i // T.int64(32), v_j], lv13[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv11[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_i // T.int64(32), v_j] + lv13[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2712[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2712[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode4_matmul5_fp16_after(lv1128: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv13: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2712_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2712[v0, v1, v2]) + T.writes(lv2712_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2712_shared[v0, v1, v2] = lv2712[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1128[v_j // T.int64(8), v_i], lv12[v_j // T.int64(32), v_i], lv13[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_j // T.int64(32), v_i] + lv13[v_j // T.int64(32), v_i] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_fp16_before(lv51: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv53: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv51[v_i // T.int64(8), v_j], lv52[v_i // T.int64(32), v_j], lv53[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv51[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv52[v_i // T.int64(32), v_j] + lv53[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv5[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv5[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_fp16_after(lv1153: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv53: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1153[v_j // T.int64(8), v_i], lv52[v_j // T.int64(32), v_i], lv53[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv52[v_j // T.int64(32), v_i] + lv53[v_j // T.int64(32), v_i] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_fp16_before(lv43: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv45: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j], lv45[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv43[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv44[v_i // T.int64(32), v_j] + lv45[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_fp16_after(lv1148: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv45: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(64)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1148[v_j // T.int64(8), v_i], lv44[v_j // T.int64(32), v_i], lv45[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv44[v_j // T.int64(32), v_i] + lv45[v_j // T.int64(32), v_i] + for k_0_1_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_fp16_before(lv59: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv61: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv59[v_i // T.int64(8), v_j], lv60[v_i // T.int64(32), v_j], lv61[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv59[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv60[v_i // T.int64(32), v_j] + lv61[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv5[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_fp16_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv61: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(2)): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv6_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504)) + T.reads(lv6[v0, v1, v2]) + T.writes(lv6_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv6_shared[v0, v1, v2] = lv6[v0, v1, v2] + for k_0_1 in range(T.int64(86)): + for ax0_0 in range(T.int64(8)): + for ax0_1 in T.unroll(T.int64(8)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1158[v_j // T.int64(8), v_i], lv60[v_j // T.int64(32), v_i], lv61[v_j // T.int64(32), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv60[v_j // T.int64(32), v_i] + lv61[v_j // T.int64(32), v_i] + for k_0_2_k_1_fused in range(T.int64(64)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode3_matmul1_cast_int3_fp16_before(lv2931: T.Buffer((T.int64(412), T.int64(32000)), "uint32"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2931[v_i // T.int64(10), v_j], lv2932[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + +@T.prim_func +def fused_decode3_matmul1_cast_int3_fp16_after(lv1123: T.Buffer((T.int64(412), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope="local", dtype="float16") + var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") + lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv1511_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv1511[v0, v1, v2]) + T.writes(lv1511_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("var_decode_intermediate_pad"): + v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1123[v0 // T.int64(10), v1], lv5866[v0 // T.int64(40), v1]) + T.writes(var_decode_intermediate_pad_local[v0, v1]) + var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(10), v1], T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * lv5866[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_pad_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_int3_fp16_before(lv1605: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1605[v_i // T.int64(10), v_j], lv1606[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1605[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_int3_fp16_after(lv1143: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv3_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv3[v0, v1, v2]) + T.writes(lv3_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1143[v_j // T.int64(10), v_i], lv36[v_j // T.int64(40), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv36[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode4_matmul5_int3_fp16_before(lv1587: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1587[v_i // T.int64(10), v_j], lv1588[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1587[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode4_matmul5_int3_fp16_after(lv1128: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2712_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2712[v0, v1, v2]) + T.writes(lv2712_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1128[v_j // T.int64(10), v_i], lv12[v_j // T.int64(40), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv12[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_int3_fp16_before(lv1617: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1617[v_i // T.int64(10), v_j], lv1618[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1617[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_int3_fp16_after(lv1153: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1153[v_j // T.int64(10), v_i], lv52[v_j // T.int64(40), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv52[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_int3_fp16_before(lv1611: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1611[v_i // T.int64(10), v_j], lv1612[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_int3_fp16_after(lv1148: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1148[v_j // T.int64(10), v_i], lv44[v_j // T.int64(40), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv44[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_int3_fp16_before(lv1623: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv167: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv165: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1623[v_i // T.int64(10), v_j], lv1624[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1623[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv167[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv167[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv165[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv165[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_int3_fp16_after(lv1158: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(11040), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11040)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(2)): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv6_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5520)) + T.reads(lv6[v0, v1, v2]) + T.writes(lv6_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv6_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(11008), lv6[v0, v1, v2], T.float16(0)) + for k_0_1 in range(T.int64(69)): + for ax0_0 in T.unroll(T.int64(80)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1158[v_j // T.int64(10), v_i], lv60[v_j // T.int64(40), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) + for k_0_2_k_1_fused in range(T.int64(80)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + k_0_2_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv60[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode3_matmul1_cast_int3_int16_fp16_before(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(32000)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + +@T.prim_func +def fused_decode3_matmul1_cast_int3_int16_fp16_after(lv1123: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") + lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv1511_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv1511[v0, v1, v2]) + T.writes(lv1511_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("var_decode_intermediate_pad"): + v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1123[v0 // T.int64(5), v1]) + T.writes(var_decode_intermediate_pad_local[v0, v1]) + var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1123[v0 // T.int64(5), v1]), T.Cast("uint16", v0 % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in range(T.int64(1)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv5866[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv5866[v_j, v_i] + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_pad_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_int3_int16_fp16_before(lv1605: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1605[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode4_fused_matmul5_add3_int3_int16_fp16_after(lv1143: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv36: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv3_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv3[v0, v1, v2]) + T.writes(lv3_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1143[v_j // T.int64(5), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1143[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in range(T.int64(1)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv36[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv36[v_j, v_i] + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode4_matmul5_int3_int16_fp16_before(lv1587: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1587[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + +@T.prim_func +def fused_decode4_matmul5_int3_int16_fp16_after(lv1128: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv12: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2712_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2712[v0, v1, v2]) + T.writes(lv2712_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1128[v_j // T.int64(5), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1128[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in range(T.int64(1)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv12[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv12[v_j, v_i] + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(var_matmul_intermediate[v0, v1, v2]) + var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_before(lv1617: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1617[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_after(lv1153: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv52: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1153[v_j // T.int64(5), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1153[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in range(T.int64(1)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv52[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv52[v_j, v_i] + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_int3_int16_fp16_before(lv1611: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1611[v_i // T.int64(5), v_j], lv1612[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1611[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode5_fused_matmul8_silu1_int3_int16_fp16_after(lv1148: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(11008)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") + lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv2749[v0, v1, v2]) + T.writes(lv2749_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(103)): + for ax0_0 in T.unroll(T.int64(40)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1148[v_j // T.int64(5), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1148[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in range(T.int64(1)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) + v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv44[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv44[v_j, v_i] + for k_0_1_k_1_fused in range(T.int64(40)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_int3_int16_fp16_before(lv1623: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv167: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv165: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1623[v_i // T.int64(5), v_j], lv1624[v_i // T.int64(40), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1623[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv167[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv167[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv165[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv165[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def fused_decode6_fused_matmul9_add3_int3_int16_fp16_after(lv1158: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv60: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + # with T.block("root"): + var_decode_intermediate_local = T.alloc_buffer((T.int64(11040), T.int64(4096)), scope="local", dtype="float16") + var_scale_intermediate_local = T.alloc_buffer((T.int64(276), T.int64(4096)), scope="local", dtype="float16") + var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") + lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11040)), scope="shared", dtype="float16") + for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): + for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(44)): + for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("lv2749_shared"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial(T.int64(11040), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) + T.reads(lv6[v0, v1, v2]) + T.writes(lv6_shared[v0, v1, v2]) + T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(11040)) + T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) + lv6_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(11008), lv6[v0, v1, v2], T.float16(0)) + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + T.reads() + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0_0 in range(T.int64(138)): + for ax0_0 in T.unroll(T.int64(80)): + for ax1 in range(T.int64(1)): + with T.block("decode"): + v_j = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(80) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv1158[v_j // T.int64(5), v_i]) + T.writes(var_decode_intermediate_local[v_j, v_i]) + var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1158[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) + for ax0_0 in T.unroll(T.int64(2)): + for ax1 in range(T.int64(1)): + with T.block("scale"): + v_j = T.axis.spatial(T.int64(276), k_0_0 * T.int64(2) + ax0_0) + v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) + T.reads(lv60[v_j, v_i]) + T.writes(var_scale_intermediate_local[v_j, v_i]) + var_scale_intermediate_local[v_j, v_i] = lv60[v_j, v_i] + for k_0_2_k_1_fused in range(T.int64(80)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) + v_k = T.axis.reduce(T.int64(11040), k_0_0 * T.int64(80) + k_0_2_k_1_fused) + T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) + T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) + var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + with T.block("var_matmul_intermediate_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) + T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) + T.writes(p_output0_intermediate[v0, v1, v2]) + p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] +################################################ + +def get_dict_key(func): + return tvm.ir.structural_hash(func), func + + +tir_dispatch_dict = { + get_dict_key(fused_min_max_triu_te_broadcast_to): fused_min_max_triu_te_broadcast_to_sch_func(), + get_dict_key(rms_norm_before): rms_norm_after, + get_dict_key(rms_norm_fp16_before): rms_norm_fp16_after, + get_dict_key(softmax_before): softmax_after, + get_dict_key(softmax_mxn_before): softmax_mxn_after, + get_dict_key(softmax_cast_mxn_before): softmax_cast_mxn_after, + get_dict_key(softmax_fp16_before): softmax_fp16_after, + get_dict_key(softmax_mxn_fp16_before): softmax_mxn_fp16_after, + get_dict_key(softmax_1xn_before): softmax_1xn_sch_func(softmax_1xn_before), + get_dict_key(softmax_cast_1xn_before): softmax_1xn_sch_func(softmax_cast_1xn_before, cast_to_fp16=True), + get_dict_key(softmax_1xn_fp16_before): softmax_1xn_sch_func(softmax_1xn_fp16_before), + get_dict_key(matmul1_before): matmul1_after, + get_dict_key(matmul2_before): matmul2_sch_func(), + get_dict_key(matmul5_before): matmul5_after, + get_dict_key(matmul5_with_m_before): matmul5_with_m_after, + get_dict_key(NT_matmul_before): NT_matmul_after, + get_dict_key(NT_matmul4_before): NT_matmul4_sch_func(), + get_dict_key(NT_matmul9_before): NT_matmul9_sch_func(), + get_dict_key(fused_matmul1_add1): fused_matmul1_add1_sch_func(), + get_dict_key(fused_matmul3_multiply): fused_matmul3_multiply_sch_func(), + get_dict_key(fused_matmul3_silu): fused_matmul3_silu_sch_func(), + get_dict_key(fused_matmul4_add1): fused_matmul4_add1_sch_func(), + get_dict_key(fused_NT_matmul_add1_before): fused_NT_matmul_add1_after, + get_dict_key(fused_NT_matmul1_divide_add_maximum_before): fused_NT_matmul1_divide_add_maximum_after, + get_dict_key(fused_NT_matmul1_divide_add_maximum_with_m_before): fused_NT_matmul1_divide_add_maximum_with_m_after, + get_dict_key(fused_NT_matmul6_divide1_add2_maximum1_before): fused_NT_matmul6_divide1_add2_maximum1_after, + get_dict_key(fused_NT_matmul2_multiply_before): fused_NT_matmul2_multiply_after, + get_dict_key(fused_NT_matmul2_silu_before): fused_NT_matmul2_silu_after, + get_dict_key(fused_NT_matmul3_add1_before): fused_NT_matmul3_add1_after, + get_dict_key(fused_NT_matmul_divide_maximum_minimum_cast_before): fused_NT_matmul_divide_maximum_minimum_cast_sch_func(), + get_dict_key(fused_NT_matmul_divide_maximum_minimum_before): fused_NT_matmul_divide_maximum_minimum_sch_func(), + get_dict_key(fused_NT_matmul1_add3_before): fused_NT_matmul1_add3_sch_func(), + get_dict_key(fused_NT_matmul2_divide1_add2_maximum1_before): fused_NT_matmul2_divide1_add2_maximum1_sch_func(fused_NT_matmul2_divide1_add2_maximum1_before), + get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_cast3_before): fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after, + get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_before): fused_NT_matmul2_divide1_maximum1_minimum1_after, + get_dict_key(fused_NT_matmul3_multiply1_before): fused_NT_matmul3_multiply1_sch_func(), + get_dict_key(fused_NT_matmul3_silu1_before): fused_NT_matmul3_silu1_sch_func(), + get_dict_key(fused_NT_matmul4_add3_before): fused_NT_matmul4_add3_sch_func(), + get_dict_key(matmul1_fp16_before): matmul1_fp16_sch_func(), + get_dict_key(matmul8_fp16_before): matmul8_fp16_sch_func(matmul8_fp16_before), + get_dict_key(matmul8_with_m_fp16_before): matmul8_fp16_sch_func(matmul8_with_m_fp16_before), + get_dict_key(NT_matmul1_fp16_before): NT_matmul1_fp16_sch_func(), + get_dict_key(decode6): decode_sch_func(decode6), + get_dict_key(decode7): decode_sch_func(decode7), + get_dict_key(decode8): decode_sch_func(decode8), + get_dict_key(decode4_fp16): decode_sch_func(decode4_fp16), + get_dict_key(decode5_fp16): decode_sch_func(decode5_fp16), + get_dict_key(decode6_fp16): decode_sch_func(decode6_fp16), + get_dict_key(decode_int3_fp16): decode_sch_func(decode_int3_fp16), + get_dict_key(decode1_int3_fp16): decode_sch_func(decode1_int3_fp16), + get_dict_key(decode2_int3_fp16): decode_sch_func(decode2_int3_fp16), + get_dict_key(decode_int3_int16_fp16): decode_sch_func(decode_int3_int16_fp16), + get_dict_key(decode1_int3_int16_fp16): decode_sch_func(decode1_int3_int16_fp16), + get_dict_key(decode2_int3_int16_fp16): decode_sch_func(decode2_int3_int16_fp16), + get_dict_key(fused_decode3_matmul1_before): fused_decode3_matmul1_after, + get_dict_key(fused_decode4_fused_matmul5_add3_before): fused_decode4_fused_matmul5_add3_after, + get_dict_key(fused_decode4_matmul5_before): fused_decode4_matmul5_after, + get_dict_key(fused_decode5_fused_matmul8_multiply1_before): fused_decode5_fused_matmul8_multiply1_after, + get_dict_key(fused_decode5_fused_matmul8_silu1_before): fused_decode5_fused_matmul8_silu1_after, + get_dict_key(fused_decode6_fused_matmul9_add3_before): fused_decode6_fused_matmul9_add3_after, + get_dict_key(fused_decode3_matmul1_fp16_before): fused_decode3_matmul1_fp16_after, + get_dict_key(fused_decode3_matmul1_cast_fp16_before): fused_decode3_matmul1_cast_fp16_after, + get_dict_key(fused_decode4_fused_matmul5_add3_fp16_before): fused_decode4_fused_matmul5_add3_fp16_after, + get_dict_key(fused_decode4_matmul5_fp16_before): fused_decode4_matmul5_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_multiply1_fp16_before): fused_decode5_fused_matmul8_multiply1_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_silu1_fp16_before): fused_decode5_fused_matmul8_silu1_fp16_after, + get_dict_key(fused_decode6_fused_matmul9_add3_fp16_before): fused_decode6_fused_matmul9_add3_fp16_after, + get_dict_key(fused_decode3_matmul1_cast_int3_fp16_before): fused_decode3_matmul1_cast_int3_fp16_after, + get_dict_key(fused_decode4_fused_matmul5_add3_int3_fp16_before): fused_decode4_fused_matmul5_add3_int3_fp16_after, + get_dict_key(fused_decode4_matmul5_int3_fp16_before): fused_decode4_matmul5_int3_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_multiply1_int3_fp16_before): fused_decode5_fused_matmul8_multiply1_int3_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_silu1_int3_fp16_before): fused_decode5_fused_matmul8_silu1_int3_fp16_after, + get_dict_key(fused_decode6_fused_matmul9_add3_int3_fp16_before): fused_decode6_fused_matmul9_add3_int3_fp16_after, + get_dict_key(fused_decode3_matmul1_cast_int3_int16_fp16_before): fused_decode3_matmul1_cast_int3_int16_fp16_after, + get_dict_key(fused_decode4_fused_matmul5_add3_int3_int16_fp16_before): fused_decode4_fused_matmul5_add3_int3_int16_fp16_after, + get_dict_key(fused_decode4_matmul5_int3_int16_fp16_before): fused_decode4_matmul5_int3_int16_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_before): fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_after, + get_dict_key(fused_decode5_fused_matmul8_silu1_int3_int16_fp16_before): fused_decode5_fused_matmul8_silu1_int3_int16_fp16_after, + get_dict_key(fused_decode6_fused_matmul9_add3_int3_int16_fp16_before): fused_decode6_fused_matmul9_add3_int3_int16_fp16_after, +} +# fmt: on + + +def lookup_func(func): + for (hash_value, func_before), f_after in tir_dispatch_dict.items(): + if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( + func, func_before + ): + return f_after + return None diff --git a/mlc_llm/quantization/__init__.py b/mlc_llm/quantization/__init__.py new file mode 100644 index 0000000..6284df6 --- /dev/null +++ b/mlc_llm/quantization/__init__.py @@ -0,0 +1,232 @@ +from .quantization import FQuantize +from .quantization import QuantizationScheme +from .quantization import QuantizationSpec, NoQuantizationSpec, ParamQuantKind +from .quantization import QuantSpecUpdater +from .group_quantization import GroupQuantizationSpec +from .autogptq_quantization import AutogptqQuantizationSpec +from .ft_quantization import FTQuantizationSpec, FTQuantizeUpdater + + +# The predefined quantization schemes. +quantization_schemes = { + "autogptq_llama_q4f16_0": QuantizationScheme( + name="autogptq_llama_q4f16_0", + linear_weight=AutogptqQuantizationSpec( + dtype="float16", + mode="int4", + sym=False, + group_size=128, + ), + embedding_table=NoQuantizationSpec("float16"), + final_fc_weight=NoQuantizationSpec("float16"), + ), + "autogptq_llama_q4f16_1": QuantizationScheme( + name="autogptq_llama_q4f16_1", + linear_weight=AutogptqQuantizationSpec( + dtype="float16", + mode="int4", + sym=False, + group_size=-1, + ), + embedding_table=NoQuantizationSpec("float16"), + final_fc_weight=NoQuantizationSpec("float16"), + ), + "q0f16": QuantizationScheme("q0f16", NoQuantizationSpec("float16")), + "q0f32": QuantizationScheme("q0f32", NoQuantizationSpec("float32")), + "q3f16_0": QuantizationScheme( + name="q3f16_0", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int3", + sym=True, + storage_nbit=16, + group_size=40, + transpose=True, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int3", + sym=True, + storage_nbit=16, + group_size=40, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + ), + "q3f16_1": QuantizationScheme( + name="q3f16_1", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int3", + sym=True, + storage_nbit=16, + group_size=40, + transpose=False, + ), + embedding_table="same_as_linear_weight", + final_fc_weight="same_as_linear_weight", + ), + "q4f16_0": QuantizationScheme( + name="q4f16_0", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=True, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + ), + "q4f16_1": QuantizationScheme( + name="q4f16_1", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + embedding_table="same_as_linear_weight", + final_fc_weight="same_as_linear_weight", + ), + "q4f16_2": QuantizationScheme( + name="q4f16_2", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + embedding_table=NoQuantizationSpec("float16"), + final_fc_weight=NoQuantizationSpec("float16"), + ), + "q4f16_ft": QuantizationScheme( + name="q4f16_ft", + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=4, + group_size=-1, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q4f16_ft_group": QuantizationScheme( + name="q4f16_ft_group", + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=4, + group_size=64, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q4f32_0": QuantizationScheme( + name="q4f32_0", + linear_weight=GroupQuantizationSpec( + dtype="float32", + mode="int4", + sym=False, + storage_nbit=32, + group_size=32, + transpose=True, + ), + embedding_table=GroupQuantizationSpec( + dtype="float32", + mode="int4", + sym=False, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + ), + "q4f32_1": QuantizationScheme( + name="q4f32_1", + linear_weight=GroupQuantizationSpec( + dtype="float32", + mode="int4", + sym=False, + storage_nbit=32, + group_size=32, + transpose=False, + ), + embedding_table="same_as_linear_weight", + final_fc_weight="same_as_linear_weight", + ), + "q8f16_ft": QuantizationScheme( + name="q8f16_ft", + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=8, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int8", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q8f16_ft_group": QuantizationScheme( + name="q8f16_ft_group", + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=8, + group_size=64, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int8", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q8f16_1": QuantizationScheme( + name="q8f16_1", + linear_weight=GroupQuantizationSpec( + dtype="float16", + mode="int8", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + embedding_table="same_as_linear_weight", + final_fc_weight="same_as_linear_weight", + ), +} diff --git a/mlc_llm/quantization/autogptq_quantization.py b/mlc_llm/quantization/autogptq_quantization.py new file mode 100644 index 0000000..2cdc186 --- /dev/null +++ b/mlc_llm/quantization/autogptq_quantization.py @@ -0,0 +1,193 @@ +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Tuple +from tvm import relax, te, tir, topi +from . import tir_utils +from .quantization import QuantizationSpec +from .quantization import FQuantize, FTEDequantize, convert_TE_func + + +@dataclass +class AutogptqQuantizationSpec(QuantizationSpec): + """The quantization specification for group quantization algorithm.""" + + mode: Literal["int2", "int3", "int4", "int8"] + sym: bool + group_size: int + storage_nbit: int = 32 + + quantized_suffix = ["qweight", "qzeros", "scales", "g_idx"] + + def get_loaded_tensor_info( + self, pname: str, param_info: relax.TensorStructInfo + ) -> Tuple[List[str], List[relax.TensorStructInfo]]: + assert self.storage_nbit == 32, "Only support 32bit storage currently" + + quantized_pnames = self.quant_convert_pname_fwd(pname) + if len(quantized_pnames) == 1: + return quantized_pnames, [param_info] + else: + assert len(quantized_pnames) == 4 + assert param_info.ndim == 2 + nbit = int(self.mode[-1]) + tensor_info = [] + outfeatures, infeatures = param_info.shape.values + group_size = self.group_size if self.group_size != -1 else infeatures + + def get_quantized_shape_dtype(quantized_pname: str): + if quantized_pname.endswith("qweight"): + return (infeatures // self.storage_nbit * nbit, outfeatures), "uint32" + elif quantized_pname.endswith("qzeros"): + return ( + infeatures // group_size, + outfeatures // self.storage_nbit * nbit, + ), "uint32" + elif quantized_pname.endswith("scales"): + return (infeatures // group_size, outfeatures), "float16" + elif quantized_pname.endswith("g_idx"): + return (infeatures,), "uint32" + else: + raise ValueError(f"Unrecognized quantized parameter name {quantized_pname}") + + for quantized_pname in quantized_pnames: + shape, dtype = get_quantized_shape_dtype(quantized_pname) + tensor_info.append(relax.TensorStructInfo(shape, dtype)) + + return quantized_pnames, tensor_info + + def quant_convert_pname_fwd(self, torch_pname: str) -> List[str]: + # For Llama: + if "_proj.weight" in torch_pname: + return [torch_pname.replace("weight", suffix) for suffix in self.quantized_suffix] + return [torch_pname] + + def run_prequantize(self, model_path: str) -> str: + # with auto-gptq >= 0.2.0 + try: + import auto_gptq # pylint: disable=import-outside-toplevel + import transformers # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError( + "Please install auto_gptq package (version >= 0.2.0) and " + "transformers package to use AutoGPTQ quantization." + ) + import os + from transformers import AutoTokenizer + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig + + quantized_model_path = ( + model_path + + f"-gptq-i{self.mode[-1]}" + + ("-sym" if self.sym else "") + + f"-g{self.group_size}" + ) + if os.path.isdir(quantized_model_path): + return quantized_model_path + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + examples = [ + tokenizer( + "MLC LLM is a universal solution that allows any language models " + "to be deployed natively on a diverse set of hardware backends and " + "native applications, plus a productive framework for everyone to " + "further optimize model performance for their own use cases." + ) + ] + quantize_config = BaseQuantizeConfig( + bits=int(self.mode[-1]), # quantize bits + desc_act=False, # disable activation description + group_size=self.group_size, # disable group quantization + ) + + model = AutoGPTQForCausalLM.from_pretrained(model_path, quantize_config) + model.quantize(examples) + + # save quantized model + model.save_quantized(quantized_model_path) + tokenizer.save_pretrained(quantized_model_path) + return quantized_model_path + + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: + return None + + def get_dequantize_func( + self, + param_info: relax.TensorStructInfo, + qparam_info: List[relax.TensorStructInfo], + ) -> Optional[FQuantize]: + return convert_TE_func( + decoding_func( + sym=self.sym, + nbit=int(self.mode[-1]), + storage_nbit=self.storage_nbit, + dim_length=param_info.shape.values[-1], + dtype=self.dtype, + ), + func_name="decode", + ) + + def convert_param_bkwd(self, torch_pname: str, torch_param): + target_dtype = ( + self.dtype if "_proj." not in torch_pname or "scales" in torch_pname else "uint32" + ) + + # For Llama + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(target_dtype))] + + def compute_relax_param(self, relax_pname: str, torch_params: List[Any]): + import numpy as np + + # For Llama + if "query_key_value_proj" in relax_pname: + assert len(torch_params) == 3 + elif "gate_up_proj" in relax_pname: + assert len(torch_params) == 2 + else: + raise ValueError("Unexpected param loading") + + if "g_idx" in relax_pname: + return torch_params[0].astype("uint32") + else: + target_dtype = self.dtype if "scales" in relax_pname else "uint32" + return np.concatenate(torch_params, axis=-1).astype(target_dtype) + + +def decoding_func( + sym: bool, + nbit: int, + storage_nbit: int, + dim_length: tir.PrimExpr, + dtype: str = "float16", +) -> FTEDequantize: + assert dtype in ["float16"], "Only support float16 currently" + assert sym == False, "Only support sym=False currently" + assert storage_nbit == 32, "Only support storage_nbit=32 currently" + + def te_decode_asym(qweight, qzeros, scales, g_idx): + n_float_per_u32 = 32 // nbit + + def f_decode_asym(i, j): + zeros = tir_utils._tir_u32_to_int_to_float( + nbit, + qzeros[g_idx[i], j // n_float_per_u32], + j % n_float_per_u32, + dtype=dtype, + ) + data_float = tir_utils._tir_u32_to_int_to_float( + nbit, + qweight[i // n_float_per_u32, j], + i % n_float_per_u32, + dtype=dtype, + ) + scale_float, bias_float = scales[g_idx[i], j], zeros + 1 + w = (data_float - bias_float) * scale_float + return w + + shape = (dim_length, qweight.shape[1]) + w = te.compute(shape=shape, fcompute=f_decode_asym, name="decode") + w = topi.transpose(w) + return w + + return te_decode_asym diff --git a/mlc_llm/quantization/ft_quantization.py b/mlc_llm/quantization/ft_quantization.py new file mode 100644 index 0000000..286ca9a --- /dev/null +++ b/mlc_llm/quantization/ft_quantization.py @@ -0,0 +1,219 @@ +from dataclasses import dataclass +from typing import List, Optional + +import tvm +from tvm.contrib.nvcc import parse_compute_version +from tvm import relax, te, tir, topi +from tvm.script import tir as T +from tvm.relax.expr_functor import visitor + +from . import tir_utils +from .quantization import QuantizationSpec, QuantSpecUpdater +from .quantization import FQuantize, convert_TE_func +from .group_quantization import GroupQuantizationSpec + + +@dataclass +class FTQuantizationSpec(QuantizationSpec): + """The quantization specification for the FasterTransformer kernel.""" + + def __init__(self, dtype, nbit, group_size=-1): + super().__init__(dtype) + self.nbit = nbit + assert group_size in [-1, 64, 128], f"Group size {group_size} is not supported." + self.group_size = group_size + + if tvm.cuda(0).exist: + major, minor = parse_compute_version(tvm.cuda(0).compute_version) + if major == 8: + self.sm = 80 + else: + self.sm = 10 * major + minor + else: + self.sm = None + + self.do_preprocess = True + + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: + assert self.sm is not None + + def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]): + encoded_data = bb.emit_te( + encoding_func( + self.nbit, + 8, + group_size=self.group_size, + dtype=self.dtype, + ), + inputs[0], + primfunc_name_hint="encode", + ) + + packed_weight = bb.normalize(encoded_data[0]) + + if self.do_preprocess: + encoded_weight = bb.emit( + relax.call_pure_packed( + "cutlass.ft_preprocess_weight", + packed_weight, + self.sm, + self.nbit == 4, + sinfo_args=packed_weight.struct_info, + ) + ) + else: + encoded_weight = packed_weight + + return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]])) + + return f_quantize + + def get_dequantize_func( + self, + param_info: relax.TensorStructInfo, + qparam_info: List[relax.TensorStructInfo], + ) -> Optional[FQuantize]: + return convert_TE_func( + decoding_func( + self.nbit, + storage_nbit=8, + group_size=self.group_size, + ), + func_name="decode", + ) + + +def encoding_func(nbit: int, storage_nbit: int, group_size: int, dtype: str = "float32"): + def te_encode_sym(weight: te.Tensor): + """Encode the weight tensor of shape [N, K] into a quantized weight tensor of shape + [K, N // float_per_int] and a scale tensor of shape [K // group_size, N] + """ + n_float_per_int = storage_nbit // nbit + max_int_value = (1 << (nbit - 1)) - 1 + + cur_group_size = weight.shape[1] if group_size == -1 else group_size + scale_min_shape = (tir.ceildiv(weight.shape[1], cur_group_size), weight.shape[0]) + k = te.reduce_axis((0, cur_group_size), name="k") + max_abs_value = te.compute( + shape=scale_min_shape, + fcompute=lambda group, i: te.max( + te.abs( + tir.if_then_else( + group * cur_group_size + k < weight.shape[1], + weight[i, group * cur_group_size + k], + tir.const(0, dtype=weight.dtype), + ) + ), + axis=k, + ), + name="max_abs_value", + ) + + def f_compute_scale(*idx): + max_value = tir.max(tir.Cast(dtype, max_abs_value(*idx)), tir.const(1e-4, dtype)) + return max_value / tir.const(max_int_value, dtype) + + scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale") + storage_dtype = "int" + str(storage_nbit) + + def f_scale_weight(i, j): + w_scaled = tir.round(tir.Cast(dtype, weight[i, j]) / scale[j // cur_group_size, i]) + w_scaled = T.min( + T.max(w_scaled, tir.const(-max_int_value - 1, dtype)), + tir.const(max_int_value, dtype), + ).astype(storage_dtype) + if n_float_per_int == 1: + return w_scaled + return w_scaled & tir.const((1 << nbit) - 1, storage_dtype) + + n_i32 = tir.ceildiv(weight.shape[0], n_float_per_int) + + if n_float_per_int == 1: + w_gathered = te.compute( + shape=(weight.shape[1], n_i32), + fcompute=lambda j, i: f_scale_weight(i, j), + name="w_gathered", + ) + else: + k = te.reduce_axis((0, n_float_per_int), name="k") + reducer = te.comm_reducer( + fcombine=lambda x, y: tir.bitwise_or(x, y), + fidentity=lambda dtype: tir.const(0, storage_dtype), + name="bitwise_or", + ) + w_gathered = te.compute( + shape=(weight.shape[1], n_i32), + fcompute=lambda j, i: reducer( + tir.if_then_else( + i * n_float_per_int + k < weight.shape[0], + f_scale_weight(i * n_float_per_int + k, j) + << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), + tir.const(0, storage_dtype), + ), + axis=k, + ), + name="w_gathered", + ) + + return w_gathered, topi.cast(scale, "float16") + + return te_encode_sym + + +def decoding_func(nbit: int, storage_nbit: int, group_size: int): + def te_decode_sym(data, scale): + n_float_per_int = storage_nbit // nbit + cur_group_size = data.shape[0] if group_size == -1 else group_size + + def f_decode_sym(i, j): + if n_float_per_int == 1: + data_float = tir.Cast("float16", data[i, j]) + else: + f_convert = tir_utils._tir_packed_int_to_int_to_float(storage_nbit) + data_float = f_convert( + nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype="float16" + ) + + scale_float = scale[i // cur_group_size, j] + return data_float * scale_float + + shape = (data.shape[0], data.shape[1] * n_float_per_int) + w = te.compute(shape=shape, fcompute=f_decode_sym, name="decode") + # Dummy transpose for FuseDecodeTranspose + return topi.transpose(w) + + return te_decode_sym + + +@visitor +class FTQuantizeUpdater(QuantSpecUpdater._cls): + def visit_call_(self, call: relax.Call): + if call.op != tvm.ir.Op.get("relax.matmul"): + return + rhs = self.lookup_binding(call.args[1]) + assert rhs is not None + if ( + rhs.op != tvm.ir.Op.get("relax.permute_dims") + or rhs.attrs.axes is not None + or rhs.args[0].struct_info.ndim != 2 + ): + return + + if rhs.args[0] not in self.param_map: + return + + param = self.param_map[rhs.args[0]] + + if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: + # FT requires N to be a multiple of 8 + # FT does not support fp32 output dtype + # TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed + # by `cast(..., "float16")`, `matmul -> cast` can be offloaded. + param.quant_spec = GroupQuantizationSpec( + param.param_info.dtype, + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ) diff --git a/mlc_llm/quantization/group_quantization.py b/mlc_llm/quantization/group_quantization.py new file mode 100644 index 0000000..7603ad2 --- /dev/null +++ b/mlc_llm/quantization/group_quantization.py @@ -0,0 +1,214 @@ +from dataclasses import dataclass +from typing import List, Literal, Optional + +import tvm +from tvm import relax, te, tir, topi +from tvm.script import tir as T +from tvm.relax.expr_functor import visitor + +from . import tir_utils +from .quantization import QuantizationSpec, QuantSpecUpdater +from .quantization import NoQuantizationSpec +from .quantization import FQuantize, FTEQuantize, FTEDequantize, convert_TE_func + + +@dataclass +class GroupQuantizationSpec(QuantizationSpec): + """The quantization specification for group quantization algorithm.""" + + mode: Literal["int3", "int4"] + sym: bool + storage_nbit: int + group_size: int + transpose: bool + + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: + return convert_TE_func( + encoding_func( + sym=self.sym, + group_size=self.group_size, + nbit=int(self.mode[-1]), + mode=self.mode, + storage_nbit=self.storage_nbit, + transpose=self.transpose, + dtype=self.dtype, + ), + func_name="encode", + ) + + def get_dequantize_func( + self, + param_info: relax.TensorStructInfo, + qparam_info: List[relax.TensorStructInfo], + ) -> Optional[FQuantize]: + return convert_TE_func( + decoding_func( + sym=self.sym, + group_size=self.group_size, + nbit=int(self.mode[-1]), + mode=self.mode, + storage_nbit=self.storage_nbit, + dim_length=param_info.shape.values[-1], + data_transposed=self.transpose, + transpose_output=self.transpose, + dtype=self.dtype, + ), + func_name="decode", + ) + + +# fmt: off +def encoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, transpose: bool=True, dtype: str = "float32") -> FTEQuantize: + def te_encode_asym(weight: te.Tensor): + assert weight.shape[1] % group_size == 0 + n_group = weight.shape[1] // group_size + n_float_per_u32 = 32 // nbit + + scale_min_shape = (weight.shape[0], n_group) + k = te.reduce_axis((0, group_size), name="k") + min_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.min(weight[i, j * group_size + k], axis=k), name="min_value") + max_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(weight[i, j * group_size + k], axis=k), name="max_value") + scale = te.compute(shape=scale_min_shape, fcompute=lambda i, j: (max_value[i, j] - min_value[i, j]) / tir.const((1 << nbit) - 1, dtype), name="scale") + + def f_scale_weight(i, j): + group_idx = j // group_size + w_scaled = tir.round((weight[i, j] - min_value[i, group_idx]) / scale[i, group_idx]).astype("int32") + w_scaled = T.min(T.max(w_scaled, tir.const(0, "int32")), tir.const((1 << nbit) - 1, "int32")) + w_scaled = w_scaled.astype("uint32") + return w_scaled + + k = te.reduce_axis((0, n_float_per_u32), name="k") + reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name="bitwise_or") + if dtype == "float32": + if transpose: + w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") + scale_bias = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name="scale_min") + else: + w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") + scale_bias = te.compute(shape=(weight.shape[0], n_group), fcompute=lambda i, j: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name="scale_min") + return w_gathered, scale_bias + else: + if transpose: + w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") + scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j], name="scale_transpose") + min_value = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: min_value[i, j], name="min_transpose") + else: + w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") + return w_gathered, scale, min_value + + def te_encode_sym(weight: te.Tensor): + n_group = tir.ceildiv(weight.shape[1], group_size) + n_float_per_int = storage_nbit // nbit + max_int_value = (1 << (nbit - 1)) - 1 + assert group_size % n_float_per_int == 0 + + scale_min_shape = (weight.shape[0], n_group) + k = te.reduce_axis((0, group_size), name="k") + max_abs_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(tir.if_then_else(j * group_size + k < weight.shape[1], te.abs(weight[i, j * group_size + k]), tir.min_value(dtype)), axis=k), name="max_abs_value") + + def f_compute_scale(i, j): + max_value = tir.max(max_abs_value[i, j], tir.const(1e-4, dtype)) + return (max_value / tir.const(max_int_value, dtype)) if mode.startswith("int") else max_value + + scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale") + storage_dtype = ("uint" + str(storage_nbit)) if mode.startswith("int") else "uint32" + + def f_scale_weight(i, j): + group_idx = j // group_size + if mode.startswith("int"): + w_scaled = tir.round(weight[i, j] / scale[i, group_idx] + tir.const(max_int_value, dtype)) + w_scaled = T.min(T.max(w_scaled, tir.const(0, dtype)), tir.const(max_int_value * 2, dtype)).astype(storage_dtype) + return w_scaled + else: + f_convert = tir_utils._tir_f32_to_uint_to_f4 if dtype == "float32" else tir_utils._tir_f16_to_uint_to_f4 + return f_convert(weight[i, j] / scale[i, group_idx]) + + k = te.reduce_axis((0, n_float_per_int), name="k") + reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name="bitwise_or") + n_i32 = tir.ceildiv(group_size, n_float_per_int) * n_group + if transpose: + w_gathered = te.compute(shape=(n_i32, weight.shape[0]), fcompute=lambda j, i: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name="w_gathered") + scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j]) + else: + w_gathered = te.compute(shape=(weight.shape[0], n_i32), fcompute=lambda i, j: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name="w_gathered") + return w_gathered, scale + + return te_encode_sym if sym else te_encode_asym + + +def decoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, dim_length: tir.PrimExpr, data_transposed: bool=True, transpose_output: bool=False, dtype: str = "float32") -> FTEDequantize: + def te_decode_asym(*args): + n_float_per_u32 = 32 // nbit + data = args[0] + if dtype == "float32": + scale_bias_bf16x2 = args[1] + else: + scale, min_value = args[1], args[2] + + def f_decode_asym(i, j): + if data_transposed: + data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i // n_float_per_u32, j], i % n_float_per_u32, dtype=dtype) + if dtype == "float32": + scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i // group_size, j]) + else: + scale_float, bias_float = scale[i // group_size, j], min_value[i // group_size, j] + else: + data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i, j // n_float_per_u32], j % n_float_per_u32, dtype=dtype) + if dtype == "float32": + scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i, j // group_size]) + else: + scale_float, bias_float = scale[i, j // group_size], min_value[i, j // group_size] + w = data_float * scale_float + bias_float + return w + + shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length) + w = te.compute(shape=shape, fcompute=f_decode_asym, name="decode") + if transpose_output: + w = topi.transpose(w) + return w + + def te_decode_sym(data, scale): + n_float_per_int = storage_nbit // nbit + + def f_decode_sym(i, j): + f_convert = tir_utils._tir_packed_uint_to_uint_to_float(storage_nbit) if mode.startswith("int") else (tir_utils._tir_u32_to_f4_to_f32 if dtype == "float32" else tir_utils._tir_u32_to_f4_to_f16) + if data_transposed: + data_float = f_convert(nbit, data[i // n_float_per_int, j], i % n_float_per_int, dtype=dtype) + scale_float = scale[i // group_size, j] + else: + data_float = f_convert(nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype=dtype) + scale_float = scale[i, j // group_size] + return data_float * scale_float + + shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length) + w = te.compute(shape=shape, fcompute=f_decode_sym, name="decode") + if transpose_output: + w = topi.transpose(w) + return w + + return te_decode_sym if sym else te_decode_asym +# fmt: on + + +# A simple example demo showing how QuantSpecUpdater is used. +# NOTE: This visitor is only for demo purpose and should not be put into real use. +@visitor +class GroupQuantDemoUpdater(QuantSpecUpdater._cls): + def visit_call_(self, call: relax.Call): + if call.op != tvm.ir.Op.get("relax.matmul"): + return + rhs = self.lookup_binding(call.args[1]) + assert rhs is not None + if ( + rhs.op != tvm.ir.Op.get("relax.permute_dims") + or rhs.attrs.axes is not None + or rhs.args[0].struct_info.ndim != 2 + ): + return + + if rhs.args[0] not in self.param_map: + return + param = self.param_map[rhs.args[0]] + # Update to no quantization for matmul with float32 output dtype. + if call.struct_info.dtype == "float32": + param.quant_spec = NoQuantizationSpec(param.param_info.dtype) diff --git a/mlc_llm/quantization/quantization.py b/mlc_llm/quantization/quantization.py new file mode 100644 index 0000000..2922c93 --- /dev/null +++ b/mlc_llm/quantization/quantization.py @@ -0,0 +1,217 @@ +import enum +from dataclasses import dataclass +from typing import Any, Callable, List, Literal, Optional, Tuple, Type, Union + +import tvm +from tvm import relax, te +from tvm.relax.expr_functor import PyExprVisitor, visitor + +FQuantize = Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var] +FTEQuantize = Callable[[te.Tensor], List[te.Tensor]] +FTEDequantize = Callable[[List[te.Tensor]], te.Tensor] + + +@dataclass +class QuantizationSpec: + """The base dataclass of quantization specification. + A specification describes how a parameter is quantized and dequantized. + + A subclass of QuantizationSpec + - contains more data fields (e.g., the "group size" in group quantization) + which instruct the quantization/dequantization, + - defines the `get_quantize_func` method, which returns a function + (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes a + Relax BlockBuilder and the weight relax Var to be quantized, computes + the quantization and returns the relax Var of quantized results. + algorithm of the quantization. + - defines the `get_dequantize_func` method, which returns function + (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes + the quantized results, computes and returns the dequantization result. + - optionally overloads the `get_loaded_tensor_info` when the parameter is + pre-quantized, in which case `get_loaded_tensor_info` needs to be overloaded + so that we know how many quantized data tensors there are, and the dtype + and shape of each quantized data tensor. + """ + + dtype: str + + def get_loaded_tensor_info( + self, pname: str, param_info: relax.TensorStructInfo + ) -> Tuple[List[str], List[relax.TensorStructInfo]]: + """Returns the names and shapes and dtypes of the tensors that need to + be loaded from the disk. + + It is useful when the parameter is pre-quantized. In such cases, we need + to know how many tensors the parameter is quantized into, and together + with the dtype and shape of each tensor, so that we can load the + pre-quantized tensors in. + """ + return [pname], [param_info] + + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: + """Returns the function which computes quantization. + Returning `None` means the parameter does not need quantization or is + pre-quantized. + + The returned function takes a Relax BlockBuilder and a (list of) weight + relax Var to be quantized, computes the quantization and returns the + quantization result Relax Var(s). + + You can use `convert_TE_func` to convert a TE function to the function + of the desired return format. See `group_quantization.py` for examples. + """ + return NotImplementedError() + + def get_dequantize_func( + self, + param_info: relax.TensorStructInfo, + qparam_info: List[relax.TensorStructInfo], + ) -> Optional[FQuantize]: + """Returns the function which computes dequantization. + Returning `None` means the parameter does not need dequantization. + + The returned function takes a Relax BlockBuilder and a (list of) + quantized weight relax Var, computes the dequantization and returns the + result Relax Var(s). + + You can use `convert_TE_func` to convert a TE function to the function + of the desired return format. See `group_quantization.py` for examples. + """ + return NotImplementedError() + + +@dataclass +class NoQuantizationSpec(QuantizationSpec): + """The quantization specification that describes doing no quantization.""" + + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: + return None + + def get_dequantize_func( + self, + param_info: relax.TensorStructInfo, + qparam_info: List[relax.TensorStructInfo], + ) -> Optional[FQuantize]: + return None + + +class ParamQuantKind(enum.IntEnum): + """The parameter quantization kind class. + + We categorized all the parameters in a model into four kinds: + - the weights of the internal linear layers, which are the main targets of quantization, + - the embedding table of every token, + - the weight of the fully-connected layer at the end of the model, which is + used for computes the logits of each input token, + - other parameters (e.g., the weight of layer normalization, etc.). + """ + + linear_weight = 0 + embedding_table = 1 + final_fc_weight = 2 + others = 3 + + +class QuantizationScheme: + """The quantization scheme class describes how an entire model is quantized. + It contains the quantization specification for each parameter quantization kind. + + Besides, it has an optional field for a visitor class which has the ability to + take the constructed model (in format of IRModule) as input, go through the + model and update the QuantizationSpec for certain parameters. + """ + + name: str + linear_weight: QuantizationSpec + embedding_table: QuantizationSpec + final_fc_weight: QuantizationSpec + others: QuantizationSpec + + qspec_updater_class: Optional[Type["QuantSpecUpdater"]] + f_convert_param_bkwd: Optional[Callable[[str, Any], Optional[List[Tuple[str, Any]]]]] + f_compute_relax_param: Optional[Callable[[str, List[Any]], Any]] + f_run_prequantize: Optional[Callable[[str], str]] + + def __init__( + self, + name: str, + linear_weight: QuantizationSpec, + *, + embedding_table: Optional[Union[QuantizationSpec, Literal["same_as_linear_weight"]]] = None, + final_fc_weight: Optional[Union[QuantizationSpec, Literal["same_as_linear_weight"]]] = None, + others: Optional[QuantizationSpec] = None, + qspec_updater_class: Optional[Type["QuantSpecUpdater"]] = None, + ) -> None: + self.name = name + self.linear_weight = linear_weight + self.others = others if others is not None else NoQuantizationSpec(self.model_dtype) + + if embedding_table is None: + self.embedding_table = self.others + elif embedding_table == "same_as_linear_weight": + self.embedding_table = self.linear_weight + else: + self.embedding_table = embedding_table + + if final_fc_weight is None: + self.final_fc_weight = self.others + elif final_fc_weight == "same_as_linear_weight": + self.final_fc_weight = self.linear_weight + else: + self.final_fc_weight = final_fc_weight + + self.qspec_updater_class = qspec_updater_class + self.f_convert_param_bkwd = None + self.f_compute_relax_param = None + self.f_run_prequantize = None + + for spec in [self.linear_weight, self.embedding_table, self.final_fc_weight, self.others]: + if hasattr(spec, "convert_param_bkwd"): + self.f_convert_param_bkwd = spec.convert_param_bkwd + if hasattr(spec, "compute_relax_param"): + self.f_compute_relax_param = spec.compute_relax_param + if hasattr(spec, "run_prequantize"): + self.f_run_prequantize = spec.run_prequantize + + @property + def model_dtype(self) -> str: + """Returns the overall model dtype, which is defined as the dtype of + the linear layers. + """ + return self.linear_weight.dtype + + +def convert_TE_func(te_func: Union[FTEQuantize, FTEDequantize], func_name: str) -> FQuantize: + def func(bb: relax.BlockBuilder, inputs: List[relax.Expr]) -> relax.Var: + return bb.call_te(te_func, *inputs, primfunc_name_hint=func_name) + + return func + + +@visitor +class QuantSpecUpdater(PyExprVisitor): + def __init__(self, param_manager) -> None: + super().__init__() + self.param_manager = param_manager + self.param_map = None + self.builder = relax.BlockBuilder() + + def lookup_binding(self, var: relax.Var): + return self.builder.lookup_binding(var) + + def visit_module(self, mod: tvm.IRModule): + for gv, func in mod.functions.items(): + if not isinstance(func, relax.Function): + continue + if func.attrs is None or not "num_input" in func.attrs: + continue + + self.param_map = dict() + num_input = int(func.attrs["num_input"]) + params_in_func = self.param_manager.params_in_func[gv.name_hint] + assert len(func.params) - num_input == len(params_in_func) + for i, relax_param in enumerate(func.params[num_input:]): + self.param_map[relax_param] = params_in_func[i] + + self.builder.normalize(func) + self.visit_expr(func) diff --git a/mlc_llm/quantization/tir_utils.py b/mlc_llm/quantization/tir_utils.py new file mode 100644 index 0000000..02d4c72 --- /dev/null +++ b/mlc_llm/quantization/tir_utils.py @@ -0,0 +1,106 @@ +"""TIR computation utilities for quantization.""" + +import tvm +from tvm import tir + +# fmt: off +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool=True): + mask = tir.const((1 << 16) - 1, "uint32") + res = [] + for data in [v0, v1]: + u32_val = tir.reinterpret("uint32", data) + if round_to_even: + rounding_bias = ((u32_val >> tir.const(16, "uint32")) & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + u32_val += rounding_bias + res.append((u32_val >> tir.const(16, "uint32")) & mask) + return res[0] | (res[1] << tir.const(16, "uint32")) + + +def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): + mask = tir.const((1 << 16) - 1, "uint32") + x0 = x & mask + x1 = (x >> 16) & mask + return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) + + +def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == "uint32" + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) + + +def _tir_packed_uint_to_uint_to_float(storage_nbit: int): + storage_dtype = "uint" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float32" + val_u32 = tir.reinterpret("uint32", val) + # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) + # e_f32 == 120 -> e_f4 = 1 + # e_f32 < 120 -> e_f4 = 0 + m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") + e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") + s = (val_u32 >> tir.const(31, "uint32")) + e_f4 = tir.Select(e_f32 > tir.const(120, "uint32"), tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float16" + val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) + m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") + e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") + s = (val_u32 >> tir.const(15, "uint32")) + e_f4 = tir.Select(e_f16 > tir.const(8, "uint32"), tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float32" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f32 = 0 + # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f32 = e_f4 | tir.const(120, "uint32") + val_f32 = tir.reinterpret("float32", (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) + + +def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float16" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f16 = e_f4 | tir.const(8, "uint32") + val_f16 = tir.reinterpret("float16", (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) +# fmt: on diff --git a/mlc_llm/relax_model/__init__.py b/mlc_llm/relax_model/__init__.py new file mode 100644 index 0000000..9ee3d0d --- /dev/null +++ b/mlc_llm/relax_model/__init__.py @@ -0,0 +1 @@ +from . import llama diff --git a/mlc_llm/relax_model/chatglm.py b/mlc_llm/relax_model/chatglm.py new file mode 100644 index 0000000..9a2afdf --- /dev/null +++ b/mlc_llm/relax_model/chatglm.py @@ -0,0 +1,797 @@ +import argparse +import math +from dataclasses import dataclass +from typing import List, Tuple + +import tvm +from tvm import relax, te, tir +from tvm.relax.op import ( + astype, + broadcast_to, + expand_dims, + matmul, + maximum, + minimum, + permute_dims, + repeat, + reshape, + split, + squeeze, +) +from tvm.relax.op.nn import silu, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import Embedding, Linear, ModuleList, RotaryEmbedding +from .param_manager import ParamManager + + +@dataclass +class ChatGLMConfig: + def __init__( + self, + add_bias_linear: bool = False, + add_qkv_bias: bool = True, + ffn_hidden_size: int = 13696, + hidden_size: int = 4096, + kv_channels: int = 128, + layernorm_epsilon: float = 1e-05, + multi_query_group_num: int = 2, + num_attention_heads: int = 32, + num_layers: int = 28, + max_sequence_length: int = 2048, + padded_vocab_size: int = 65024, + eos_token_id: int = 2, + bos_token_id: int = 0, + dtype: str = "float32", + **kwargs, + ): + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.ffn_hidden_size = ffn_hidden_size + self.hidden_size = hidden_size + self.kv_channels = kv_channels + self.layernorm_epsilon = layernorm_epsilon + self.multi_query_group_num = multi_query_group_num + self.num_attention_heads = num_attention_heads + self.num_layers = num_layers + self.max_sequence_length = min(2048, max_sequence_length) + self.padded_vocab_size = padded_vocab_size + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.dtype = dtype + self.kwargs = kwargs + + +def _repeat_kv(k: relax.Expr, v: relax.Expr, n_rep: int, shape: relax.Expr): + k = nn.emit(reshape(repeat(k, n_rep, 1), shape)) + v = nn.emit(reshape(repeat(v, n_rep, 1), shape)) + return k, v + + +def _reshape(x: relax.Expr, shape: Tuple[int]): + x = nn.emit(reshape(x, R.shape(shape))) + return x + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, dtype, eps=1e-5): + self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") + self.eps = tvm.tir.const(eps, dtype) + + def forward(self, hidden_states): + def f_rms_norm(x, weight): + is_float32 = x.dtype == "float32" + + def f_square(x): + return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x + + k = te.reduce_axis((0, x.shape[2]), name="k") + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + def f_div_cast(bsz, i, k): + x_val = x[bsz, i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.eps) + + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), + name="rms_norm", + ) + + return nn.emit_te( + f_rms_norm, + hidden_states, + self.weight, + primfunc_name_hint="rms_norm", + ) + + +class CoreAttention(nn.Module): + def __init__(self, config: ChatGLMConfig): + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + self.dtype = config.dtype + + def forward( + self, + q: relax.Expr, + k: relax.Expr, + v: relax.Expr, + attention_mask: relax.Expr, + ) -> relax.Expr: + bsz, sl, nh, hd = q.struct_info.shape + kv_sl = k.struct_info.shape[1] + + # [bsz, nh, sl, hd] + q = nn.emit(permute_dims(q, [0, 2, 1, 3])) + + # [bsz, nh, kv_sl, hd] + k = nn.emit(permute_dims(k, [0, 2, 1, 3])) + v = nn.emit(permute_dims(v, [0, 2, 1, 3])) + + # Calculate Q.K: [bsz, nh, sl, kv_sl] + matmul_result = nn.emit( + matmul(q, permute_dims(k, [0, 1, 3, 2])) + / relax.const(self.norm_factor, q.struct_info.dtype) + ) + attention_scores = _reshape(matmul_result, (bsz, nh, sl, kv_sl)) + + # Apply attention mask: [bsz, nh, sl, kv_sl] + attention_scores = nn.emit( + maximum( + attention_scores, + relax.const( + tvm.tir.min_value(attention_scores.struct_info.dtype).value, + attention_scores.struct_info.dtype, + ), + ) + ) + attention_scores = nn.emit(minimum(attention_scores, attention_mask)) + + # Calculate Softmax(Q.K) + if attention_scores.struct_info.dtype != "float32": + attention_scores = astype(attention_scores, "float32") + attention_probs = nn.emit(softmax(attention_scores, axis=-1)) + if attention_probs.struct_info.dtype != q.struct_info.dtype: + attention_probs = astype(attention_probs, q.struct_info.dtype) + + # Calculate Softmax(Q.K).V + context = nn.emit(matmul(attention_probs, v)) + context = nn.emit(permute_dims(context, [0, 2, 1, 3])) + context = _reshape(context, (bsz, sl, nh * hd)) + + return context + + +class SelfAttention(nn.Module): + def __init__( + self, + config: ChatGLMConfig, + rotary_pos_emb: RotaryEmbedding, + ): + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + # Multi-query attention config + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + + self.query_key_value = Linear( + config.hidden_size, + self.qkv_hidden_size, + config.dtype, + bias=config.add_bias_linear or config.add_qkv_bias, + ) + + self.rotary_pos_emb = rotary_pos_emb + + self.core_attention = CoreAttention(config) + + self.dense = Linear( + self.projection_size, + config.hidden_size, + config.dtype, + bias=config.add_bias_linear, + ) + + self.dtype = config.dtype + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr, relax.Expr], + attention_mask: relax.Expr, + ) -> Tuple[relax.Expr, Tuple[relax.Expr, relax.Expr]]: + # hidden_states: [bsz, sl, hs] + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + bsz, sl, _ = hidden_states.struct_info.shape + kv_sl = all_seq_len_shape.struct_info.values[0] + + mixed_x_layer = nn.emit( + split( + self.query_key_value(hidden_states), + indices_or_sections=[ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + ( + self.num_attention_heads_per_partition + + self.num_multi_query_groups_per_partition + ) + * self.hidden_size_per_attention_head, + ], + axis=-1, + ) + ) + + q_shape = ( + bsz, + sl, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + kv_shape = ( + bsz, + sl, + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + + # queries: [bsz, sl, nh, hd] + q = _reshape(relax.TupleGetItem(mixed_x_layer, 0), q_shape) + + # keys: [bsz, sl, ng, hd] + k = _reshape(relax.TupleGetItem(mixed_x_layer, 1), kv_shape) + + # values: [bsz, sl, ng, hd] + v = _reshape(relax.TupleGetItem(mixed_x_layer, 2), kv_shape) + + # apply rotary embeddings + q, k = self.rotary_pos_emb(q, k, kv_sl - sl) + + assert k.struct_info.shape[0] == 1 and v.struct_info.shape[0] == 1 + squeezed_k, squeezed_v = nn.emit(squeeze(k, axis=0)), nn.emit(squeeze(v, axis=0)) + + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[k_cache, squeezed_k], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[v_cache, squeezed_v], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + + kv_sl = all_seq_len_shape.struct_info.values[0] + bsz, _, n_groups, head_dim = k.struct_info.shape + kv_cache_shape = R.shape([kv_sl, n_groups, head_dim]) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], + ) + ) + v = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], + ) + ) + + n_rep = self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition + kv_attn_shape = R.shape( + [ + bsz, + kv_sl, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ] + ) + k, v = _repeat_kv(k, v, n_rep, kv_attn_shape) + + # core attention computation + context_layer = self.core_attention(q, k, v, attention_mask) + + # apply output projection + output = self.dense(context_layer) + + return output, past_key_value + + +class MLP(nn.Module): + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.dtype = config.dtype + + self.dense_h_to_4h = Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + config.dtype, + bias=config.add_bias_linear, + ) + + def swiglu(x: relax.Expr): + x = nn.emit(split(x, 2, axis=-1)) + return nn.emit(silu(x[0]) * x[1]) + + self.activation_func = swiglu + + self.dense_4h_to_h = Linear( + config.ffn_hidden_size, + config.hidden_size, + config.dtype, + bias=config.add_bias_linear, + ) + + def forward(self, hidden_states): + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.activation_func(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + + return hidden_states + + +class GLMBlock(nn.Module): + def __init__(self, config: ChatGLMConfig, rotary_pos_emb: RotaryEmbedding): + self.input_layernorm = RMSNorm( + hidden_size=config.hidden_size, + dtype=config.dtype, + eps=config.layernorm_epsilon, + ) + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + dtype=config.dtype, + eps=config.layernorm_epsilon, + ) + + self.self_attention = SelfAttention(config, rotary_pos_emb) + self.mlp = MLP(config) + + self.dtype = config.dtype + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: relax.Expr, + ): + layernorm_output = self.input_layernorm(hidden_states) + attention_output, present_key_value = self.self_attention( + layernorm_output, all_seq_len_shape, past_key_value, attention_mask + ) + + # residual connection + layernorm_input = nn.emit(attention_output + hidden_states) + + layernorm_output = self.post_attention_layernorm(layernorm_input) + mlp_output = self.mlp(layernorm_output) + + # residual connection + output = nn.emit(mlp_output + layernorm_input) + + return output, present_key_value + + +class GLMTransformer(nn.Module): + def __init__(self, config: ChatGLMConfig, rotary_pos_emb: RotaryEmbedding): + self.num_layers = config.num_layers + + self.layers = ModuleList([GLMBlock(config, rotary_pos_emb) for _ in range(self.num_layers)]) + self.final_layernorm = RMSNorm( + hidden_size=config.hidden_size, + dtype=config.dtype, + eps=config.layernorm_epsilon, + ) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + attention_mask: relax.Expr, + ): + present_kv_cache = [] + for i, block in enumerate(self.layers): + past_key_value = past_key_values[i * 2], past_key_values[i * 2 + 1] + hidden_states, (present_k_cache, present_v_cache) = block( + hidden_states, + all_seq_len_shape=all_seq_len_shape, + past_key_value=past_key_value, + attention_mask=attention_mask, + ) + present_kv_cache.append(present_k_cache) + present_kv_cache.append(present_v_cache) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states, present_kv_cache + + +class ChatGLMModel(nn.Module): + def __init__(self, config: ChatGLMConfig): + self.num_layers = config.num_layers + + self.embedding = Embedding( + num_embeddings=config.padded_vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + + self.seq_length = config.max_sequence_length + rotary_dim = config.kv_channels // 2 + + self.rotary_pos_emb = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=10000, + max_sequence_length=config.max_sequence_length, + rotary_dim=rotary_dim, + swizzle_style="glm", + dtype=config.dtype, + ) + self.encoder = GLMTransformer(config, self.rotary_pos_emb) + self.output_layer = Linear( + in_features=config.hidden_size, + out_features=config.padded_vocab_size, + bias=False, + dtype=config.dtype, + ) + + self.dtype = config.dtype + + def _prepare_decoder_attention_mask(self, input_shape, kv_sl, dtype): + # create causal mask + # [bsz, sl] -> [bsz, 1, sl, kv_sl] + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + bsz, sl = input_shape + + def min_max_triu_te(): + return te.compute( + (sl, sl), + lambda i, j: tvm.tir.Select( + j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) + ), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + mask = nn.emit(expand_dims(mask, 0)) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, sl, sl))) + if kv_sl == sl: + return diag_mask + + def extend_te(x, sl, kv_sl): + return te.compute( + (bsz, 1, sl, kv_sl), + lambda b, _, i, j: te.if_then_else( + j < kv_sl - sl, + tvm.tir.max_value(dtype), + x[b, _, i, j - (kv_sl - sl)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, sl, kv_sl) + else: + # Get kv_sl from input parameters + # [bsz, sl=1] -> [bsz, 1, sl=1, kv_sl] + bsz, sl = input_shape + mask = relax.op.full( + (bsz, 1, sl, kv_sl), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + return nn.emit(mask) + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + batch_size, seq_length = input_ids.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + + # Token Embeddings + inputs_embeds = self.embedding(input_ids) + + attention_mask = self._prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + dtype=self.dtype, + ) + + hidden_states, present_kv_cache = self.encoder( + inputs_embeds, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + attention_mask=attention_mask, + ) + + return hidden_states, present_kv_cache + + +class ChatGLMForCausalLM(nn.Module): + def __init__(self, config: ChatGLMConfig): + self.transformer = ChatGLMModel(config) + + self.dtype = config.dtype + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.transformer( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slice_last(x: te.Tensor): + _, sl, hs = x.shape + return te.compute( + shape=(1, 1, hs), + fcompute=lambda i, _, k: x[i, sl - 1, k], + name="slice_last", + ) + + hidden_states = nn.emit_te( + te_slice_last, + hidden_states, + primfunc_name_hint="slice_last", + ) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + lm_logits = self.transformer.output_layer(hidden_states) + + if lm_logits.struct_info.dtype != "float32": + lm_logits = nn.emit(astype(lm_logits, "float32")) + + return lm_logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embedding.weight" in name: + return ParamQuantKind.embedding_table + elif "transformer.output_layer.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: ChatGLMConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "prefill" + + bsz = tvm.tir.IntImm("int64", 1) + sl = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = ChatGLMForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, sl), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), + ) + + with bb.dataflow(): + logits, key_value_cache = model( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: ChatGLMConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.SizeVar("m", "int64") + + with bb.function(func_name): + model = ChatGLMForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + config.multi_query_group_num, + config.hidden_size // config.num_attention_heads, + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_layers * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder((1, 1, config.padded_vocab_size), dtype="float32", name="logits") + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args: argparse.Namespace, hf_config): + model = args.model + dtype = args.quantization.model_dtype + + if model.startswith("chatglm2") or model.startswith("codegeex2") or model.startswith("chatglm3"): + config = ChatGLMConfig( + **hf_config, + dtype=dtype, + ) + + param_manager = ParamManager() + bb = relax.BlockBuilder() + create_encoding_func(bb, param_manager, config, args.quantization) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model, + max_window_size=config.max_sequence_length, + stop_tokens=[0], + add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, + ) + + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if "transformer.embedding" in pname: + return [ + pname.replace("transformer.embedding", "transformer.embedding.word_embeddings") + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if "transformer.embedding.word_embeddings" in torch_pname: + return [ + ( + torch_pname.replace( + "transformer.embedding.word_embeddings", + "transformer.embedding", + ), + torch_param.astype(dtype), + ) + ] + else: + return [(torch_pname, torch_param.astype(dtype))] + + param_manager.set_param_loading_func( + args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd + ) + return mod, param_manager, [None] * len(param_manager.param_names), config + + raise ValueError(f"Unsupported model {model}") diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py new file mode 100644 index 0000000..be0c477 --- /dev/null +++ b/mlc_llm/relax_model/commons.py @@ -0,0 +1,363 @@ +import json +from typing import Dict, List, Optional + +import mlc_llm +import tvm +from tvm import relax, te, tir, topi + + +def create_metadata_func( + bb: relax.BlockBuilder, + model_name: str, + max_window_size: int, + stop_tokens: List[int], + add_prefix_space: bool, + prefill_chunk_size: int = -1, + sliding_window: int = -1, +): + metadata = json.dumps( + { + "model_name": model_name, + "max_window_size": max_window_size, + "stop_tokens": stop_tokens, + "add_prefix_space": add_prefix_space, + "prefill_chunk_size": prefill_chunk_size, + "sliding_window": sliding_window, + } + ) + with bb.function("get_metadata", params=[]): + bb.emit_func_output(relax.StringImm(metadata)) + + +def _get_shard_strategies( + model_config, num_shards: int, param_shape_is_already_sharded: bool +) -> Dict[str, tvm.tir.PrimFunc]: + head_dim = model_config.hidden_size // model_config.num_attention_heads + q_heads = model_config.num_attention_heads + kv_heads = model_config.get_num_key_value_heads() + + # pylint: disable=invalid-name + def shard_qkv_weight_scale(weight: relax.TensorStructInfo): + (spatial, red), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((spatial, red), dtype=dtype) + w = topi.reshape(a, (spatial // head_dim, head_dim, red)) + q = te.compute((q_heads, head_dim, red), lambda i, j, k: w[i, j, k]) + k = te.compute((kv_heads, head_dim, red), lambda i, j, k: w[q_heads + i, j, k]) + v = te.compute((kv_heads, head_dim, red), lambda i, j, k: w[q_heads + kv_heads + i, j, k]) + q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim, red)) + k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim, red)) + v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim, red)) + w = topi.concatenate((q, k, v), axis=1) + w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim, red)) + func = te.create_prim_func([a, w]) + return func + + def shard_k_weight_scale(weight: relax.TensorStructInfo): + (spatial, red), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((spatial, red), dtype=dtype) + w = topi.reshape(a, (spatial, num_shards, red // num_shards)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_axis_0(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (num_shards, red // num_shards, spatial)) + func = te.create_prim_func([a, w]) + return func + + def shard_axis_1(weight: relax.TensorStructInfo): + (spatial, red), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((spatial, red), dtype=dtype) + w = topi.reshape(a, (spatial, num_shards, red // num_shards)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): + (spatial, red), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((spatial, red), dtype=dtype) + g = te.compute((spatial // 2, red), lambda i, j: a[i, j]) + u = te.compute((spatial // 2, red), lambda i, j: a[spatial // 2 + i, j]) + g = topi.reshape(g, (num_shards, spatial // 2 // num_shards, red)) + u = topi.reshape(u, (num_shards, spatial // 2 // num_shards, red)) + w = topi.concatenate((g, u), axis=1) + w = topi.reshape(w, (num_shards, spatial // num_shards, red)) + func = te.create_prim_func([a, w]) + return func + + # pylint: enable=invalid-name + + return { + "shard_qkv": shard_qkv_weight_scale, + "shard_mlp_k": shard_k_weight_scale, + "shard_o_proj_k": shard_k_weight_scale, + "shard_gate_up": shard_gate_up_weight_scale, + "shard_axis_0": shard_axis_0, + "shard_axis_1": shard_axis_1, + } + + +def _get_shard_strategies_ft( + model_config, num_shards: int, param_shape_is_already_sharded: bool +) -> Dict[str, tvm.tir.PrimFunc]: + q_heads = model_config.num_attention_heads + kv_heads = model_config.get_num_key_value_heads() + + def shard_qkv_weight_scale(x: relax.TensorStructInfo): + (red, spatial), dtype = x.shape, x.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + head_dim = spatial // (q_heads + 2 * kv_heads) + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (red, spatial // head_dim, head_dim)) + q = te.compute((red, q_heads, head_dim), lambda i, j, k: w[i, j, k]) + k = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + j, k]) + v = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + kv_heads + j, k]) + q = topi.reshape(q, (red, num_shards, q_heads // num_shards, head_dim)) + k = topi.reshape(k, (red, num_shards, kv_heads // num_shards, head_dim)) + v = topi.reshape(v, (red, num_shards, kv_heads // num_shards, head_dim)) + w = topi.concatenate((q, k, v), axis=2) + w = topi.reshape(w, (red, num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_k_weight(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (num_shards, red // num_shards, spatial)) + func = te.create_prim_func([a, w]) + return func + + def shard_axis_0(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (num_shards, red // num_shards, spatial)) + func = te.create_prim_func([a, w]) + return func + + def shard_axis_1(weight: relax.TensorStructInfo): + (spatial, red), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((spatial, red), dtype=dtype) + w = topi.reshape(a, (spatial, num_shards, red // num_shards)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_gate_up_weight_scale(x: relax.TensorStructInfo): + (red, spatial), dtype = x.shape, x.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + g = te.compute((red, spatial // 2), lambda i, j: a[i, j]) + u = te.compute((red, spatial // 2), lambda i, j: a[i, spatial // 2 + j]) + g = topi.reshape(g, (red, num_shards, spatial // 2 // num_shards)) + u = topi.reshape(u, (red, num_shards, spatial // 2 // num_shards)) + w = topi.concatenate((g, u), axis=2) + w = topi.reshape(w, (red, num_shards, spatial // num_shards)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + return { + "shard_qkv": shard_qkv_weight_scale, + "shard_mlp_k": shard_k_weight, + "shard_o_proj_k": shard_k_weight, + "shard_gate_up": shard_gate_up_weight_scale, + "shard_axis_0": shard_axis_0, + "shard_axis_1": shard_axis_1, + } + + +def create_shard_info_func(param_manager, args, model_config) -> tvm.IRModule: + shard_strategy_to_func = _get_shard_strategies( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + + shard_info_dict = {} + shard_funcs = {} + + def add_to_shard_info(param_name: str, func_name: Optional[str]): + shard_info = [] + if func_name is not None: + func = shard_funcs[func_name] + buffer = func.buffer_map[func.params[-1]] + shape = [int(i) for i in buffer.shape] + dtype = str(buffer.dtype) + shard_info.append((func_name, [shape, dtype])) + + shard_info_dict[param_name] = shard_info + + q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")] + for _, param in param_manager.params.items(): + if param.shard_strategy is None: + pass + elif param.shard_strategy in shard_strategy_to_func: + for i, weight in enumerate(param_manager.param2qrange[param]): + if args.use_presharded_weights: + sharding_func_name = None + else: + sharding_func_name = f"{param.shard_strategy}_{i}" + if sharding_func_name not in shard_funcs: + shard_funcs[sharding_func_name] = shard_strategy_to_func[ + param.shard_strategy + ](q_params[weight]) + add_to_shard_info(f"param_{weight}", sharding_func_name) + else: + raise NotImplementedError(f"Shard strategy not implemented: {param.shard_strategy}") + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + + for name, func in shard_funcs.items(): + func = func.with_attr({"global_symbol": name}) + bb.add_func(func, name) + + with bb.function("get_shard_info", params=[]): + bb.emit_func_output(relax.StringImm(json.dumps(shard_info_dict))) + + return bb.get() + + +def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: + use_ft_quant = args.quantization.name in [ + "q4f16_ft", + "q8f16_ft", + "q4f16_ft_group", + "q8f16_ft_group", + ] + + if use_ft_quant: + shard_strategy_to_func = _get_shard_strategies_ft( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + else: + shard_strategy_to_func = _get_shard_strategies( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + + q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")] + + # The order of the quantized parameters must be preserved. + # Therefore, we need to loop over q_params and look up information + # as needed, rather than looping over original parameters and + # looking up the quantized parameters as needed. + orig_param_lookup = {} + for param in param_manager.params_in_func["prefill"]: + qrange = param_manager.param2qrange[param] + for i_orig_part, i_qparam in enumerate(qrange): + orig_param_lookup[i_qparam] = ( + param, + i_orig_part, + len(qrange), + ) + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + with bb.function("transform_params"): + rank = tir.SizeVar("rank", "int64") + # TODO(Lunderberg): Support primitive inputs to relax + # functions. Currently, using a PrimStructInfo as the + # argument results in an error thrown during + # `vm_shape_lower.cc`, due to BindParams failing to replace + # the symbolic variable "rank" when defined in a R.PrimValue. + # + # rank_arg = relax.Var("rank", relax.PrimStructInfo(value=rank)) + rank_arg = relax.Var("rank_arg", relax.ShapeStructInfo([rank])) + + args = [rank_arg] + output = [] + + for i_qparam, qparam_sinfo in enumerate(q_params): + param, i_orig_part, num_orig_parts = orig_param_lookup[i_qparam] + + if isinstance(param.quant_spec, mlc_llm.quantization.NoQuantizationSpec): + arg_name = param.name + elif num_orig_parts == 1: + arg_name = f"{param.name}.quantized" + else: + arg_name = f"{param.name}.quantized_{i_orig_part}" + + arg = relax.Var(arg_name, qparam_sinfo) + + if param.shard_strategy is None or ( + use_ft_quant + and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"] + and qparam_sinfo.shape[0] == 1 + ): + sharded = arg + else: + strategy_func = shard_strategy_to_func[param.shard_strategy]( + qparam_sinfo + ).without_attr("global_symbol") + strategy_gvar = bb.add_func( + strategy_func, + func_name=f"{arg_name}.sharding_func", + ) + + # TODO(Lunderberg): Write the strategies as relax + # functions, so the sharded shapes can be inferred. + reordered_buffer = strategy_func.buffer_map[strategy_func.params[-1]] + reordered_sinfo = relax.TensorStructInfo( + reordered_buffer.shape, reordered_buffer.dtype + ) + reordered = relax.op.call_tir( + strategy_gvar, relax.Tuple([arg]), out_sinfo=reordered_sinfo + ) + + # TODO(Lunderberg): Allow relax.PrimValue as the index + # in a TupleGetItem. This would allow all of the + # splits to be generated at once in the merged + # function, and could be optimized to an in-place view. + # + # split = relax.op.split(reordered, indices_or_sections=num_shards, axis=0)[rank] + split = relax.op.strided_slice( + reordered, + axes=[0], + begin=[rank], + end=[rank + 1], + assume_inbound=True, + ) + + sharded = relax.op.squeeze(split, axis=0) + + args.append(arg) + output.append(sharded) + + with bb.dataflow(): + gv = bb.emit_output(output) + bb.emit_func_output(output=gv, params=args) + + return bb.get() diff --git a/mlc_llm/relax_model/gpt_bigcode.py b/mlc_llm/relax_model/gpt_bigcode.py new file mode 100644 index 0000000..a089390 --- /dev/null +++ b/mlc_llm/relax_model/gpt_bigcode.py @@ -0,0 +1,661 @@ +import argparse +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import tvm +from tvm import relax, te +from tvm.relax.op import ( + astype, + broadcast_to, + expand_dims, + matmul, + maximum, + minimum, + permute_dims, + reshape, + squeeze, +) +from tvm.relax.op.nn import gelu, layer_norm, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import Embedding, Linear, ModuleList +from .param_manager import ParamManager + + +@dataclass +class GPTBigCodeConfig: + def __init__( + self, + bos_token_id: int = 0, + eos_token_id: int = 0, + initializer_range: float = 0.02, + layer_norm_epsilon: float = 1e-05, + max_sequence_length: int = 2048, + n_embd: int = 6144, + n_head: int = 48, + n_inner: int = 24576, + n_layer: int = 40, + n_positions: int = 8192, + scale_attn_weights: bool = True, + vocab_size: int = 49152, + dtype: str = "float32", + **kwargs, + ): + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.max_sequence_length = max_sequence_length + self.n_embd = n_embd + self.n_head = n_head + self.n_inner = n_inner + self.n_layer = n_layer + self.n_positions = n_positions + self.scale_attn_weights = scale_attn_weights + self.vocab_size = vocab_size + self.dtype = dtype + self.kwargs = kwargs + + +def _prepare_decoder_attention_mask(input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, tgt_seq_len, 1, src_seq_len] + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + bsz, tgt_len = input_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select( + j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) + ), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + mask = nn.emit(expand_dims(mask, 1)) + diag_mask = nn.emit(broadcast_to(mask, (bsz, tgt_len, 1, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, tgt_len, 1, src_len), + lambda b, i, _, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, i, _, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, tgt_seq_len, 1, src_seq_len] + bsz, tgt_len = input_shape + mask = relax.op.full( + (bsz, tgt_len, 1, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + return nn.emit(mask) + + +def apply_position_embedding(t_embd, weight, offset: int = 0): + def f_position_embedding(tensor, weight, offset): + def position_compute(*idx): + b, s, e = idx + return weight[s + offset, e] + tensor[b, s, e] + + return tvm.te.compute(tensor.shape, position_compute, name="position") + + hidden_states = nn.emit_te( + f_position_embedding, + t_embd, + weight, + offset, + primfunc_name_hint="position_embedding", + ) + return hidden_states + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + dtype, + eps=1e-5, + ): + super().__init__() + self.dtype = dtype + + self.eps = eps + self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="weight") + self.bias = nn.Parameter((hidden_size,), dtype=dtype, name="bias") + + def forward(self, x: relax.Expr) -> relax.Var: + if x.struct_info.dtype != self.dtype: + x = nn.emit(relax.op.astype(x, self.dtype)) + x = nn.emit( + layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class GPTBigCodeAttention(nn.Module): + """Multi-query attention from 'Fast Transformer Decoding: One Write-Head is All You Need'""" + + def __init__(self, config: GPTBigCodeConfig): + if config.n_embd % config.n_head != 0: + raise ValueError( + f"hidden_size must be divisible by n_head (got `hidden_size`: {config.n_embd}" + f" and `n_head`: {config.n_head})." + ) + self.n_embd = config.n_embd + self.n_head = config.n_head + self.head_dim = config.n_embd // config.n_head + + self.c_attn = Linear(self.n_embd, self.n_embd + 2 * self.head_dim, config.dtype, bias=True) + self.c_proj = Linear(self.n_embd, self.n_embd, config.dtype, bias=True) + + self.dtype = config.dtype + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: + # hidden_states: [batch_size, seq_len, n_embd] + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + batch_size, seq_len, _ = hidden_states.struct_info.shape + kv_seq_len = all_seq_len_shape.struct_info.values[0] + + def te_slice(x: te.Tensor, start: int, end: int): + batch_size, seq_len, _ = x.shape + return te.compute( + shape=(batch_size, seq_len, end - start), + fcompute=lambda i, j, k: x[i, j, start + k], + name="slice", + ) + + query_key_value = self.c_attn(hidden_states) + # queries: [batch_size, seq_len, n_embd] + q = nn.emit_te(te_slice, query_key_value, 0, self.n_embd, primfunc_name_hint="slice") + # keys: [batch_size, seq_len, head_dim] + k = nn.emit_te( + te_slice, + query_key_value, + self.n_embd, + self.n_embd + self.head_dim, + primfunc_name_hint="slice", + ) + # values: [batch_size, seq_len, head_dim] + v = nn.emit_te( + te_slice, + query_key_value, + self.n_embd + self.head_dim, + self.n_embd + 2 * self.head_dim, + primfunc_name_hint="slice", + ) + + squeezed_k = nn.emit(squeeze(k, axis=0)) + squeezed_v = nn.emit(squeeze(v, axis=0)) + + assert k.struct_info.shape[0] == 1 and v.struct_info.shape[0] == 1 + + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[k_cache, squeezed_k], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[v_cache, squeezed_v], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + + batch_size, _, head_size = k.struct_info.shape + kv_cache_shape = R.shape([kv_seq_len, head_size]) + kv_states_shape = R.shape([batch_size, kv_seq_len, head_size]) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], + ) + ) + v = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], + ) + ) + + k = nn.emit(reshape(k, kv_states_shape)) + v = nn.emit(reshape(v, kv_states_shape)) + + q_state_shape = R.shape([batch_size, seq_len * self.n_head, self.head_dim]) + q = nn.emit(reshape(q, q_state_shape)) + + # Calculate Q.K + attn_weights = nn.emit( + matmul(q, permute_dims(k, [0, 2, 1])) + / relax.const(math.sqrt(self.head_dim), q.struct_info.dtype) + ) + + # Apply attention mask + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_shape = R.shape([batch_size, seq_len, self.n_head, kv_seq_len]) + attn_view = R.shape([batch_size, seq_len * self.n_head, kv_seq_len]) + attn_weights = nn.emit(reshape(attn_weights, attn_shape)) + attn_weights = nn.emit(minimum(attn_weights, attention_mask)) + attn_weights = nn.emit(reshape(attn_weights, attn_view)) + + # Calculate Softmax(Q.K) + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != q.struct_info.dtype: + attn_weights = astype(attn_weights, q.struct_info.dtype) + + # Calculate Softmax(Q.K).V + attn_output = nn.emit(matmul(attn_weights, v)) + + # Apply output projection + attn_output = self.c_proj( + reshape( + attn_output, + (batch_size, seq_len, self.n_embd), + ) + ) + + return attn_output, past_key_value + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.dtype = config.dtype + + self.c_fc = Linear(config.n_embd, config.n_inner, config.dtype, bias=True) + self.c_proj = Linear(config.n_inner, config.n_embd, config.dtype, bias=True) + + def forward(self, hidden_states): + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + hidden_states = self.c_fc(hidden_states) + hidden_states = nn.emit(gelu(hidden_states)) + hidden_states = self.c_proj(hidden_states) + + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + self.dtype = config.dtype + + self.ln_1 = LayerNorm( + hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon + ) + self.ln_2 = LayerNorm( + hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon + ) + + self.attn = GPTBigCodeAttention(config) + self.mlp = GPTBigCodeMLP(config) + + def forward( + self, + hidden_states, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ): + attn_input = self.ln_1(hidden_states) + attn_output, present_key_value = self.attn( + attn_input, all_seq_len_shape, past_key_value, attention_mask + ) + + # residual connection + attn_output = nn.emit(attn_output + hidden_states) + + mlp_input = self.ln_2(attn_output) + mlp_output = self.mlp(mlp_input) + + # residual connection + hidden_states = nn.emit(astype(mlp_output, self.dtype) + attn_output) + + return hidden_states, present_key_value + + +class GPTBigCodeModel(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + self.wte = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.n_embd, + dtype=config.dtype, + ) + self.wpe = Embedding( + num_embeddings=config.n_positions, + embedding_dim=config.n_embd, + dtype=config.dtype, + ) + + self.h = ModuleList([GPTBigCodeBlock(config) for _ in range(config.n_layer)]) + self.ln_f = LayerNorm( + hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon + ) + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + batch_size, seq_length = input_ids.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + + # Token Embeddings + t_embd = self.wte(input_ids) + + # Position Embeddings + offset = seq_length_with_past - seq_length + hidden_states = apply_position_embedding(t_embd, self.wpe.weight, offset=offset) + + attention_mask = _prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + dtype=hidden_states.struct_info.dtype, + ) + + present_kv_cache = [] + for i, block in enumerate(self.h): + past_key_value = ( + (past_key_values[i * 2], past_key_values[i * 2 + 1]) + if past_key_values is not None + else None + ) + hidden_states, (present_k_cache, present_v_cache) = block( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + present_kv_cache.append(present_k_cache) + present_kv_cache.append(present_v_cache) + hidden_states = self.ln_f(hidden_states) + return hidden_states, present_kv_cache + + +class GPTBigCodeForCausalLM(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + self.dtype = config.dtype + + self.transformer = GPTBigCodeModel(config) + self.lm_head = Linear( + in_features=config.n_embd, + out_features=config.vocab_size, + bias=False, + dtype=config.dtype, + ) + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.transformer( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slice_last(x: te.Tensor): + _, seq_len, n_embd = x.shape + return te.compute( + shape=(1, 1, n_embd), + fcompute=lambda i, _, k: x[i, seq_len - 1, k], + name="slice_last", + ) + + hidden_states = nn.emit_te( + te_slice_last, + hidden_states, + primfunc_name_hint="slice_last", + ) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + logits = self.lm_head(hidden_states) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "wte.weight" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif "wpe" not in name and param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTBigCodeConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "prefill" + + batch_size = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = GPTBigCodeForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), + ) + + with bb.dataflow(): + logits, key_value_cache = model( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTBigCodeConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.IntImm("int64", 1) + all_seq_len = tvm.tir.SizeVar("m", "int64") + + with bb.function(func_name): + model = GPTBigCodeForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> None: + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + config.n_embd // config.n_head, + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.n_layer * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args: argparse.Namespace, hf_config): + model = args.model + dtype = args.quantization.model_dtype + max_seq_len = args.max_seq_len + + if ( + model.startswith("starcoder") + or model.startswith("WizardCoder-") + or model.startswith("gpt_bigcode") + ): + config = GPTBigCodeConfig( + **hf_config, + dtype=dtype, + ) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + elif config.max_sequence_length is None: + config.max_sequence_length = 2048 + + param_manager = ParamManager() + bb = relax.BlockBuilder() + create_encoding_func(bb, param_manager, config, args.quantization) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model, + max_window_size=config.max_sequence_length, + stop_tokens=[0], + add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, + ) + + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_param_bkwd=lambda torch_pname, torch_param: [ + (torch_pname, torch_param.astype(dtype)) + ], + ) + return mod, param_manager, [None] * len(param_manager.param_names), config + + raise ValueError(f"Unsupported model {model}") diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py new file mode 100644 index 0000000..cdf80d1 --- /dev/null +++ b/mlc_llm/relax_model/gpt_neox.py @@ -0,0 +1,733 @@ +# pylint: disable=missing-docstring,too-few-public-methods,too-many-instance-attributes,invalid-name,too-many-locals,too-many-arguments +import argparse +import math +from typing import List, Optional, Tuple, Union + +import tvm +from tvm import relax, te +from tvm.relax.op import ( + astype, + broadcast_to, + matmul, + maximum, + minimum, + permute_dims, + reshape, + squeeze, +) +from tvm.relax.op.nn import gelu, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding +from .param_manager import ParamManager + + +class GPTNeoXConfig: # pylint: disable=too-many-instance-attributes + def __init__( + self, + use_parallel_residual, + hidden_size, + intermediate_size, + num_attention_heads, + num_hidden_layers, + vocab_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + max_sequence_length, + dtype, + ffn_out_dtype, + **kwargs, + ): + self.use_parallel_residual = use_parallel_residual + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.layer_norm_eps = layer_norm_eps + self.max_sequence_length = max_sequence_length + self.dtype = dtype + self.ffn_out_dtype = ffn_out_dtype + self.kwargs = kwargs + + +class GPTNeoXAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + rotary_embedding: RotaryEmbedding, + dtype: str, + ): + if hidden_size % num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size}" + f" and `num_heads`: {num_heads})." + ) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.rotary_embedding = rotary_embedding + self.query_key_value = Linear(hidden_size, hidden_size * 3, dtype, bias=True) + self.dense = Linear(hidden_size, hidden_size, dtype, bias=True) + self.dtype = dtype + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: + # hidden_states: [batch_size, seq_len, hidden_size] + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + batch_size, seq_len, _ = hidden_states.struct_info.shape + kv_seq_len = all_seq_len_shape.struct_info.values[0] + + # qkv_states: [batch_size, seq_len, hidden_size * 3] + qkv_states = nn.emit( + relax.op.split( + reshape( + self.query_key_value(hidden_states), + (batch_size, seq_len, self.num_heads, 3 * self.head_dim), + ), + indices_or_sections=3, + axis=-1, + ) + ) + + # q/k/v states: [batch_size, seq_len, num_attention_heads, head_size] + q, k, v = [relax.TupleGetItem(qkv_states, idx) for idx in range(3)] + q, k = self.rotary_embedding(q, k, kv_seq_len - seq_len) + + if past_key_value is not None: + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache, v_cache = past_key_value + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[k_cache, squeeze(k, axis=0)], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[v_cache, squeeze(v, axis=0)], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + batch_size, _, num_heads, head_size = k.struct_info.shape + kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size]) + kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size]) + k = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], + ) + ) + v = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], + ) + ) + k = nn.emit(reshape(k, kv_states_shape)) + v = nn.emit(reshape(v, kv_states_shape)) + past_key_value = (k_cache, v_cache) + else: + past_key_value = (None, None) + + q = nn.emit(permute_dims(q, [0, 2, 1, 3])) + k = nn.emit(permute_dims(k, [0, 2, 1, 3])) + v = nn.emit(permute_dims(v, [0, 2, 1, 3])) + + # Calculate QK + attn_weights = nn.emit( + matmul(q, permute_dims(k, [0, 1, 3, 2])) + / relax.const( + math.sqrt(self.head_dim), + q.struct_info.dtype, + ) + ) + # Apply attention mask + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(minimum(attn_weights, attention_mask)) + # Calculate Softmax(QK) + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != q.struct_info.dtype: + attn_weights = astype(attn_weights, q.struct_info.dtype) + # Calculate Softmax(QK)V + attn_output = nn.emit(matmul(attn_weights, v)) + # Apply output projection + attn_output = self.dense( + reshape( + permute_dims(attn_output, [0, 2, 1, 3]), + (batch_size, seq_len, self.hidden_size), + ) + ) + return attn_output, past_key_value + + +class GPTNeoXMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + dtype: str, + out_dtype: Optional[str], + ): + super().__init__() + if out_dtype is None: + out_dtype = dtype + self.dense_h_to_4h = Linear( + hidden_size, + intermediate_size, + dtype=dtype, + out_dtype=out_dtype, + ) + self.dense_4h_to_h = Linear( + intermediate_size, + hidden_size, + dtype=dtype, + out_dtype=out_dtype, + ) + self.dtype = dtype + + def forward(self, hidden_states): + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = nn.emit(gelu(hidden_states)) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + hidden_states = self.dense_4h_to_h(hidden_states) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + num_heads: int, + use_parallel_residual: bool, + rotary_embedding: RotaryEmbedding, + dtype: str, + ffn_out_dtype: Optional[str], + ): + self.input_layernorm = LayerNorm( + hidden_size, + eps=layer_norm_eps, + dtype=dtype, + ) + self.post_attention_layernorm = LayerNorm( + hidden_size, + eps=layer_norm_eps, + dtype=dtype, + ) + self.attention = GPTNeoXAttention( + hidden_size, + num_heads=num_heads, + rotary_embedding=rotary_embedding, + dtype=dtype, + ) + self.mlp = GPTNeoXMLP( + hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + out_dtype=ffn_out_dtype, + ) + self.use_parallel_residual = use_parallel_residual + self.dtype = dtype + + def forward( + self, + hidden_states, + all_seq_len_shape: relax.Expr, + past_key_value: Optional[Tuple[relax.Expr]] = None, + attention_mask: Optional[relax.Expr] = None, + ): + attn_input = self.input_layernorm(hidden_states) + attn_output, present_key_value = self.attention( + attn_input, + all_seq_len_shape, + past_key_value, + attention_mask, + ) + if self.use_parallel_residual: + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = nn.emit(mlp_output + attn_output + hidden_states) + else: + attn_output = nn.emit(attn_output + hidden_states) + mlp_input = self.post_attention_layernorm(attn_output) + mlp_output = self.mlp(mlp_input) + hidden_states = nn.emit(astype(mlp_output, self.dtype) + attn_output) + return hidden_states, present_key_value + + +def _prepare_decoder_attention_mask(input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + bsz, tgt_len = input_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select( + j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) + ), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, _, i, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + mask = relax.op.full( + (bsz, 1, tgt_len, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + return nn.emit(mask) + + +class GPTNeoXEmbedTokens(nn.Module): + def __init__(self, config: GPTNeoXConfig): + self.embed_in = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + + def forward(self, input_ids: relax.Expr): + return self.embed_in(input_ids) + + +class GPTNeoXEmbedTokensWrapper(nn.Module): + def __init__(self, config: GPTNeoXConfig): + # build a wrapper to ensure that the naming of the embed_in parameter is consistent + self.gpt_neox = GPTNeoXEmbedTokens(config) + + def forward(self, input_ids: relax.Expr): + return self.gpt_neox(input_ids) + + +class GPTNeoXModel(nn.Module): + def __init__( + self, + config: GPTNeoXConfig, + sep_embed: bool = False, + ): + rotary_embedding = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=config.rotary_emb_base, + max_sequence_length=config.max_sequence_length, + rotary_pct=config.rotary_pct, + dtype=config.dtype, + ) + + self.embed_in = None + if not sep_embed: + self.embed_in = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + + self.layers = ModuleList( + [ + GPTNeoXLayer( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + num_heads=config.num_attention_heads, + rotary_embedding=rotary_embedding, + use_parallel_residual=config.use_parallel_residual, + dtype=config.dtype, + ffn_out_dtype=config.ffn_out_dtype, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = LayerNorm( + hidden_size=config.hidden_size, + eps=config.layer_norm_eps, + dtype=config.dtype, + ) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: Optional[Tuple[relax.Expr, relax.Expr]], + ): + # embed positions + hidden_states = self.embed_in(inputs) if self.embed_in else inputs + + batch_size, seq_length, _ = hidden_states.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + attention_mask = _prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + dtype=hidden_states.struct_info.dtype, + ) + present_kv_cache = [] + for i, layer in enumerate(self.layers): + past_key_value = ( + (past_key_values[i * 2], past_key_values[i * 2 + 1]) + if past_key_values is not None + else None + ) + hidden_states, (present_k_cache, present_v_cache) = layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + present_kv_cache.append(present_k_cache) + present_kv_cache.append(present_v_cache) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states, present_kv_cache + + +class GPTNeoXForCausalLM(nn.Module): + def __init__( + self, + config: GPTNeoXConfig, + sep_embed: bool = False, + ): + self.gpt_neox = GPTNeoXModel(config, sep_embed) + self.embed_out = Linear( + in_features=config.hidden_size, + out_features=config.vocab_size, + bias=False, + dtype="float32", + ) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: Optional[List[relax.Expr]], + ): + hidden_states, key_value_cache = self.gpt_neox( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def _slice(x: te.Tensor): + _, seq_len, hidden_dim = x.shape + return te.compute( + shape=(1, 1, hidden_dim), + fcompute=lambda i, _, k: x[i, seq_len - 1, k], + name="slice", + ) + + hidden_states = nn.emit_te( + _slice, + hidden_states, + primfunc_name_hint="slice", + ) + hidden_states = astype(hidden_states, "float32") + logits = self.embed_out(hidden_states) + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_in.weight" in name: + return ParamQuantKind.embedding_table + elif "embed_out.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTNeoXConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = GPTNeoXEmbedTokensWrapper(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("embed") + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTNeoXConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + batch_size = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = GPTNeoXForCausalLM(config, sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder( + (batch_size, seq_len, hidden_size), + dtype=config.dtype, + name="input_embeds", + ) + if sep_embed + else nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTNeoXConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + batch_size = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.IntImm("int64", 1) + all_seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = GPTNeoXForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var( + "all_seq_len", + relax.ShapeStructInfo((all_seq_len,)), + ) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func( + bb: relax.BlockBuilder, + config: GPTNeoXConfig, +) -> None: + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: GPTNeoXConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model( + args: argparse.Namespace, + hf_config, +): + model = args.model + dtype = args.quantization.model_dtype + ffn_out_dtype = "float32" + sep_embed = args.sep_embed + + if model.startswith("dolly-"): + stop_tokens = [2] + ffn_out_dtype = "float16" + elif model.startswith("stablelm-"): + stop_tokens = [50278, 50279, 50277, 1, 0] + ffn_out_dtype = "float16" + elif model.lower().startswith("stablecode-"): + stop_tokens = [0] + elif model.lower().startswith("redpajama-"): + stop_tokens = [0] + else: + raise ValueError(f"Unsupported model {model}") + + config = GPTNeoXConfig( + **hf_config, + max_sequence_length=args.max_seq_len if args.max_seq_len != -1 else 2048, + dtype=dtype, + ffn_out_dtype=ffn_out_dtype, + ) + + param_manager = ParamManager() + bb = relax.BlockBuilder() + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model, + max_window_size=config.max_sequence_length, + stop_tokens=stop_tokens, + add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, + ) + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + # torch_param: numpy.ndarray + if "layernorm" in torch_pname or "layer_norm" in torch_pname or "embed_out" in torch_pname: + return [(torch_pname, torch_param.astype("float32"))] + elif ".dense_h_to_4h.bias" in torch_pname or ".dense_4h_to_h.bias" in torch_pname: + return [(torch_pname, torch_param.astype(ffn_out_dtype))] + else: + return [(torch_pname, torch_param.astype(dtype))] + + param_manager.set_param_loading_func( + args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd + ) + return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/gptj.py b/mlc_llm/relax_model/gptj.py new file mode 100644 index 0000000..9096583 --- /dev/null +++ b/mlc_llm/relax_model/gptj.py @@ -0,0 +1,688 @@ +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import tvm +from tvm import relax, te +from tvm.relax.op import ( + astype, + broadcast_to, + full, + matmul, + maximum, + minimum, + permute_dims, + reshape, + squeeze, + triu, +) +from tvm.relax.op.nn import gelu, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .gpt_neox import create_kv_cache_func +from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding +from .param_manager import ParamManager + + +def _min_value(dtype) -> relax.Expr: + v = tvm.tir.min_value(dtype).value + if dtype == "float16": + v = -55504.0 + return relax.const(v, dtype) + + +def _max_value(dtype) -> relax.Expr: + v = tvm.tir.max_value(dtype).value + if dtype == "float16": + v = 55504.0 + return relax.const(v, dtype) + + +@dataclass +class GPTJConfig: # pylint: disable=too-many-instance-attributes + def __init__( + self, + vocab_size, + n_embd, + n_inner, + n_head, + n_layer, + bos_token_id, + eos_token_id, + rotary_dim, + tie_word_embeddings, + dtype="float32", + layer_norm_eps=1e-5, + max_sequence_length=2048, + rotary_emb_base=10000, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = n_embd + self.intermediate_size = n_inner if n_inner is not None else 4 * n_embd + self.num_attention_heads = n_head + self.num_hidden_layers = n_layer + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.rotary_dim = rotary_dim + self.tie_word_embeddings = tie_word_embeddings + self.dtype = dtype + self.layer_norm_eps = layer_norm_eps + self.max_sequence_length = max_sequence_length + self.rotary_emb_base = rotary_emb_base + self.kwargs = kwargs + + +class GPTJMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): + super().__init__() + self.fc_in = Linear(hidden_size, intermediate_size, dtype, bias=True) + self.fc_out = Linear(intermediate_size, hidden_size, dtype, bias=True) + self.dtype = dtype + + def forward(self, hidden_states): + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + hidden_states = self.fc_in(hidden_states) + hidden_states = nn.emit(gelu(hidden_states)) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + hidden_states = self.fc_out(hidden_states) + return nn.emit(hidden_states) + + +class GPTJAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + rotary_embedding: RotaryEmbedding, + dtype: str, + ): + if hidden_size % num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size}" + f" and `num_heads`: {num_heads})." + ) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.rotary_embedding = rotary_embedding + self.q_proj = Linear(hidden_size, hidden_size, dtype, bias=False) + self.k_proj = Linear(hidden_size, hidden_size, dtype, bias=False) + self.v_proj = Linear(hidden_size, hidden_size, dtype, bias=False) + self.out_proj = Linear(hidden_size, hidden_size, dtype, bias=False) + self.dtype = dtype + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: + # hidden_states: [batch_size, seq_len, hidden_size] + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + batch_size, seq_len, _ = hidden_states.struct_info.shape + kv_seq_len = all_seq_len_shape.struct_info.values[0] + + def _project(proj): + return nn.emit( + reshape( + proj(hidden_states), + (batch_size, seq_len, self.num_heads, self.head_dim), + ) + ) + + # q/k/v states: [batch_size, seq_len, num_attention_heads, head_size] + q, k, v = ( + _project(self.q_proj), + _project(self.k_proj), + _project(self.v_proj), + ) + q, k = self.rotary_embedding(q, k, kv_seq_len - seq_len) + + if past_key_value is not None: + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache, v_cache = past_key_value + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[k_cache, squeeze(k, axis=0)], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[v_cache, squeeze(v, axis=0)], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + batch_size, _, num_heads, head_size = k.struct_info.shape + kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size]) + kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size]) + k = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], + ) + ) + v = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], + ) + ) + k = nn.emit(reshape(k, kv_states_shape)) + v = nn.emit(reshape(v, kv_states_shape)) + past_key_value = (k_cache, v_cache) + else: + past_key_value = (None, None) + + q = nn.emit(permute_dims(q, [0, 2, 1, 3])) + k = nn.emit(permute_dims(k, [0, 2, 1, 3])) + v = nn.emit(permute_dims(v, [0, 2, 1, 3])) + + # Calculate QK + attn_weights = nn.emit( + matmul(q, permute_dims(k, [0, 1, 3, 2])) + / relax.const( + math.sqrt(self.head_dim), + q.struct_info.dtype, + ) + ) + # Apply attention mask + attn_weights = nn.emit(attn_weights + attention_mask) + attn_weights = nn.emit( + minimum( + maximum( + attn_weights, + _min_value(attn_weights.struct_info.dtype), + ), + _max_value(attn_weights.struct_info.dtype), + ) + ) + # Calculate Softmax(QK) + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != q.struct_info.dtype: + attn_weights = astype(attn_weights, q.struct_info.dtype) + # Calculate Softmax(QK)V + attn_output = nn.emit(matmul(attn_weights, v)) + # Apply output projection + attn_output = self.out_proj( + reshape( + permute_dims(attn_output, [0, 2, 1, 3]), + (batch_size, seq_len, self.hidden_size), + ) + ) + return attn_output, past_key_value + + +class GPTJLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + num_heads: int, + rotary_embedding: RotaryEmbedding, + dtype: str, + ): + self.ln_1 = LayerNorm( + hidden_size, + eps=layer_norm_eps, + dtype=dtype, + ) + self.attn = GPTJAttention( + hidden_size, + num_heads=num_heads, + rotary_embedding=rotary_embedding, + dtype=dtype, + ) + self.mlp = GPTJMLP( + hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + ) + self.dtype = dtype + + def forward( + self, + hidden_states, + all_seq_len_shape: relax.Expr, + past_key_value: Optional[Tuple[relax.Expr]] = None, + attention_mask: Optional[relax.Expr] = None, + ): + normalized_input = self.ln_1(hidden_states) + attn_output, present_key_value = self.attn( + normalized_input, + all_seq_len_shape, + past_key_value, + attention_mask, + ) + mlp_output = self.mlp(normalized_input) + hidden_states = nn.emit(mlp_output + attn_output + hidden_states) + return hidden_states, present_key_value + + +def _prepare_decoder_attention_mask(input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + bsz, tgt_len = input_shape + mask = full((tgt_len, tgt_len), _min_value(dtype)) + mask = triu(mask, k=1) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, 0, x[b, _, i, j - (src_len - tgt_len)] + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + mask = relax.op.zeros((bsz, 1, tgt_len, src_len), dtype) + return nn.emit(mask) + + +class GPTJEmbedTokens(nn.Module): + def __init__(self, config: GPTJConfig): + self.wte = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + + def forward(self, input_ids: relax.Expr): + return self.wte(input_ids) + + +class GPTJEmbedTokensWrapper(nn.Module): + def __init__(self, config: GPTJConfig): + # build a wrapper to ensure that the naming of the embed_in parameter is consistent + self.gptj = GPTJEmbedTokens(config) + + def forward(self, input_ids: relax.Expr): + return self.gptj(input_ids) + + +class GPTJModel(nn.Module): + def __init__( + self, + config: GPTJConfig, + sep_embed: bool = False, + ): + rotary_embedding = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=config.rotary_emb_base, + max_sequence_length=config.max_sequence_length, + rotary_dim=config.rotary_dim, + swizzle_style="gptj", + dtype=config.dtype, + ) + self.wte = None + if not sep_embed: + self.wte = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + self.h = ModuleList( + [ + GPTJLayer( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + num_heads=config.num_attention_heads, + rotary_embedding=rotary_embedding, + dtype=config.dtype, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.ln_f = LayerNorm( + hidden_size=config.hidden_size, + eps=config.layer_norm_eps, + dtype=config.dtype, + ) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: Optional[Tuple[relax.Expr, relax.Expr]], + ): + batch_size, seq_length = inputs.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + # embed positions + hidden_states = self.wte(inputs) if self.wte is not None else inputs + attention_mask = _prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + dtype=hidden_states.struct_info.dtype, + ) + present_kv_cache = [] + for i, layer in enumerate(self.h): + past_key_value = ( + (past_key_values[i * 2], past_key_values[i * 2 + 1]) + if past_key_values is not None + else None + ) + hidden_states, (present_k_cache, present_v_cache) = layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + present_kv_cache.append(present_k_cache) + present_kv_cache.append(present_v_cache) + hidden_states = self.ln_f(hidden_states) + return hidden_states, present_kv_cache + + +class GPTJForCausalLM(nn.Module): + def __init__( + self, + config: GPTJConfig, + sep_embed: bool = False, + ): + self.transformer = GPTJModel(config, sep_embed) + self.lm_head = Linear( + in_features=config.hidden_size, + out_features=config.vocab_size, + bias=True, + dtype=config.dtype, + ) + self.dtype = config.dtype + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: Optional[List[relax.Expr]], + ): + hidden_states, key_value_cache = self.transformer( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + if hidden_states.struct_info.dtype != self.dtype: + hidden_states = nn.emit(astype(hidden_states, self.dtype)) + + def _slice(x: te.Tensor): + _, seq_len, hidden_dim = x.shape + return te.compute( + shape=(1, 1, hidden_dim), + fcompute=lambda i, _, k: x[i, seq_len - 1, k], + name="slice", + ) + + hidden_states = nn.emit_te( + _slice, + hidden_states, + primfunc_name_hint="slice", + ) + logits = self.lm_head(hidden_states) + if logits.struct_info.dtype != "float32": + logits = nn.emit(astype(logits, "float32")) + + return logits, key_value_cache + + +def check_parameters(param_dict, param_list): + relax_shape_to_list = lambda _: [s.value for s in _.values] + shape_dict_0 = {k: relax_shape_to_list(v.struct_info.shape) for k, v in param_dict.items()} + shape_dict_1 = {k: list(v.shape) for (k, v) in param_list} + assert len(shape_dict_0) == len(shape_dict_1) + for k, v in shape_dict_0.items(): + assert k in shape_dict_1, "{}".format(k) + assert v == shape_dict_1[k], "key={}, shape_0={}, shape_1={}".format(k, v, shape_dict_1[k]) + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "wte.weight" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTJConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = GPTJEmbedTokensWrapper(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("embed") + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTJConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + batch_size = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = GPTJForCausalLM(config, sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder( + (batch_size, seq_len, hidden_size), + dtype=config.dtype, + name="input_embeds", + ) + if sep_embed + else nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: GPTJConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + batch_size = tvm.tir.IntImm("int64", 1) + seq_len = tvm.tir.IntImm("int64", 1) + all_seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = GPTJForCausalLM(config) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var( + "all_seq_len", + relax.ShapeStructInfo((all_seq_len,)), + ) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_softmax_func(bb: relax.BlockBuilder, config: GPTJConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + max_seq_len = args.max_seq_len + sep_embed = args.sep_embed + + if model_name.startswith("gpt-j-"): + stop_tokens = [50256] + elif model_name.startswith("moss-"): + stop_tokens = [106068] + + config = GPTJConfig(**hf_config, dtype=dtype) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=stop_tokens, + add_prefix_space=True, + prefill_chunk_size=args.prefill_chunk_size, + ) + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + import re + + str_pattern = re.compile(r"(q|k|v)_proj") + if re.search(str_pattern, pname) is not None: + return [str_pattern.sub("qkv_proj", pname)] + else: + return [pname] + + hidden_size = config.hidden_size + + def f_convert_param_bkwd(torch_pname: str, torch_param) -> Optional[List[Tuple[str, Any]]]: + # torch_param: numpy.ndarray + if torch_pname.endswith("qkv_proj.weight"): + assert torch_param.ndim == 2 + mp_num = 4 + torch_param = torch_param.astype(dtype).reshape(mp_num, 3, -1, hidden_size) + q_weight = torch_param[:, 0, :, :].reshape(hidden_size, hidden_size) + k_weight = torch_param[:, 2, :, :].reshape(hidden_size, hidden_size) + v_weight = torch_param[:, 1, :, :].reshape(hidden_size, hidden_size) + return [ + (torch_pname.replace("qkv_proj", "q_proj"), q_weight), + (torch_pname.replace("qkv_proj", "k_proj"), k_weight), + (torch_pname.replace("qkv_proj", "v_proj"), v_weight), + ] + if "ln_1" in torch_pname or "ln_f" in torch_pname: + return [(torch_pname, torch_param.astype("float32"))] + else: + return [(torch_pname, torch_param.astype(dtype))] + + param_manager.set_param_loading_func( + args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd + ) + return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py new file mode 100644 index 0000000..06272e3 --- /dev/null +++ b/mlc_llm/relax_model/llama.py @@ -0,0 +1,1507 @@ +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import tvm +from tvm import relax, te, tir +from tvm.relax.op import ccl +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList +from .param_manager import ParamManager + + +@dataclass +class LlamaConfig: + def __init__( + self, + dtype="float32", + max_sequence_length=2048, + vocab_size=32000, # some models like WizardMath can have 32001 + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + position_embedding_base=10000, + combine_matmul=True, + build_model_only=False, + num_shards=1, + sliding_window=None, + target_kind=None, + **kwargs, + ): + self.dtype = dtype + self.max_sequence_length = max_sequence_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.position_embedding_base = position_embedding_base + self.combine_matmul = combine_matmul + self.sliding_window = sliding_window + self.target_kind = target_kind + + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + + return self.num_key_value_heads + + +class Linear(nn.Module): + def __init__(self, in_features, out_features, dtype: str, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name="linear_weight") + if bias: + self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + return nn.emit(relax.op.linear(input, self.weight, self.bias)) + + +class Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, dtype: str): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter( + (num_embeddings, embedding_dim), dtype=dtype, name="embedding_weight" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + from tvm.relax.op import reshape, take + + ndim = x.struct_info.ndim + if ndim == 1: + return nn.emit(take(self.weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = self.weight.struct_info.shape.values[-1] + x = nn.emit(reshape(x, shape=[-1])) + embedding = nn.emit(take(self.weight, x, axis=0)) + return nn.emit(reshape(embedding, [*x_shape, emb_size])) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, dtype, eps=1e-6): + self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") + self.variance_epsilon = tvm.tir.const(eps, dtype) + + def forward(self, hidden_states): + from tvm import te, tir + + def f_rms_norm(x, weight): + is_float32 = x.dtype == "float32" + + def f_square(x): + return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x + + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value + + def f_div_cast_2d(i, k): + x_val = x[i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[i] / x.shape[1] + self.variance_epsilon) + + def f_div_cast_3d(bsz, i, k): + x_val = x[bsz, i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) + + k = te.reduce_axis((0, x.shape[-1]), name="k") + + if len(x.shape) == 2: + square_sum = te.compute( + (x.shape[0],), + lambda i: te.sum(f_square(x[i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)), + name="rms_norm", + ) + else: + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)), + name="rms_norm", + ) + + return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") + + +class LlamaMLP(nn.Module): + def __init__(self, config: LlamaConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.gate_up_proj.weight.shard_strategy = "shard_gate_up" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "shard_mlp_k" + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.gate_proj.weight.shard_dim = 0 + self.gate_proj.weight.shard_strategy = "shard_axis_0" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "shard_axis_1" + self.up_proj.weight.shard_dim = 0 + self.up_proj.weight.shard_strategy = "shard_axis_0" + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +def rotary_modulate_by_freq(tensor, idx, pos, position_embedding_base): + head_dim = tensor.shape[-1] + dtype = tensor.dtype + n_feat_half = head_dim // 2 + feat_idx = idx[-1] + inv_freq = te.const(1, "float32") / ( + te.power( + te.const(position_embedding_base, "float32"), + ((2 * feat_idx) % head_dim).astype("float32") / head_dim.astype("float32"), + ) + ) + freq = pos * inv_freq + left_indices = idx[:-1] + (feat_idx - n_feat_half,) + right_indices = idx[:-1] + (feat_idx + n_feat_half,) + return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype(dtype) * tvm.tir.Select( + feat_idx >= n_feat_half, + tensor[(*left_indices,)], + -tensor[(*right_indices,)], + ) + + +def apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0): + def f_rotary_embedding(tensor, offset): + def rotary_compute(*idx): + pos = (offset + idx[-3]).astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, offset, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, offset, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class LlamaAttentionBase(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.position_embedding_base = config.position_embedding_base + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + self.query_key_value_proj.weight.shard_strategy = "shard_qkv" + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + self.q_proj.weight.shard_strategy = "shard_axis_0" + self.k_proj.weight.shard_strategy = "shard_axis_0" + self.v_proj.weight.shard_strategy = "shard_axis_0" + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + self.o_proj.weight.shard_strategy = "shard_o_proj_k" + + def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): + from tvm.relax.op import reshape, split + + if self.combine_matmul: + qkv_states = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query_states = relax.TupleGetItem(qkv_states, 0) + key_states = relax.TupleGetItem(qkv_states, 1) + value_states = relax.TupleGetItem(qkv_states, 2) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = nn.emit( + reshape(query_states, query_output_shape), + ) + key_states = nn.emit( + reshape(key_states, kv_output_shape), + ) + value_states = nn.emit( + reshape(value_states, kv_output_shape), + ) + + return query_states, key_states, value_states + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + bsz, q_len, _ = hidden_states.struct_info.shape + + query_states, key_states, value_states = self.project_qkv( + hidden_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ) + + from tvm.relax.op import reshape + + attn_output, past_key_values = self.attention_fwd( + query_states, + key_states, + value_states, + past_key_values, + bsz, + q_len, + layer_id=layer_id, + all_seq_len_shape=all_seq_len_shape, + attention_mask=attention_mask, + ) + + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + attn_output = self.o_proj(attn_output) + return attn_output, past_key_values + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ): + raise NotImplementedError() + + +class LlamaPagedAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, relax.Expr]: + assert "layer_id" in kwargs and isinstance(kwargs["layer_id"], int) + layer_id = kwargs["layer_id"] + + f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") + attn_output = nn.emit( + relax.call_dps_packed( + f_kv_cache_attention, + [ + past_key_values, + relax.PrimValue(layer_id), + query_states, + key_states, + value_states, + ], + out_sinfo=relax.TensorStructInfo( + ((batch_size, q_len, self.num_query_heads, self.head_dim)), + query_states.struct_info.dtype, + ), + ) + ) + return attn_output, past_key_values + + +class LlamaAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.config = config + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, Tuple[relax.Expr]]: + assert "attention_mask" in kwargs + assert "all_seq_len_shape" in kwargs + attention_mask = kwargs["attention_mask"] + kv_seq_len = kwargs["all_seq_len_shape"].struct_info.values[0] + + from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, squeeze + from tvm.relax.op.nn import softmax + + offset = kv_seq_len - q_len + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + self.position_embedding_base, + offset=offset, + ) + # [bsz, t, nh, hd] + + kv_states_shape = key_states.struct_info.shape + kv_states_dtype = key_states.struct_info.dtype + assert kv_states_shape[0] == 1 # bsz + kv_states_shape = R.shape( + [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) + + squeezed_key = nn.emit(squeeze(key_states, axis=0)) + squeezed_value = nn.emit(squeeze(value_states, axis=0)) + k_cache, v_cache = past_key_values + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + k_cache, + squeezed_key, + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + v_cache, + squeezed_value, + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_values = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + k_cache, + kv_cache_shape, + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + v_cache = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + v_cache, + kv_cache_shape, + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + key_states = nn.emit(reshape(k_cache, kv_states_shape)) + value_states = nn.emit(reshape(v_cache, kv_states_shape)) + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) + value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) + + if self.config.target_kind == "android": + attn_weights = nn.emit( + matmul( + permute_dims(query_states, [0, 2, 1, 3]), permute_dims(key_states, [0, 2, 3, 1]) + ) + / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) + ) + else: + query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) + key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) + value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (batch_size, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query_states.struct_info.dtype: + attn_weights = astype(attn_weights, query_states.struct_info.dtype) + if self.config.target_kind == "android": + attn_output = nn.emit(matmul(attn_weights, permute_dims(value_states, [0, 2, 1, 3]))) + else: + attn_output = nn.emit(matmul(attn_weights, value_states)) + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + return attn_output, past_key_values + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, enable_batching: bool): + attn_class = LlamaPagedAttention if enable_batching else LlamaAttention + self.hidden_size = config.hidden_size + self.self_attn = attn_class(config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + + def post_self_attn(self, hidden_states, residual): + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + return hidden_states + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + layer_id=layer_id, + ) + hidden_states = self.post_self_attn(hidden_states, residual) + return hidden_states, present_key_value + + +def _make_causal_mask(input_ids_shape, dtype, src_len): + from tvm.relax.op import broadcast_to + + bsz, tgt_len = input_ids_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, _, i, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + + +class LlamaEmbedTokens(nn.Module): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class LlamaEmbedTokensWrapper(nn.Module): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = LlamaEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class LlamaModelBase(nn.Module): + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tir.SizeVar, + sep_embed: bool = False, + enable_batching: bool = False, + ): + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [LlamaDecoderLayer(config, enable_batching) for _ in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + raise NotImplementedError() + + +class LlamaModelForSingleSequence(LlamaModelBase): + def __init__( + self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): + super().__init__(config, vocab_size_var, sep_embed, enable_batching=False) + + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + combined_attention_mask = nn.emit( + relax.op.full( + (bsz, 1, tgt_len, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + return combined_attention_mask + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + # embed positions + attention_mask = self._prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_value, + all_seq_len_shape=all_seq_len_shape, + layer_id=idx, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class LlamaModelForBatching(LlamaModelBase): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool): + assert sep_embed + super().__init__(config, vocab_size_var, sep_embed=True, enable_batching=True) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + assert all_seq_len_shape is None + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + hidden_states, past_key_values = decoder_layer( + hidden_states, + attention_mask=None, + past_key_values=past_key_values, + all_seq_len_shape=all_seq_len_shape, + layer_id=idx, + ) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + +class LlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tvm.tir.SizeVar, + sep_embed: bool = False, + enable_batching: bool = False, + output_all_logits: bool = False, + ): + model_class = LlamaModelForBatching if enable_batching else LlamaModelForSingleSequence + self.model = model_class(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + cache_len = te.var("cached_rotary_embedding_len", "int64") + self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") + self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") + + # Mark if output_all_logits is True + self.output_all_logits = output_all_logits + ############ End ############ + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + logit_positions: Optional[relax.Expr] = None, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + assert x.ndim == 3 + return te.compute( + shape=(x.shape[0], 1, x.shape[2]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + if not self.output_all_logits and hidden_states.struct_info.shape[1] != 1: + if logit_positions is None: + hidden_states = nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice") + else: + hidden_states = relax.op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + with bb.function(func_name): + model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_prefill_func_for_single_seq( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed, enable_batching=False + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_prefill_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "prefill_with_embed" + + bsz = tir.SizeVar("batch_size", "int64") + total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder( + (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + ) + logit_pos = nn.Placeholder((bsz,), dtype="int32", name="logit_positions") + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + all_seq_len_shape=None, + past_key_values=past_key_values, + logit_positions=logit_pos, + ) + params = [inputs, logit_pos, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func_for_single_seq( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") + + with bb.function(func_name): + model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode_with_embed" + + bsz = tir.SizeVar("batch_size", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder((bsz, 1, hidden_size), dtype=config.dtype, name="inputs_embeds") + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape=None, past_key_values=past_key_values + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + +def create_verification_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "verify_with_embed" + + total_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, + tvm.tir.SizeVar("vocab_size", "int64"), + sep_embed=True, + enable_batching=True, + output_all_logits=True, + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder( + (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + ) + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + all_seq_len_shape=None, + past_key_values=past_key_values, + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + zeros, + init_shape, + relax.PrimValue(0), + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + head_dim = config.hidden_size // config.num_attention_heads + num_qo_heads = config.num_attention_heads // config.num_shards + num_kv_heads = config.get_num_key_value_heads() // config.num_shards + + page_size = tir.SizeVar("page_size", "int64") + total_seq_len = tir.SizeVar("total_seq_len", "int64") + reserved_nseq = tir.SizeVar("reserved_nseq", "int64") + cache_config = relax.Var( + "cache_config", + relax.ShapeStructInfo([reserved_nseq, total_seq_len, page_size]), + ) + + with bb.function("create_kv_cache", [cache_config]): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros((), config.dtype)) + f_kv_cache_create = relax.extern("vm.builtin.paged_attention_kv_cache_create") + cache = bb.emit_output( + relax.call_pure_packed( + f_kv_cache_create, + args=[ + cache_config, + relax.PrimValue(config.num_hidden_layers), + relax.PrimValue(num_qo_heads), + relax.PrimValue(num_kv_heads), + relax.PrimValue(head_dim), + relax.PrimValue(1), + relax.PrimValue(config.position_embedding_base), + zeros, + bb.get().get_global_var("kv_cache_transpose_append"), + bb.get().get_global_var("attention_prefill"), + bb.get().get_global_var("attention_decode"), + bb.get().get_global_var("attention_prefill_ragged"), + bb.get().get_global_var("attention_prefill_ragged_begin_forward"), + bb.get().get_global_var("attention_prefill_ragged_end_forward"), + bb.get().get_global_var("attention_prefill_begin_forward"), + bb.get().get_global_var("attention_prefill_end_forward"), + bb.get().get_global_var("attention_decode_begin_forward"), + bb.get().get_global_var("attention_decode_end_forward"), + bb.get().get_global_var("attention_rope_in_place"), + bb.get().get_global_var("attention_merge_state"), + bb.get().get_global_var("kv_cache_debug_get_kv"), + ], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + bb.emit_func_output(cache) + + +def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + with bb.function("softmax_with_temperature"): + bsz = tvm.tir.SizeVar("batch_size", "int64") + logits = nn.Placeholder( + (bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")), + dtype="float32", + name="logits", + ) + temperature = nn.Placeholder((bsz,), dtype="float32", name="temperature") + with bb.dataflow(): + t_reshaped = bb.emit(relax.op.reshape(temperature, (bsz, 1, 1))) + div = bb.emit(relax.op.divide(logits, t_reshaped)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + from tvm.script import tir as T + + num_kv_heads = config.get_num_key_value_heads() // config.num_shards + head_dim = config.hidden_size // config.num_attention_heads + + @T.prim_func + def kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer( + var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), config.dtype + ) + k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), config.dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), config.dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + position: T.int64 = T.Cast("int64", position_map[vgpos]) + pages[ + T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + position: T.int64 = T.Cast("int64", position_map[vgpos]) + pages[ + T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf + ] = v_data[vgpos, vh, vf] + + @T.prim_func + def kv_cache_debug_get_kv( + var_pages: T.handle, + var_position_map: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + layer_id: T.int64, + ): + seqlen = T.SizeVar("seqlen", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer( + var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), config.dtype + ) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + k_data = T.match_buffer( + var_k_data, (config.num_hidden_layers, seqlen, num_kv_heads, head_dim), config.dtype + ) + v_data = T.match_buffer( + var_v_data, (config.num_hidden_layers, seqlen, num_kv_heads, head_dim), config.dtype + ) + + for p, h, d in T.grid(seqlen, num_kv_heads, head_dim): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + position: T.int64 = T.Cast("int64", position_map[vp]) + k_data[layer_id, vp, vh, vd] = pages[ + T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd + ] + v_data[layer_id, vp, vh, vd] = pages[ + T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd + ] + + bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") + bb.add_func(kv_cache_debug_get_kv, "kv_cache_debug_get_kv") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_decode"), "attention_decode") + bb.add_func( + relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), + "attention_prefill_ragged", + ) + bb.add_func( + relax.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), + "attention_prefill_begin_forward", + ) + bb.add_func( + relax.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), + "attention_prefill_end_forward", + ) + bb.add_func( + relax.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), + "attention_decode_begin_forward", + ) + bb.add_func( + relax.extern("paged_kv_cache.attention_kernel_decode_end_forward"), + "attention_decode_end_forward", + ) + bb.add_func( + relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), + "attention_prefill_ragged_begin_forward", + ) + bb.add_func( + relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), + "attention_prefill_ragged_end_forward", + ) + bb.add_func( + relax.extern("flashinfer.merge_state_in_place"), + "attention_merge_state", + ) + bb.add_func( + relax.extern("flashinfer.batch_qk_apply_rotary_in_place"), + "attention_rope_in_place", + ) + + +def setup_params(mod, param_manager, dtype, config, args): + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.get_num_key_value_heads() + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + gate, up = torch_params + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + device = tvm.cpu() + param_list = [None] * param_manager.nparam_to_load + + head_dim = config.hidden_size / config.num_attention_heads + inv_freq = 1.0 / ( + config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) + + # The following cos/sin values can be removed but **are kept for compatibility issues**. + t = np.arange(2048, dtype=inv_freq.dtype) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) + param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) + + return mod, param_manager, param_list, config + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + enable_batching = args.enable_batching + sep_embed = args.sep_embed + + if enable_batching and not sep_embed: + raise ValueError("`sep_embed` is required when batching is enabled.") + + position_embedding_base = 10000 + + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + target_kind=args.target_kind, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + target_kind=args.target_kind, + ) + else: + raise Exception( + "The model config should contain information about maximum sequence length." + ) + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + + if enable_batching: + emit_paged_kv_cache_op(bb, config) + create_prefill_func_for_batching(bb, param_manager, config, args.quantization) + create_decoding_func_for_batching(bb, param_manager, config, args.quantization) + create_verification_func_for_batching(bb, param_manager, config, args.quantization) + create_paged_kv_cache_func(bb, config) + create_softmax_func_for_batching(bb, config) + else: + create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func_for_single_seq(bb, config) + + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, + ) + + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["num_tokens_without_cache"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["num_tokens_with_cache"] = config.max_sequence_length + tir_bound_map["vocab_size"] = args.max_vocab_size + if enable_batching: + tir_bound_map["nseq"] = args.max_batch_size + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py new file mode 100644 index 0000000..365500b --- /dev/null +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -0,0 +1,658 @@ +from typing import Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.relax.op.nn import attention_var_len +from tvm.relax.testing import nn +from tvm.ir import VDevice +from tvm.script import relax as R +from tvm.script.ir_builder import tir as T + +from ..quantization import QuantizationScheme +from .modules import ModuleList +from .param_manager import ParamManager +from .llama import ( + LlamaConfig, + Linear, + Embedding, + LlamaRMSNorm, + LlamaAttentionBase, + LlamaDecoderLayer, + get_param_quant_kind, + setup_params, + rotary_modulate_by_freq, +) + + +def apply_rotary_pos_emb(q, k, positions, position_embedding_base): + def f_rotary_embedding(tensor, pos_tensor): + def rotary_compute(*idx): + pos = pos_tensor[idx[0]].astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, positions, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, positions, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class LlamaAttentionBatched(LlamaAttentionBase): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config) + self.head_mapping = head_mapping # (num_heads,), used by vLLM for multi-query attention + self.sliding_window = None + + if config.sliding_window: + self.sliding_window = T.IntImm("int32", config.sliding_window) + + def forward( + self, + hidden_states: relax.Expr, # (num_token, hidden_size) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], # (num_token,) + max_seqlen: Optional[relax.Expr], # (), must be on CPU + seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + num_tokens, _ = hidden_states.struct_info.shape + + queries, keys, values = self.project_qkv( + hidden_states, + (num_tokens, self.num_query_heads, self.head_dim), + (num_tokens, self.num_key_value_heads, self.head_dim), + ) + + queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) + + if kv_cache: + # Paged KV cache update + k_cache, v_cache = kv_cache + + if self.sliding_window is None or block_tables: + # For decode or prefill without sliding window, cache all keys / values. + keys_to_cache = keys + values_to_cache = values + else: + # Cache only the most recent keys and values within the window. + keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) + values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) + slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) + + # kv caches are updated inplace, takes ownership of the arguments + kv = nn.emit( + relax.op.call_inplace_packed( + "tvm.contrib.vllm.reshape_and_cache", + args=[keys_to_cache, values_to_cache, k_cache, v_cache, slot_mapping], + inplace_indices=[2, 3], + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + k_cache, v_cache = kv[0], kv[1] + else: + k_cache = v_cache = None + + if seqstart: + # Prefill, batched attention over variable sequence lengths + attn_output = nn.emit( + attention_var_len( + nn.emit(expand_dims(queries, axis=0)), + nn.emit(expand_dims(keys, axis=0)), + nn.emit(expand_dims(values, axis=0)), + seqstart_q=seqstart, + max_seqlen_q=max_seqlen, + causal_mask="BottomRight", + window_size=self.sliding_window, + ) + ) + else: + # Decode, using vLLM kernel + attn_output = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + self.head_mapping, + block_tables, + seq_lens, + 16, # block_size + max_seqlen, + ], + out_sinfo=queries.struct_info, + ) + ) + + attn_output = nn.emit( + reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) + ) + attn_output = self.o_proj(attn_output) + + return attn_output, (k_cache, v_cache) + + +class LlamaDecoderLayerBatched(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config, False) + self.self_attn = LlamaAttentionBatched(config, head_mapping) + + def forward( + self, + hidden_states: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], + max_seqlen: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, new_kv = self.self_attn( + hidden_states=hidden_states, + positions=positions, + seq_lens=seq_lens, + kv_cache=kv_cache, + slot_mapping=slot_mapping, + max_seqlen=max_seqlen, + seqstart=seqstart, + block_tables=block_tables, + indices_within_window=indices_within_window, + ) + + hidden_states = self.post_self_attn(hidden_states, residual) + + return hidden_states, new_kv + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.SizeVar, + sep_embed: bool = False, + ): + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + num_query_heads = config.num_attention_heads // config.num_shards + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + num_queries_per_kv = num_query_heads // num_key_value_heads + head_mapping = relax.const( + tvm.nd.array( + np.repeat(np.arange(num_key_value_heads, dtype="int32"), num_queries_per_kv) + ) + ) + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [ + LlamaDecoderLayerBatched(config, head_mapping) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + + self.cpu_device = cpu_device + + def forward( + self, + inputs: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_caches: Optional[relax.Expr], + slot_mapping: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ): + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + # max_seqlen needs to be on CPU, so that vLLM and Flash Attention can directly get the + # integer length by max_seqlen->data[0]. Otherwise, we need to repeatedly do cudaMemcpy + # of a single int32. + max_seqlen = R.to_vdevice(R.max(seq_lens), self.cpu_device) + + new_kvs = () + + for idx, decoder_layer in enumerate(self.layers): + if kv_caches: + cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) + else: + cache = None + + hidden_states, new_kv = decoder_layer( + hidden_states, + positions, + seq_lens, + cache, + slot_mapping, + max_seqlen, + seqstart, + block_tables, + indices_within_window, + ) + new_kvs += new_kv + + return self.norm(hidden_states), new_kvs + + +class LlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.SizeVar, + sep_embed: bool = False, + ): + self.num_shards = config.num_shards + self.model = LlamaModel(config, cpu_device, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + cache_len = te.var("cached_rotary_embedding_len", "int64") + self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") + self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") + ############ End ############ + + def forward( + self, + input_ids: relax.Expr, # (num_token,) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate + slot_mapping: Optional[ + relax.Expr + ], # (num_token,), for prefill and decode, not needed for evaluate + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + """ + In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other + for values. The tensor has shape (num_blocks, num_kv_heads, head_size, block_size). + (In practice, the key cache has a slightly different shape for an efficiency reason, + but that's not important.) + + The mapping between sequences / tokens to blocks is specified by two inputs. + - block_tables: A list of block IDs allocated for the sequence. + - slot_mapping: A linear index into the 2D grid (num_blocks, block_size), for each token. + + Support for sliding-window attention is realized by making a block table a circular buffer. + So the length of a block table for each sequence is at most ceil(window_size / block_size). + + With sliding window, not all past K / V values need to be cached during prefill. + The last input, indices_within_window, tells which tokens among (num_token,) need to have + their K / V values cached. + """ + if self.num_shards > 1: + input_ids = nn.emit(ccl.broadcast_from_worker0(input_ids)) + positions = nn.emit(ccl.broadcast_from_worker0(positions)) + seq_lens = nn.emit(ccl.broadcast_from_worker0(seq_lens)) + + if slot_mapping: + slot_mapping = nn.emit(ccl.broadcast_from_worker0(slot_mapping)) + + if block_tables: + block_tables = nn.emit(ccl.broadcast_from_worker0(block_tables)) + + if indices_within_window: + indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) + + is_prompt = block_tables is None + + if is_prompt: # prefill and evaluate + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust + cumsum = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + ) + seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) + else: + seqstart = None + + hidden_states, new_kvs = self.model( + input_ids, + positions, + seq_lens, + kv_caches, + slot_mapping, + seqstart, + block_tables, + indices_within_window, + ) + + if is_prompt: + # Extract logits for the last token in each sequence + + def get_logits_last_tokens(x, seq_len_tensor, seqstart): + return te.compute( + shape=(seq_len_tensor.shape[0], x.shape[-1]), + fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], + name="get_logits_last_tokens", + ) + + logits = self.lm_head( + nn.emit_te( + get_logits_last_tokens, + hidden_states, + seq_lens, + seqstart, + primfunc_name_hint="get_logits_last_tokens", + ) + ) + else: + logits = self.lm_head(hidden_states) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, new_kvs + + +def get_inputs( + num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True +): + hidden_size = config.hidden_size + + inputs = ( + nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((num_token,), dtype="int32", name="input_ids") + ) + + seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") + positions = nn.Placeholder((num_token,), dtype="int32", name="positions") + + if need_cache: + num_blocks = tvm.tir.SizeVar("num_blocks", "int64") + block_size = 16 + + vec_size = 8 # 128 bit, fp16 x 8 + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + num_key_value_heads, + head_size // vec_size, + block_size, + vec_size, + ) + v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + + get_cache_sinfo = lambda i: relax.TensorStructInfo( + k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" + ) + + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] + ), + ) + slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") + else: + past_key_values = None + slot_mapping = None + block_tables = None + + if max_num_blocks_per_seq is None: + block_tables = None + else: + block_tables = nn.Placeholder( + (num_seq, max_num_blocks_per_seq), dtype="int32", name="block_tables" + ) + + return inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables + + +def create_evaluate_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" + func_name = "evaluate" + + num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs, positions, seq_lens, _, _, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + logits, _ = model( + inputs, + positions, + seq_lens, + kv_caches=None, + slot_mapping=None, + block_tables=None, + indices_within_window=None, + ) + params = [ + inputs, + positions, + seq_lens, + ] + model.parameters() + gv = bb.emit_output(logits) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Batched prefill with vLLM paged KV cache. + + The batched attention op is intended to be offloaded to CUTLASS or Flash Attention + via BYOC. + """ + func_name = "prefill_with_embed" if sep_embed else "prefill" + + num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") + + num_inputs = 5 + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + params = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + ] + + inputs = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + None, # block_tables + ] + + if config.sliding_window: + num_inputs += 1 + # The value of num_cached_total is between + # num_token (if seq_len < sliding_window for all seq) and + # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) + num_cached_total = tvm.tir.SizeVar("num_cached_total", "int64") + indices_within_window = nn.Placeholder( + (num_cached_total,), dtype="int32", name="indices_within_window" + ) + inputs.append(indices_within_window) + params.append(indices_within_window) + else: + inputs.append(None) + + logits, new_kvs = model(*inputs) + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + + bb.emit_func_output(gv, params + model.parameters()) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, +) -> None: + """Batched decoding with vLLM paged KV cache.""" + func_name = "decode" + + num_seq = tvm.tir.SizeVar("batch_size", "int64") + max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") + + with bb.function(func_name): + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( + num_seq, num_seq, config, max_num_blocks_per_seq + ) + + with bb.dataflow(): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + logits, new_kvs = model( + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None + ) + params = [ + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 6)) + + +def get_model(args, hf_config): + dtype = args.quantization.model_dtype + sep_embed = False + + position_embedding_base = 10000 + + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception( + "The model config should contain information about maximum sequence length." + ) + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. + cpu_dev = VDevice("llvm", 0, "global") + + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + + mod = bb.get() + + mod.update_global_info("vdevice", [cpu_dev]) + + if args.build_model_only: + return mod, param_manager, None, config + + return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/minigpt.py b/mlc_llm/relax_model/minigpt.py new file mode 100644 index 0000000..96126bb --- /dev/null +++ b/mlc_llm/relax_model/minigpt.py @@ -0,0 +1,627 @@ +import math +import os +from dataclasses import dataclass + +import torch +import tvm +from tvm import relax +from tvm.relax.testing import nn + + +from ..quantization import ParamQuantKind, QuantizationScheme +from .modules import ModuleList, TransformImage +from .param_manager import ParamManager + + +@dataclass +class MiniGPTConfig: + dtype: str = "float16" + in_chan: int = 4 # represent rgba + image_size: int = 224 + num_query_token: int = 32 + max_txt_len: int = 160 + vocab_size: int = 32000 + patch_size: int = 14 + word_embed: int = 768 + visual_encoder_embed_dim: int = 1408 + visual_encoder_attn_heads: int = 16 + visual_encoder_attn_hidden_dim: int = 257 + visual_encoder_fc_hidden_dim: int = 6144 + visual_encoder_num_blocks: int = 39 + bert_hidden_layers: int = 12 + bert_num_attn_heads: int = 12 + bert_attn_head_size: int = 64 + bert_interm_query: int = 3072 + llama_proj_size: int = 4096 + + +MODEL_CONFIG = { + "minigpt4-7b": {}, +} + + +class MiniGPTPatchEmbed(nn.Module): + def __init__( + self, image_size, patch_size, embed_dim, dtype: str, in_chans=3, bias=True + ): + self.strides = (patch_size, patch_size) + self.embed_dim = embed_dim + self.out_shape = image_size // patch_size + + bs = 1 + self.cls_token = nn.Parameter((bs, 1, embed_dim), dtype=dtype, name="cls_token") + self.pos_embed = nn.Parameter( + (1, self.out_shape * self.out_shape + 1, embed_dim), + dtype=dtype, + name="pos_embed", + ) + self.weight = nn.Parameter( + (embed_dim, in_chans, patch_size, patch_size), + dtype=dtype, + name="patch_embed_weight", + ) + if bias: + self.bias = nn.Parameter((embed_dim,), dtype=dtype, name="patch_embed_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + bs = 1 + x = nn.emit(relax.op.nn.conv2d(input, self.weight, self.strides)) + if self.bias: + bias = relax.op.reshape(self.bias, [1, self.embed_dim, 1, 1]) + x = relax.op.add(x, bias) + x = relax.op.reshape(x, (bs, self.embed_dim, self.out_shape * self.out_shape)) + x = relax.op.permute_dims(x, [0, 2, 1]) + # concatenate with cls_tokens + x_concat = relax.op.concat([self.cls_token, x], axis=1) + # add with pos_embed + res = relax.op.add(x_concat, self.pos_embed) + return res + + +class MiniGPTVisualEncoderAttention(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.embed_dim = config.visual_encoder_embed_dim + self.num_heads = config.visual_encoder_attn_heads + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim ** (-0.5) + self.dtype = config.dtype + self.N = config.visual_encoder_attn_hidden_dim + + self.q_bias = nn.Parameter((self.embed_dim,), dtype=self.dtype, name="q_bias") + self.v_bias = nn.Parameter((self.embed_dim,), dtype=self.dtype, name="v_bias") + self.qkv_weight = nn.Parameter( + (self.embed_dim * 3, self.embed_dim), dtype=self.dtype, name="qkv_weight" + ) + self.proj_weight = nn.Parameter( + (self.embed_dim, self.embed_dim), dtype=self.dtype, name="proj_weight" + ) + self.proj_bias = nn.Parameter( + (self.embed_dim,), dtype=self.dtype, name="proj_bias" + ) + + def forward(self, input: relax.Expr): + from tvm.relax.op import ( + concat, + linear, + matmul, + permute_dims, + reshape, + squeeze, + strided_slice, + zeros, + ) + + bs = 1 + k_bias = zeros((self.embed_dim,), self.dtype) + qkv_bias = concat([self.q_bias, k_bias, self.v_bias], axis=0) + x = linear(input, self.qkv_weight, qkv_bias) + x = reshape(x, (bs, self.N, 3, self.num_heads, self.head_dim)) + x = permute_dims(x, [2, 0, 3, 1, 4]) + q = squeeze(strided_slice(x, axes=[0], begin=[0], end=[1]), [0]) + k = squeeze(strided_slice(x, axes=[0], begin=[1], end=[2]), [0]) + v = squeeze(strided_slice(x, axes=[0], begin=[2], end=[3]), [0]) + q = q * relax.const(self.scale, self.dtype) + attn = matmul(q, permute_dims(k, [0, 1, 3, 2])) + attn = relax.op.nn.softmax(attn, -1) + res = permute_dims(matmul(attn, v), [0, 2, 1, 3]) + res = reshape(res, (bs, self.N, self.embed_dim)) + res = linear(res, self.proj_weight, self.proj_bias) + return res + + +class MiniGPTMLP(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.hidden_dim = config.visual_encoder_fc_hidden_dim + self.embed_dim = config.visual_encoder_embed_dim + self.dtype = config.dtype + + self.fc1_weight = nn.Parameter( + (self.hidden_dim, self.embed_dim), dtype=self.dtype, name="fc1_weight" + ) + self.fc1_bias = nn.Parameter( + (self.hidden_dim,), dtype=self.dtype, name="fc1_bias" + ) + self.fc2_weight = nn.Parameter( + (self.embed_dim, self.hidden_dim), dtype=self.dtype, name="fc2_weight" + ) + self.fc2_bias = nn.Parameter( + (self.embed_dim,), dtype=self.dtype, name="fc2_bias" + ) + + def forward(self, input: relax.Expr): + res = relax.op.linear(input, self.fc1_weight, self.fc1_bias) + res = relax.op.nn.gelu(res) + res = relax.op.linear(res, self.fc2_weight, self.fc2_bias) + return res + + +class MiniGPTVisualEncoderBlock(nn.Module): + def __init__(self, config: MiniGPTConfig): + embed_dim = config.visual_encoder_embed_dim + dtype = config.dtype + self.norm1_weight = nn.Parameter((embed_dim,), dtype=dtype, name="norm1_weight") + self.norm1_bias = nn.Parameter((embed_dim,), dtype=dtype, name="norm1_bias") + self.attn = MiniGPTVisualEncoderAttention(config) + self.norm2_weight = nn.Parameter((embed_dim,), dtype=dtype, name="norm2_weight") + self.norm2_bias = nn.Parameter((embed_dim,), dtype=dtype, name="norm2_bias") + self.mlp = MiniGPTMLP(config) + + def forward(self, input: relax.Expr): + x = relax.op.nn.layer_norm(input, self.norm1_weight, self.norm1_bias, axes=[-1]) + proj = self.attn(x) + proj = relax.op.add(input, proj) + res = relax.op.nn.layer_norm( + proj, self.norm2_weight, self.norm2_bias, axes=[-1] + ) + res = self.mlp(res) + res = relax.op.add(proj, res) + return res + + +class MiniGPTVisualEncoder(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.embed_dim = config.visual_encoder_embed_dim + self.dtype = config.dtype + self.transform = TransformImage(config.dtype, config.in_chan) + self.patch_embed = MiniGPTPatchEmbed( + config.image_size, + config.patch_size, + config.visual_encoder_embed_dim, + config.dtype, + ) + self.num_blocks = config.visual_encoder_num_blocks + self.blocks = ModuleList( + [MiniGPTVisualEncoderBlock(config) for _ in range(self.num_blocks)] + ) + + self.ln_vision_weight = nn.Parameter( + (self.embed_dim,), dtype=self.dtype, name="ln_vision_weight" + ) + self.ln_vision_bias = nn.Parameter( + (self.embed_dim,), dtype=self.dtype, name="ln_vision_bias" + ) + + def forward(self, input_image: relax.Expr): + res = self.transform(input_image) + res = self.patch_embed(res) + for block in self.blocks: + res = block(res) + res = relax.op.nn.layer_norm( + res, self.ln_vision_weight, self.ln_vision_bias, axes=[-1] + ) + return res + + +class MiniGPTEmbedding(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.word_embed = config.word_embed + self.dtype = config.dtype + self.eps = 1e-12 + + self.norm_weight = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_weight" + ) + self.norm_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_bias" + ) + + def forward(self, embedding: relax.Expr): + res = relax.op.nn.layer_norm( + embedding, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps + ) + return res + + +class MiniGPTBertAttention(nn.Module): + def __init__(self, config: MiniGPTConfig, hidden_dim: int): + self.word_embed = config.word_embed + self.num_query_token = config.num_query_token + self.num_attn_heads = config.bert_num_attn_heads + self.attn_head_size = config.bert_attn_head_size + self.visual_encoder_attn_hidden_dim = config.visual_encoder_attn_hidden_dim + self.dtype = config.dtype + self.eps = 1e-12 + + self.query_weight = nn.Parameter( + (self.word_embed, self.word_embed), dtype=self.dtype, name="query_weight" + ) + self.query_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="query_bias" + ) + self.key_weight = nn.Parameter( + (self.word_embed, hidden_dim), dtype=self.dtype, name="key_weight" + ) + self.key_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="key_bias" + ) + self.value_weight = nn.Parameter( + (self.word_embed, hidden_dim), dtype=self.dtype, name="value_weight" + ) + self.value_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="value_bias" + ) + self.dense_weight = nn.Parameter( + (self.word_embed, self.word_embed), dtype=self.dtype, name="dense_weight" + ) + self.dense_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="dense_bias" + ) + self.norm_weight = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_weight" + ) + self.norm_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_bias" + ) + + def forward( + self, + hidden_states: relax.Expr, + attention_mask: relax.Expr, + encoder_hidden_states=None, + encoder_extend_attention_mask=None, + ): + from tvm.relax.op import add, linear, matmul, permute_dims, reshape + + bs = 1 + states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + mask = ( + encoder_extend_attention_mask + if encoder_extend_attention_mask is not None + else attention_mask + ) + hidden_dim = ( + self.visual_encoder_attn_hidden_dim + if encoder_hidden_states is not None + else self.num_query_token + ) + key = linear(states, self.key_weight, self.key_bias) + value = linear(states, self.value_weight, self.value_bias) + key = reshape(key, [bs, hidden_dim, self.num_attn_heads, self.attn_head_size]) + key = permute_dims(key, [0, 2, 1, 3]) + value = reshape( + value, [bs, hidden_dim, self.num_attn_heads, self.attn_head_size] + ) + value = permute_dims(value, [0, 2, 1, 3]) + query = linear(hidden_states, self.query_weight, self.query_bias) + query = reshape( + query, [bs, self.num_query_token, self.num_attn_heads, self.attn_head_size] + ) + query = permute_dims(query, [0, 2, 1, 3]) + scores = matmul(query, permute_dims(key, [0, 1, 3, 2])) + scores = scores / relax.const(math.sqrt(self.attn_head_size), dtype=self.dtype) + scores = add(scores, mask) + probs = relax.op.nn.softmax(scores, axis=-1) + context = matmul(probs, value) + context = permute_dims(context, [0, 2, 1, 3]) + context = reshape(context, [bs, self.num_query_token, self.word_embed]) + # calculate the output + context = linear(context, self.dense_weight, self.dense_bias) + context = add(context, hidden_states) + res = relax.op.nn.layer_norm( + context, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps + ) + return res, key, value + + +class MiniGPTBertLayer(nn.Module): + def __init__(self, config: MiniGPTConfig, use_cross_attention=False): + self.word_embed = config.word_embed + self.embed_dim = config.visual_encoder_embed_dim + self.interm_query = config.bert_interm_query + self.dtype = config.dtype + self.eps = 1e-12 + + self.attention = MiniGPTBertAttention(config, self.word_embed) + if use_cross_attention: + self.cross_attention = MiniGPTBertAttention(config, self.embed_dim) + else: + self.cross_attention = None + self.interm_query_weight = nn.Parameter( + (self.interm_query, self.word_embed), + dtype=self.dtype, + name="interm_query_weight", + ) + self.interm_query_bias = nn.Parameter( + (self.interm_query,), dtype=self.dtype, name="interm_query_bias" + ) + self.output_query_weight = nn.Parameter( + (self.word_embed, self.interm_query), + dtype=self.dtype, + name="output_query_weight", + ) + self.output_query_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="output_query_bias" + ) + self.norm_weight = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_weight" + ) + self.norm_bias = nn.Parameter( + (self.word_embed,), dtype=self.dtype, name="norm_bias" + ) + + def forward( + self, + embedding: relax.Expr, + extend_attention_mask: relax.Expr, + encoder_hidden_states: relax.Expr, + encoder_extend_attention_mask: relax.Expr, + ): + attn_output, key, value = self.attention(embedding, extend_attention_mask) + if self.cross_attention: + attn_output, _, _ = self.cross_attention( + attn_output, + extend_attention_mask, + encoder_hidden_states, + encoder_extend_attention_mask, + ) + res = relax.op.linear( + attn_output, self.interm_query_weight, self.interm_query_bias + ) + res = relax.op.nn.gelu(res) + res = relax.op.linear(res, self.output_query_weight, self.output_query_bias) + res = relax.op.add(res, attn_output) + res = relax.op.nn.layer_norm( + res, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps + ) + return res, key, value + + +class MiniGPTQFormer(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.N = config.visual_encoder_attn_hidden_dim + self.num_query_token = config.num_query_token + self.word_embed = config.word_embed + self.num_layers = config.bert_hidden_layers + self.dtype = config.dtype + + bs = 1 + self.query_tokens = nn.Parameter( + (bs, self.num_query_token, self.word_embed), + dtype=self.dtype, + name="query_tokens", + ) + self.embedding = MiniGPTEmbedding(config) + self.bert_layers = ModuleList( + [MiniGPTBertLayer(config, i % 2 == 0) for i in range(self.num_layers)] + ) + + def forward(self, image_embeds: relax.Expr): + from tvm.relax.op import expand_dims, ones + + bs = 1 + image_attns = ones((bs, self.N), self.dtype) + embedding = self.embedding(self.query_tokens) + attention_mask = ones((bs, self.num_query_token), self.dtype) + extend_attention_mask = expand_dims(attention_mask, [1, 2]) + extend_attention_mask = ( + relax.const(1.0, self.dtype) - extend_attention_mask + ) * relax.const(-10000.0, self.dtype) + encoder_extend_attention_mask = expand_dims(image_attns, [1, 2]) + encoder_extend_attention_mask = ( + relax.const(1.0, self.dtype) - encoder_extend_attention_mask + ) + for layer in self.bert_layers: + embedding, _, _ = layer( + embedding, + extend_attention_mask, + image_embeds, + encoder_extend_attention_mask, + ) + return embedding + + +class MiniGPTLLaMAProj(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.proj_size = config.llama_proj_size + self.word_embed = config.word_embed + self.dtype = config.dtype + + self.weight = nn.Parameter( + (self.proj_size, self.word_embed), dtype=self.dtype, name="weight" + ) + self.bias = nn.Parameter((self.proj_size,), dtype=self.dtype, name="bias") + + def forward(self, embedding: relax.Expr): + return relax.op.linear(embedding, self.weight, self.bias) + + +class MiniGPTModel(nn.Module): + def __init__(self, config: MiniGPTConfig): + self.visual_encoder = MiniGPTVisualEncoder(config) + self.q_former = MiniGPTQFormer(config) + self.llama_proj = MiniGPTLLaMAProj(config) + + def forward(self, input_image: relax.Expr): + output = self.visual_encoder(input_image) + output = self.q_former(output) + output = self.llama_proj(output) + return output + + +def get_param_quant_kind( + name: str, param_info: relax.TensorStructInfo +) -> ParamQuantKind: + """No quantization for MiniGPT. Use q0f16 or q0f32 when building it.""" + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MiniGPTConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bs = 1 + with bb.function(func_name): + model = MiniGPTModel(config) + param_manager.register_params( + model, func_name, quant_scheme, get_param_quant_kind + ) + + input_image = nn.Placeholder( + (bs, config.image_size, config.image_size, config.in_chan), + dtype="uint8", + name="input_image", + ) + with bb.dataflow(): + output = model(input_image) + params = [input_image] + model.parameters() + gv = bb.emit_output(output) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def get_model(args, _config): + model_name = args.model + model_path = args.model_path + + if model_name.startswith("minigpt"): + config = MiniGPTConfig(**MODEL_CONFIG[model_name]) + config.dtype = args.quantization.model_dtype + # build the relax model + param_manager = ParamManager() + bb = relax.BlockBuilder() + create_embed_func(bb, param_manager, config, args.quantization) + mod = bb.get() + + if args.build_model_only: + return mod, param_manager, None, config + + param_manager.set_param_loading_func( + args.model_path, args.use_safetensors, no_lazy_param_loading=True + ) + + # load visual encoder weights + visual_encoder_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + visual_encoder_cached_file = download_cached_file( + visual_encoder_url, check_hash=False, progress=True + ) + visual_encoder_state_dict = torch.load( + visual_encoder_cached_file, map_location="cpu" + ) + + # load QFormer weights + q_former_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth" + q_former_cached_file = download_cached_file( + q_former_url, check_hash=False, progress=True + ) + q_former_state_dict = torch.load(q_former_cached_file, map_location="cpu")[ + "model" + ] + + # load llama and llama proj weights + if os.path.isdir(model_path): + raise ValueError( + "MiniGPT model path should be a single file instead of a directory." + ) + llama_state_dict = torch.load(model_path + ".pth", map_location="cpu")["model"] + + param_list = [] + device = tvm.cpu() + visual_encoder_key_list = list(visual_encoder_state_dict.keys())[ + : 4 + 13 * config.visual_encoder_num_blocks + ] + for key in visual_encoder_key_list: + param_list.append( + tvm.nd.array( + visual_encoder_state_dict[key].numpy().astype(config.dtype), device + ) + ) + q_former_key_list = ( + list(q_former_state_dict.keys())[1:3] + + [list(q_former_state_dict.keys())[0]] + + list(q_former_state_dict.keys())[ + 6 : 8 + (26 + 16) * config.bert_hidden_layers // 2 + ] + ) + for key in q_former_key_list: + param_list.append( + tvm.nd.array( + q_former_state_dict[key].numpy().astype(config.dtype), device + ) + ) + llama_key_list = list(llama_state_dict.keys())[-2:] + for key in llama_key_list: + param_list.append( + tvm.nd.array(llama_state_dict[key].numpy().astype(config.dtype), device) + ) + + return mod, param_manager, param_list, config + + raise ValueError(f"Unsupported model: {model_name}") + + +# helper functions for distributed download of model weights from URL +# source: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/common/dist_utils.py (originally credit to Salesforce) + + +def download_cached_file(url, check_hash=True, progress=False): + import timm.models.hub as timm_hub + import torch.distributed as dist + + def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + def is_main_process(): + return get_rank() == 0 + + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py new file mode 100644 index 0000000..e08495f --- /dev/null +++ b/mlc_llm/relax_model/mistral.py @@ -0,0 +1,1125 @@ +# pylint: disable=too-many-lines, missing-class-docstring, missing-function-docstring +"""Implements the mistal model with sliding window attention.""" + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList +from .param_manager import ParamManager + + +@dataclass +class MistralConfig: + """Configuration for mistral model.""" + + def __init__( + self, + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=32768, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=8, + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + attention_sink_size=0, + tie_word_embeddings=False, + vocab_size=32000, + dtype="float32", + max_sequence_length=16384, + combine_matmul=True, + build_model_only=False, + num_shards=1, + **kwargs, + ): + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.attention_sink_size = attention_sink_size + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.dtype = dtype + self.max_sequence_length = sliding_window * 4 + self.combine_matmul = combine_matmul + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + + return self.num_key_value_heads + + +class Linear(nn.Module): + def __init__(self, in_features, out_features, dtype: str, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name="linear_weight") + if bias: + self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + return nn.emit(relax.op.linear(input, self.weight, self.bias)) + + +class Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, dtype: str): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter( + (num_embeddings, embedding_dim), dtype=dtype, name="embedding_weight" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + from tvm.relax.op import ( # pylint: disable=import-outside-toplevel + reshape, + take, + ) + + ndim = x.struct_info.ndim + if ndim == 1: + return nn.emit(take(self.weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = self.weight.struct_info.shape.values[-1] + x = nn.emit(reshape(x, shape=[-1])) + embedding = nn.emit(take(self.weight, x, axis=0)) + return nn.emit(reshape(embedding, [*x_shape, emb_size])) + + +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, dtype, eps=1e-6): + self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") + self.variance_epsilon = tvm.tir.const(eps, dtype) + + def forward(self, hidden_states): + from tvm import te, tir + + def f_rms_norm(x, weight): + is_float32 = x.dtype == "float32" + + def f_square(x): + return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x + + k = te.reduce_axis((0, x.shape[2]), name="k") + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + def f_div_cast(bsz, i, k): + x_val = x[bsz, i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) + + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), + name="rms_norm", + ) + + return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") + + +class MistralMLP(nn.Module): + def __init__(self, config: MistralConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.gate_up_proj.weight.shard_strategy = "shard_gate_up" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "shard_mlp_k" + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +def apply_rotary_pos_emb(q, k, base, q_offset): + def f_rotary_embedding(tensor, offset): + dtype = tensor.dtype + head_dim = tensor.shape[-1] + n_feat_half = tensor.shape[-1] // 2 + + def rotary_compute(*idx): + i, j = idx[-3], idx[-1] + pos = (offset + i).astype("float32") + inv_freq = te.const(1, "float32") / ( + te.power( + te.const(base, "float32"), + ((2 * j) % head_dim).astype("float32") / head_dim.astype("float32"), + ) + ) + freq = pos * inv_freq + return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype( + dtype + ) * tvm.tir.Select( + j >= n_feat_half, + tensor[idx[0], i, idx[2], j - n_feat_half], + -tensor[idx[0], i, idx[2], j + n_feat_half], + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, q_offset, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, 0, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class MistralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MistralConfig): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.rope_theta = config.rope_theta + self.sliding_window = config.sliding_window + self.attention_sink_size = config.attention_sink_size + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + self.query_key_value_proj.weight.shard_strategy = "shard_qkv" + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + self.o_proj.weight.shard_strategy = "shard_o_proj_k" + + def interleave_kv( + self, + key_cur: relax.Expr, + value_cur: relax.Expr, + kv_seq_len: int, + rolling_cache_len: int, + cache_offset: int, + attention_sink_size: int, + past_key_value: Tuple[relax.Expr], + ): + from tvm.relax.op import reshape + + def te_cache_unrotate(x_cached, cache_offset, rolling_cache_len): + return te.compute( + (kv_cur_shape[0], rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]), + lambda b, s, h, d: te.if_then_else( + s < attention_sink_size, + x_cached[b, s, h, d], + te.if_then_else( + s < rolling_cache_len - cache_offset + attention_sink_size, + x_cached[b, s + cache_offset - attention_sink_size, h, d], + x_cached[b, s + cache_offset - rolling_cache_len, h, d], + ), + ), + name="te_cache_unrotate", + ) + + def te_cache_cur_concat(x, x_cached, kv_seq_len, rolling_cache_len): + return te.compute( + (kv_cur_shape[0], kv_seq_len, kv_cur_shape[2], kv_cur_shape[3]), + lambda b, s, h, d: te.if_then_else( + s < rolling_cache_len, + x_cached[b, s, h, d], + x[b, s - rolling_cache_len, h, d], + ), + name="te_cache_cur_concat", + ) + + def te_squeeze(x): + return te.compute( + x.shape[1:], + lambda s, h, d: x[0, s, h, d], + name="squeeze_te", + ) + + # [bsz, t, nh, hd] + kv_cur_shape = key_cur.struct_info.shape + kv_cur_dtype = key_cur.struct_info.dtype + assert kv_cur_shape[0] == 1 # bsz + kv_batched_cache_shape = R.shape( + [kv_cur_shape[0], rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]] + ) + kv_cache_shape = R.shape([rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]]) + + # fecth past keys and values from cache + k_cache, v_cache = past_key_value + + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + key_cached = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], + ) + ) + value_cached = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], + ) + ) + key_cached = nn.emit(reshape(key_cached, kv_batched_cache_shape)) + value_cached = nn.emit(reshape(value_cached, kv_batched_cache_shape)) + + key_cached = nn.emit_te( + te_cache_unrotate, + key_cached, + cache_offset, + rolling_cache_len, + primfunc_name_hint="te_cache_unrotate_key", + ) + key = nn.emit_te( + te_cache_cur_concat, + key_cur, + key_cached, + kv_seq_len, + rolling_cache_len, + primfunc_name_hint="te_cache_cur_concat_key", + ) + + value_cached = nn.emit_te( + te_cache_unrotate, + value_cached, + cache_offset, + rolling_cache_len, + primfunc_name_hint="te_cache_unrotate_value", + ) + value = nn.emit_te( + te_cache_cur_concat, + value_cur, + value_cached, + kv_seq_len, + rolling_cache_len, + primfunc_name_hint="te_cache_cur_concat_value", + ) + + # update cache + squeezed_key = nn.emit_te(te_squeeze, key_cur) + squeezed_value = nn.emit_te(te_squeeze, value_cur) + + assert attention_sink_size >= 0 + f_kv_cache_override = relax.extern( + "vm.builtin.attention_kv_cache_window_override_with_sinks" + ) + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_override, + args=[ + k_cache, + squeezed_key, + relax.PrimValue(self.sliding_window), + relax.PrimValue(attention_sink_size), + ], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_override, + args=[ + v_cache, + squeezed_value, + relax.PrimValue(self.sliding_window), + relax.PrimValue(attention_sink_size), + ], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + + return key, value, (k_cache, v_cache) + + def forward( + self, + hidden_states: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + cache_offset_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: + # pylint: disable=import-outside-toplevel + from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, split + from tvm.relax.op.nn import softmax + + bsz, q_len, _ = hidden_states.struct_info.shape + assert bsz == 1, "Only support batch size 1 at this moment." + + if self.combine_matmul: + qkv_cur = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query = relax.TupleGetItem(qkv_cur, 0) + key_cur = relax.TupleGetItem(qkv_cur, 1) + value_cur = relax.TupleGetItem(qkv_cur, 2) + else: + query = self.q_proj(hidden_states) + key_cur = self.k_proj(hidden_states) + value_cur = self.v_proj(hidden_states) + + query = nn.emit( + reshape( + query, + (bsz, q_len, self.num_query_heads, self.head_dim), + ), + ) + key_cur = nn.emit( + reshape( + key_cur, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + value_cur = nn.emit( + reshape( + value_cur, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + + # concat current kv with cached kv (unrotating the cache) + rolling_cache_len = cache_len_shape.struct_info.values[0] + kv_seq_len = kv_seq_len_shape.struct_info.values[0] + cache_offset = cache_offset_shape.struct_info.values[0] + key, value, updated_key_value = self.interleave_kv( + key_cur, + value_cur, + kv_seq_len, + rolling_cache_len, + cache_offset, + self.attention_sink_size, + past_key_value, + ) + + # cache relative position embeddings (after KV Cache) + query, key = apply_rotary_pos_emb( + query, + key, + self.rope_theta, + q_offset=rolling_cache_len, + ) + + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key = nn.emit(relax.op.repeat(key, n_rep, axis=2)) + value = nn.emit(relax.op.repeat(value, n_rep, axis=2)) + + query = nn.emit(permute_dims(query, [0, 2, 1, 3])) + key = nn.emit(permute_dims(key, [0, 2, 1, 3])) + value = nn.emit(permute_dims(value, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query, permute_dims(key, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query.struct_info.dtype: + attn_weights = astype(attn_weights, query.struct_info.dtype) + attn_output = nn.emit(matmul(attn_weights, value)) + + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + + attn_output = self.o_proj(attn_output) + + return attn_output, ((None, None) if updated_key_value is None else updated_key_value) + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig): + self.hidden_size = config.hidden_size + self.self_attn = MistralAttention(config) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + cache_offset_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + cache_offset_shape=cache_offset_shape, + ) + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + return hidden_states, present_key_value + + +def _make_sliding_window_mask(input_shape, kv_seq_len, sliding_window, dtype): + # See `tests/python/test_sliding_window_mask.py` for more on its behavior. + # [bsz, tgt_len] -> [bsz, 1, tgt_len, kv_seq_len] + + bsz, tgt_len = input_shape # TODO: only support batch size of 1 for now + cache_len = kv_seq_len - tgt_len # number of elements in cache + + if isinstance(tgt_len, tvm.tir.SizeVar) or tgt_len > 1: + # Either 1. First prefill, or 2. Subsequent prefill + from tvm.relax.op import broadcast_to # pylint: disable=import-outside-toplevel + + def sliding_window_min_max_te(sliding_window): + return te.compute( + (tgt_len, kv_seq_len), + lambda i, j: tvm.tir.Select( + tvm.tir.all(i + cache_len >= j, i + cache_len - j < sliding_window), + tvm.tir.max_value(dtype), + tvm.tir.min_value(dtype), + ), + name="make_diag_mask_sliding_window_te", + ) + + mask = nn.emit_te(sliding_window_min_max_te, sliding_window) + return nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, kv_seq_len))) + + else: + # 3. Decode (equivalent to prefilling a chunk of size 1) + # Mask nothing here since WS == cache_size + bsz, tgt_len = input_shape + return nn.emit( + relax.op.full( + (bsz, 1, tgt_len, kv_seq_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + + +class MistralEmbedTokens(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class MistralEmbedTokensWrapper(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = MistralEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class MistralModel(nn.Module): + def __init__( + self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = MistralRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + def forward( + self, + inputs: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + cache_offset_shape: relax.Expr, + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + kv_seq_len = kv_seq_len_shape.struct_info.values[0] + + # embed positions + attention_mask = _make_sliding_window_mask( + (batch_size, seq_length), + kv_seq_len, + self.sliding_window, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + cache_offset_shape=cache_offset_shape, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class MistralForCausalLM(nn.Module): + def __init__( + self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): + self.model = MistralModel(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + rope_cache_len = te.var("rope_cache_len", "int64") + self.cos_cached = nn.Parameter( + (rope_cache_len, head_dim), dtype=config.dtype, name="cos_cached" + ) + self.sin_cached = nn.Parameter( + (rope_cache_len, head_dim), dtype=config.dtype, name="sin_cached" + ) + ############ End ############ + + def forward( + self, + inputs: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + cache_offset_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + cache_offset_shape=cache_offset_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.SizeVar("n", "int64") + with bb.function(func_name): + model = MistralEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.SizeVar("n", "int64") # number of tokens for the input + rolling_cache_len = tvm.tir.SizeVar( + "c", "int64" + ) # rolling_cache_len captures number of elements in the cache + kv_seq_len = tvm.tir.SizeVar( + "k", "int64" + ) # kv_seq_len captures number of elements in cache + seq_len + cache_offset = tvm.tir.SizeVar("o", "int64") # slidinf window kv cache offset + + hidden_size = config.hidden_size + with bb.function(func_name): + model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + cache_len_shape = relax.Var( + "rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,)) + ) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + cache_offset_shape = relax.Var("cache_offset", relax.ShapeStructInfo((cache_offset,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + cache_len_shape, + kv_seq_len_shape, + cache_offset_shape, + past_key_values=past_key_values, + ) + params = [ + inputs, + cache_len_shape, + kv_seq_len_shape, + cache_offset_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 5)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + rolling_cache_len = tvm.tir.SizeVar( + "c", "int64" + ) # rolling_cache_len captures number of elements in the cache + kv_seq_len = tvm.tir.SizeVar( + "k", "int64" + ) # kv_seq_len captures number of elements in cache + seq_len + cache_offset = tvm.tir.SizeVar("o", "int64") # sliding window kv cache offset + + with bb.function(func_name): + model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + cache_len_shape = relax.Var( + "rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,)) + ) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + cache_offset_shape = relax.Var("cache_offset", relax.ShapeStructInfo((cache_offset,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, + cache_len_shape, + kv_seq_len_shape, + cache_offset_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + cache_len_shape, + kv_seq_len_shape, + cache_offset_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 5)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.sliding_window, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + sep_embed = args.sep_embed + assert not sep_embed, "Mistral does not support separate embedding." + + if args.sliding_window != -1: + hf_config["sliding_window"] = args.sliding_window + if args.attention_sink_size > 0: + hf_config["attention_sink_size"] = args.attention_sink_size + if args.max_seq_len != -1: + hf_config["max_sequence_length"] = args.max_seq_len + + config = MistralConfig( + **hf_config, + dtype=dtype, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + + # prefill chunk size same as sliding window by default + if args.prefill_chunk_size < 1: + args.prefill_chunk_size = config.sliding_window - config.attention_sink_size + + assert config.sliding_window != -1 + assert args.prefill_chunk_size <= config.sliding_window - config.attention_sink_size + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + sliding_window=config.sliding_window, + prefill_chunk_size=args.prefill_chunk_size, + ) + + mod = bb.get() + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr( + "tir_var_upper_bound", + { + "n": args.prefill_chunk_size, + "c": config.sliding_window, + "k": config.sliding_window + args.prefill_chunk_size, + }, + ) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.get_num_key_value_heads() + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + gate, up = torch_params + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + device = tvm.cpu() + param_list = [None] * param_manager.nparam_to_load + + head_dim = config.hidden_size / config.num_attention_heads + inv_freq = 1.0 / (config.rope_theta ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + + # The following cos/sin values can be removed but **are kept for compatibility issues**. + t = np.arange(2048, dtype=inv_freq.dtype) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) + param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) + + return mod, param_manager, param_list, config diff --git a/mlc_llm/relax_model/modules.py b/mlc_llm/relax_model/modules.py new file mode 100644 index 0000000..e506938 --- /dev/null +++ b/mlc_llm/relax_model/modules.py @@ -0,0 +1,280 @@ +# pylint: disable=missing-docstring,invalid-name +from typing import Dict, List, Tuple, Optional + +import numpy as np +from tvm import relax, te, tir +from tvm.relax.op import matmul, permute_dims, reshape, take +from tvm.relax.op.nn import layer_norm +from tvm.relax.testing import nn +from tvm.runtime.ndarray import array as tvm_array + + +class ModuleList(nn.Module): + def __init__(self, modules: List[nn.Module]): + self.modules = modules + + def __iter__(self): + return iter(self.modules) + + def __getitem__(self, idx): + return self.modules[idx] + + def __len__(self): + return len(self.modules) + + def forward(self, x: relax.Expr) -> relax.Var: + for module in self.modules: + x = module(x) + return x + + +class Linear(nn.Module): + def __init__( + self, + in_features, + out_features, + dtype, + bias=True, + out_dtype=None, + ): + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + (out_features, in_features), + dtype=dtype, + name="linear_weight", + ) + if bias: + self.bias = nn.Parameter( + (out_features,), + dtype=dtype if out_dtype is None else out_dtype, + name="linear_bias", + ) + else: + self.bias = None + self.dtype = dtype + self.out_dtype = out_dtype + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit(x) + weight = permute_dims(self.weight, axes=None) + x = nn.emit(matmul(x, weight, out_dtype=self.out_dtype)) + if self.bias is not None: + x = nn.emit(x + self.bias) + return x + + +class Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, dtype): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter( + (num_embeddings, embedding_dim), dtype=dtype, name="weight" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + ndim = x.struct_info.ndim + if ndim == 1: + return nn.emit(take(self.weight, x, axis=0)) + x_shape = x.struct_info.shape.values + emb_size = self.weight.struct_info.shape.values[-1] + x = nn.emit(reshape(x, shape=[-1])) + embedding = nn.emit(take(self.weight, x, axis=0)) + return nn.emit(reshape(embedding, [*x_shape, emb_size])) + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + dtype, + eps=1e-5, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter((hidden_size,), dtype="float32", name="weight") + self.bias = nn.Parameter((hidden_size,), dtype="float32", name="bias") + + def forward(self, x: relax.Expr) -> relax.Var: + if x.struct_info.dtype != "float32": + x = nn.emit(relax.op.astype(x, "float32")) + x = nn.emit( + layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + position_embedding_base: int, + max_sequence_length: int, + rotary_pct: Optional[float] = None, + rotary_dim: Optional[int] = None, + swizzle_style: str = "neox", + dtype: str = "float32", + ): + super().__init__() + head_dim = hidden_size // num_attention_heads + if rotary_dim is not None: + rotary_ndim = rotary_dim + else: + rotary_ndim = int(head_dim * rotary_pct) + inv_freq = 1.0 / ( + position_embedding_base + ** (np.arange(0, rotary_ndim, 2).astype("float32") / rotary_ndim) + ) + t = np.arange(max_sequence_length, dtype=inv_freq.dtype) + freq = np.einsum("i,j->ij", t, inv_freq) + if swizzle_style == "neox": + emb = np.concatenate((freq, freq), axis=-1) + elif swizzle_style in ("gptj", "glm"): + emb = np.repeat(freq, repeats=2, axis=-1) + else: + raise KeyError("Unrecognized swizzle style {}".format(swizzle_style)) + self.swizzle_style = swizzle_style + self.rotary_ndim = rotary_ndim + self.cos_cached = relax.const(tvm_array(np.cos(emb).astype(dtype))) + self.sin_cached = relax.const(tvm_array(np.sin(emb).astype(dtype))) + + def get_x_swizzle(self, x, i_batch_size, i_seq_len, i_num_heads, i_head_dim): + if self.swizzle_style == "neox": + n_feat_half = self.rotary_ndim // 2 + return tir.Select( + i_head_dim < n_feat_half, + -x[ + i_batch_size, + i_seq_len, + i_num_heads, + i_head_dim + n_feat_half, + ], + x[ + i_batch_size, + i_seq_len, + i_num_heads, + i_head_dim - n_feat_half, + ], + ) + elif self.swizzle_style in ("gptj", "glm"): + return tir.Select( + i_head_dim % 2 == 0, + -x[i_batch_size, i_seq_len, i_num_heads, i_head_dim + 1], + x[i_batch_size, i_seq_len, i_num_heads, i_head_dim - 1], + ) + else: + raise KeyError("Unrecognized swizzle style: {}.".format(self.swizzle_style)) + + def forward( + self, + q: relax.Expr, + k: relax.Expr, + offset: relax.Expr, + ) -> Tuple[relax.Expr, relax.Expr]: + def rotary_embedding(x, cos, sin, offset): + def compute( + i_batch_size, + i_seq_len, + i_num_heads, + i_head_dim, + ): + return tir.Select( + i_head_dim < self.rotary_ndim, + cos[ + offset + i_seq_len, + i_head_dim, + ] + * x(i_batch_size, i_seq_len, i_num_heads, i_head_dim) + + sin[ + offset + i_seq_len, + i_head_dim, + ] + * self.get_x_swizzle( + x, i_batch_size, i_seq_len, i_num_heads, i_head_dim + ), + x(i_batch_size, i_seq_len, i_num_heads, i_head_dim), + ) + + return te.compute(x.shape, compute, name="rotary") + + cos, sin = self.cos_cached, self.sin_cached + q_embed = nn.emit_te( + rotary_embedding, + q, + cos, + sin, + offset, + primfunc_name_hint="rotary_embedding", + ) + k_embed = nn.emit_te( + rotary_embedding, + k, + cos, + sin, + offset, + primfunc_name_hint="rotary_embedding", + ) + return q_embed, k_embed + + +class TransformImage(nn.Module): + def __init__(self, dtype: str, in_chans: int = 4): + self.in_chans = in_chans + self.dtype = dtype + + # used in normalization, assume channels are RGB + self.r_mean = relax.const(0.48145466, "float32") + self.g_mean = relax.const(0.4578275, "float32") + self.b_mean = relax.const(0.40821073, "float32") + self.r_std = relax.const(0.26862954, "float32") + self.g_std = relax.const(0.26130258, "float32") + self.b_std = relax.const(0.27577711, "float32") + + def forward(self, input: relax.Expr) -> relax.Expr: + from tvm.relax.op import astype, concat, permute_dims, strided_slice + + assert input.struct_info.ndim == 4 + # perform torch.ToTensor on input of shape (bs, height, width, in_chans) + input = permute_dims(input, [0, 3, 1, 2]) + x = astype(input, "float32") / relax.const(255.0, "float32") + r = strided_slice(x, axes=[1], begin=[0], end=[1]) + g = strided_slice(x, axes=[1], begin=[1], end=[2]) + b = strided_slice(x, axes=[1], begin=[2], end=[3]) + + # normalize rgba to rgb + if self.in_chans == 4: + a = strided_slice(x, axes=[1], begin=[3], end=[4]) + r /= a + g /= a + b /= a + + # perform torch.Normalize + r = (r - self.r_mean) / self.r_std + g = (g - self.g_mean) / self.g_std + b = (b - self.b_mean) / self.b_std + res = concat([r, g, b], axis=1) + res = astype(res, self.dtype) + + return res + + +def named_parameters(model: nn.Module) -> Dict[str, nn.Parameter]: + params: Dict[str, nn.Parameter] = {} + for name, module in model.__dict__.items(): + if isinstance(module, nn.Parameter): + params[name] = module + elif isinstance(module, ModuleList): + for i, m in enumerate(module): + for param_name, param in named_parameters(m).items(): + params[f"{name}.{i}.{param_name}"] = param + elif isinstance(module, nn.Module): + for param_name, param in named_parameters(module).items(): + params[f"{name}.{param_name}"] = param + return params diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py new file mode 100644 index 0000000..f776db3 --- /dev/null +++ b/mlc_llm/relax_model/param_manager.py @@ -0,0 +1,1209 @@ +import json +import os +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import tvm +from torch import Tensor as torchTensor +from tvm import relax, tir +from tvm._ffi.runtime_ctypes import Device +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr import Expr, Function, Var +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.relax.testing import nn + +from .. import quantization +from .modules import named_parameters +from ..transform import ReorderTransformFunc + + +def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any: + """The defualt `f_compute_relax_param` for ParamManager. + See ParamManager for more details. + """ + raise NotImplementedError() + + +class Parameter: + """The abstraction of weight tensors (e.g., linear layer weight, embedding + table, etc.) in a model. + + Attributes + ---------- + name : str + The name of the parameter. + The name of a weight is got by `named_parameters()` method, similar to + PyTorch's `named_parameters()` function. + An example name is `model.layers.11.self_attn.k_proj.weight`. + In a model, the name is the **unique** identifier of a parameter. + + param_info_dict : Dict[str, relax.TensorStructInfo] + The shape and dtype of the parameter in each function. + The shape can be accessed by `param_info_dict[func_name].shape`, which is + a relax.ShapeExpr instance. + And the dtype can be accessed by `param_info_dict[func_name].dtype`, + which is a Python string. + + quant_spec : quantization.QuantizationSpec + The quantization specification of this parameter. + It specifies the algorithm to quantize and dequantize this parameter (or + this parameter does not need quantization). + + shard_dim : Optional[int] + The dimension to be sharded. + + shard_strategy : Optional[str] + The strategy to shard the parameter. + """ + + name: str + param_info_dict: Dict[str, relax.TensorStructInfo] + quant_spec: quantization.QuantizationSpec + shard_dim: Optional[int] + shard_strategy: Optional[str] + + def __init__( + self, + name: str, + quant_spec: quantization.QuantizationSpec, + shard_dim: Optional[int], + shard_strategy: Optional[str], + ) -> None: + self.name = name + self.param_info_dict = dict() + self.quant_spec = quant_spec + self.shard_dim = shard_dim + self.shard_strategy = shard_strategy + + def register_func(self, func_name: str, param_info: relax.TensorStructInfo): + self.param_info_dict[func_name] = param_info + + @property + def param_info(self): + """Return the shape and dtype of the parameter (in some arbitrary function).""" + return next(iter(self.param_info_dict.values())) + + +class ParamManager: + """The model-wise data structure which contains the information of every + weight in the model and is in charge of applying quantization and dequantization + to the parameters at the entire model level. + + Attributes + ---------- + params : Dict[str, Parameter] + The mapping from parameter names to parameters. + + param_names : List[str] + The name list of all the parameters. + To enforce a unique order or all the parameters for determinism, the + parameter names are kept in the list, and the parameter order is + uniquely determined by the parameter name list. + + func_raw_param_map : Dict[relax.Var, Tuple[str, Parameter]] + The mapping from each relax.Var that denotes a weight parameter to the + name of the function the var is in (e.g., "prefill" or "decode"), and + the Parameter it corresponds to. + This mapping is used for applying quantization transformation to the + Relax functions (e.g., the "prefill", "decode", etc.) in the model. + + param2qrange : Dict[Parameter, range] + The mapping from each parameter to the range of its quantized tensors + in the list of quantized tensors of all parameters. + Each parameter is quantized into multiple tensors. + For example, assume we have parameters `p0`, `p1`, `p2`. + - `p0` is quantized into `t0_0`, `t0_1`, + - `p1` is quantized into `t1_0`, and + - `p2` is quantized into `t2_0`, `t2_1` and `t2_2`. + Then the list of all quantized tensors is `[t0_0, t0_1, t1_0, t2_0, t2_1, t2_2]`, + and the dict `param2qrange` is + `{p0: range(0, 2), p1: range(2, 3), p2: range(3, 6)}`. + + f_convert_pname_fwd : Callable[[str], List[str]] + The function which converts Relax parameter name (ours) to torch's + parameter names, suggesting "to load this Relax parameter, which torch + parameter(s) are needed". + - Usually, the function maps a name to itself. For example, in LLaMA we + map `lm_head.weight` itself, as the parameter has the same name on both + Relax side and torch side. + - In some cases we map a name to multiple names. For example, if we + support combined QKV computing when the torch side separates them, on + Relax side we only have one QKV weight, while on torch side we have + one weight for each of Q, K, V. In this case, we map one name to three + names. + - In some cases we map a name to a single name which is other than + itself. This can happen either when the Relax nn.Module has different + param names than the torch's implementation so we need to map names + for connection, or when a Relax parameter is computed out from a torch + parameter. For example, if the torch implementation supports combined + QKV while the Relax one does not, we need compute the relax parameter + out from torch's parameter. In this case we map the relax parameter + name to the torch's parameter name. + + f_convert_param_bkwd : Callable[[str, Any], Optional[List[Tuple[str, Any]]]] + The function which converts torch parameter and param name back to + Relax parameters with names. `Any` here stands for numpy.ndarray. + - Usually, the function just returns the input torch parameter and + the corresponding Relax parameter's name. + - In some cases, we return multiple Relax parameters. For example, if + the torch implementation supports combined QKV while the Relax one does + not, the function takes torch's combined QKV weight, and return the + separated Q K V weights with their corresponding names. + - In some cases we return `None`. This happens when the input torch + parameter itself does not determine any Relax parameter. For example, + if we support combined QKV computing when the torch side separates them, + we return `None` here for the single Q, K, V weights, as by only having + a Q (or K, V) weight we cannot compute the combined QKV weight. + + f_compute_relax_param : Callable[[str, List[Any]], Any] + The function which computes a Relax parameter from a list of torch + parameters. `Any` here stands for numpy.ndarray. In the case when one + Relax parameter is computed from multiple torch parameters, this + functions is used. + For example, if we support combined QKV computing when the torch side + separates them, we use this function to combine the torch's Q, K, V + weights into one + In usual case, this function is not needed and by default it is + implemented by raising `NotImplementedError` (see f_default_compute_relax_param). + + model_path : str + The path of the Hugging Face model on disk. + + use_safetensors: bool + Whether to use `.safetensors` instead of `.bin` to load model. + + safetensors_load_func: Callable[[Union[str, os.PathLike], str], Dict[str, torch.Tensor]] + A reference to the function `load_file` improted from `safetensors.torch`. + The goal is to prevent repeatedly importing in a tvm registered function. + + pidx2pname : Dict[int, str] + The dictionary from each Relax parameter's index in `param_names` to + the Relax parameter's name. + + torch_pname2binname : Dict[str, str] + The dictionary from each torch parameter's name to the name of the + binary shard where the torch parameter is saved. + """ + + params: Dict[str, Parameter] + param_names: List[str] + func_raw_param_map: Dict[relax.Var, Tuple[str, Parameter]] + param2qrange: Dict[Parameter, range] + + qspec_updater_classes: List[quantization.QuantSpecUpdater] + + nparam_to_load: int + f_convert_pname_fwd: Callable[[str], List[str]] + f_convert_param_bkwd: Callable[[str, Any], Optional[List[Tuple[str, Any]]]] + f_compute_relax_param: Callable[[str, List[Any]], Any] + f_run_prequantize: Optional[Callable[[str], str]] + + model_path: str + use_safetensors: bool + safetensors_load_func: Callable[[Union[str, os.PathLike], str], Dict[str, torchTensor]] + pidx2pname: Dict[int, str] + torch_pname2binname: Dict[str, str] + + def __init__(self) -> None: + self.params = {} + self.param_names = [] + self.params_in_func = {} + + self.func_raw_param_map = {} + self.param2qrange = None + + self.nparam_to_load = None + self.f_convert_pname_fwd = None + self.f_convert_param_bkwd = None + self.f_compute_relax_param = None + self.f_run_prequantize = None + + self.qspec_updater_classes = [] + + def register_params( + self, + model: nn.Module, + func_name: str, + quantization_scheme: quantization.QuantizationScheme, + f_get_param_quant_kind: Callable[ + [str, relax.TensorStructInfo], quantization.ParamQuantKind + ], + ) -> None: + """Register the parameters of the input model (within the context of the + input function) in the parameter manager. + + Parameters + ---------- + model : nn.Module + The input model whose parameters are registered. + + func_name : str + The name of the function the input model is in. + For example, the "prefill" function or the "decode" function. + + quantization_scheme : quantization.QuantizationScheme + The quantization scheme of the input model, which describes how + to quantize the model. + + f_get_param_quant_kind: Callable[[str, relax.TensorStructInfo], quantization.ParamQuantKind] + A function which takes the name and StructInfo (effectively shape + and dtype) of a parameter, and returns which quantization kind this + parameter uses. + This is used for applying quantization to the parameters. + """ + if quantization_scheme.qspec_updater_class is not None: + self.qspec_updater_classes.append(quantization_scheme.qspec_updater_class) + if quantization_scheme.f_convert_param_bkwd is not None: + self.f_convert_param_bkwd = quantization_scheme.f_convert_param_bkwd + if quantization_scheme.f_compute_relax_param is not None: + self.f_compute_relax_param = quantization_scheme.f_compute_relax_param + if quantization_scheme.f_run_prequantize is not None: + self.f_run_prequantize = quantization_scheme.f_run_prequantize + + self.params_in_func[func_name] = [] + # For each parameter in the input model, get its quantization kind and + # register the parameter with its name and quantization kind. + for name, relax_param in named_parameters(model).items(): + quant_kind = f_get_param_quant_kind(name, relax_param.struct_info) + param = self._register_param( + name, + relax_param, + getattr(quantization_scheme, quant_kind.name), + func_name, + relax_param.__dict__.get("shard_dim", None), + relax_param.__dict__.get("shard_strategy", None), + ) + + self.params_in_func[func_name].append(param) + + def run_pre_quantize(self, model_path: str): + if self.f_run_prequantize is not None: + model_path = self.f_run_prequantize(model_path) + + self.model_path = model_path + return model_path + + def init_torch_pname_to_bin_name(self, use_safetensors: bool): + assert hasattr(self, "model_path"), ( + "Must call either set_param_loading_func or run_pre_quantize " + "before init_torch_pname_to_bin_name" + ) + + if self.pidx2pname: + mapping = load_torch_pname2binname_map( + self.model_path, + use_safetensors, + set(self.pidx2pname.values()), + self.f_convert_pname_fwd, + ) + else: + mapping = {} + + self.torch_pname2binname = mapping + + def set_param_loading_func( + self, + model_path: str, + use_safetensors: bool, + f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname], + f_convert_param_bkwd: Callable[ + [str, Any], Optional[List[Tuple[str, Any]]] + ] = lambda pname, torch_param: [(pname, torch_param)], + f_compute_relax_param: Callable[[str, List[Any]], Any] = f_default_compute_relax_param, + *, + no_lazy_param_loading: bool = False, + ) -> None: + """Set the parameter loading functions. + + Parameters + ---------- + model_path : str + The path of the Hugging Face model on disk. + + use_safetensors : bool + Whether to use ``.safetensors`` instead of ``.bin`` to load model. + + f_convert_pname_fwd : Callable[[str], List[str]] + The function which converts Relax parameter name (ours) to torch's + parameter names. See the document of ParamManager for more details. + + f_convert_param_bkwd : Callable[[str, Any], Optional[List[Tuple[str, Any]]]] + The function which converts torch parameter and param name back to + Relax parameters with names. `Any` here stands for numpy.ndarray. + See the document of ParamManager for more details. + + f_compute_relax_param : Callable[[str, List[Any]], Any] + The function which computes a Relax parameter from a list of torch + parameters. `Any` here stands for numpy.ndarray. + See the document of ParamManager for more details. + + no_lazy_param_loading : bool + A boolean indicating that no lazy parameter loading from torch is needed. + This needs to be set as True when all the model weights are loaded + at the time of constructing the model. + """ + self.f_convert_pname_fwd = f_convert_pname_fwd + if self.f_convert_param_bkwd is None: + self.f_convert_param_bkwd = f_convert_param_bkwd + if self.f_compute_relax_param is None: + self.f_compute_relax_param = f_compute_relax_param + + self.model_path = model_path + self.use_safetensors = use_safetensors + if self.use_safetensors: + # Use a pointer here to prevent repeated import in tvm registered function + from safetensors.torch import ( + load_file, # pylint: disable=import-outside-toplevel + ) + + def load_safetensors_func(*args): + params = load_file(*args) + for name, param in params.items(): + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + params[name] = param + return params + + self.safetensors_load_func = load_safetensors_func + + pnames_to_load = [] + for param_name in self.param_names: + param = self.params[param_name] + loaded_names, _ = param.quant_spec.get_loaded_tensor_info(param_name, param.param_info) + pnames_to_load += loaded_names + + self.nparam_to_load = len(pnames_to_load) + if not no_lazy_param_loading: + self.pidx2pname = {pidx: pname for pidx, pname in enumerate(pnames_to_load)} + else: + self.pidx2pname = dict() + + def transform_dequantize(self) -> tvm.ir.transform.Pass: + """Apply dequantization to the input IRModule. + + Parameters + ---------- + mod : tvm.IRModule + The input IRModule to be applied dequantization. + The IRModule contains all the constructed Relax functions + (e.g., the "prefill"/"decode" functions) and is expected to + have all of its parameters registered in the ParamManager. + + Returns + ------- + updated_mod : tvm.IRModule + The IRModule updated with the dequantization computation. + """ + + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} + + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: + func_name_to_quantized_params[gv.name_hint] = self.get_quantized_params( + gv.name_hint + ) + + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} + + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map + + func_name, param = self.func_raw_param_map[var] + quantized_params = func_name_to_quantized_params[func_name] + relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] + + dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + + dequantized_cache[var] = dequantized + return dequantized + + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() + + return mod + + return transform_func + + def get_quantized_params(self, func_name: str) -> List[relax.Var]: + quantized_params: List[relax.Var] = [] + + bb = relax.BlockBuilder() + with bb.function("main", []): + self.param2qrange = dict() + + for name in self.param_names: + param = self.params[name] + param_info = None + if func_name in param.param_info_dict: + param_info = param.param_info_dict[func_name] + else: + param_info = relax.TensorStructInfo( + tvm.ir.load_json(tvm.ir.save_json(param.param_info.shape)), + param.param_info.dtype, + ) + + loaded_tensor_names, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info( + name, param_info + ) + + provided_tensor_vars: List[relax.Var] = [ + relax.Var(name, sinfo) + for name, sinfo in zip(loaded_tensor_names, loaded_tensor_info) + ] + + # Get the quantization function of this parameter. + f_quantize = param.quant_spec.get_quantize_func(param_info) + if f_quantize is None: + # If the parameter does not have a quantization function, either it + # does not need quantization or it is pre-quantized. + self.param2qrange[param] = range( + len(quantized_params), + len(quantized_params) + len(provided_tensor_vars), + ) + quantized_params.extend(provided_tensor_vars) + else: + # If the parameter has a quantization function, it is not expected + # to be pre-quantized. + assert len(provided_tensor_vars) == 1, ( + "A parameter with quantization function is not expected " + "to be pre-quantized." + ) + + # Apply the quantization function. + quantized_data = bb.normalize(f_quantize(bb, provided_tensor_vars)) + if isinstance(quantized_data.struct_info, relax.TupleStructInfo): + fields = quantized_data.struct_info.fields + n_tensor = len(fields) + assert n_tensor > 1 + # Record the range of quantized tensors of this parameter. + self.param2qrange[param] = range( + len(quantized_params), + len(quantized_params) + n_tensor, + ) + # Collect the quantized tensors to return. + quantized_params.extend( + relax.Var(f"{name}.{field.dtype}.{i}", field) + for i, field in enumerate(fields) + ) + + else: + field = quantized_data.struct_info + assert isinstance(field, relax.TensorStructInfo) + self.param2qrange[param] = range( + len(quantized_params), len(quantized_params) + 1 + ) + quantized_params.append(relax.Var(f"{name}.{field.dtype}", field)) + bb.emit_func_output(relax.const(0, "int64")) + + return quantized_params + + def get_param_get_item( + self, device: Device, model_params: List[Optional[tvm.nd.NDArray]] = [] + ) -> Callable: + """A wrapper function which returns the `get_item` + functions for parameter lazy loading. + + The return value of this function is intended to be registered + as `"get_item"`, for use in a module built with + `LazyTransformParams`. + + .. code-block:: python + + get_item = manager.get_param_get_item(tvm.cuda()) + tvm.register_func(func_name="get_item", f=get_item, override=True) + compiled_function() + + Parameters + ---------- + device : Device + + The device onto which tensor parameters should be loaded. + + model_params : List[Optional[tvm.nd.NDArray]] + + Any pre-loaded model parameters. For parameter at index + `i`, if `model_params[i]` already contains an array, that + array will be returned from `get_item`. Otherwise, the + parameter will be loaded either from disk, or from an + internal cache. + + Returns + ------- + get_item: Callable[[int], tvm.nd.NDArray] + + A function that accepts an index, and returns the tensor + parameter located at that index, loaded onto `device`. + + """ + import torch # pylint: disable=import-outside-toplevel + + assert self.f_convert_pname_fwd is not None + assert self.f_convert_param_bkwd is not None + assert self.f_compute_relax_param is not None + pname2pidx: Dict[str, int] = {pname: pidx for pidx, pname in self.pidx2pname.items()} + + # The set of indices of loaded parameters, serving for + # robustness guarantee to avoid one parameter being loaded for + # multiple times. + loaded_idx_set: Set[int] = set() + + # The set of torch binary filenames, serving for robustness guarantee + # to avoid one torch binary file being loaded for multiple times. + loaded_torch_bins: Set[str] = set() + + # The set of cached Relax parameters. + cached_relax_params: Dict[int, tvm.nd.NDArray] = {} + + # The set of cached torch parameters. `Any` here stands for + # numpy.ndarray. + cached_torch_params: Dict[str, Any] = {} + + device_cpu = tvm.cpu() + + def fetch_torch_param(torch_param): + if str(torch_param.dtype) == "torch.bfloat16": + # Convert to float32 first. + return torch_param.detach().cpu().float().numpy() + else: + return torch_param.detach().cpu().numpy() + + def load_torch_params_from_bin(torch_binname: str): + torch_binpath = os.path.join(self.model_path, torch_binname) + torch_params = None + if self.use_safetensors: + torch_params = self.safetensors_load_func(torch_binpath) + else: + torch_params = torch.load( + torch_binpath, + map_location=torch.device("cpu"), + ) + torch_param_names = list(torch_params.keys()) + for torch_param_name in torch_param_names: + torch_param = fetch_torch_param(torch_params[torch_param_name]) + del torch_params[torch_param_name] + + relax_params = self.f_convert_param_bkwd(torch_param_name, torch_param) + if relax_params is not None: + for param_name, param in relax_params: + if param_name not in pname2pidx.keys(): + continue + pidx = pname2pidx[param_name] + assert pidx not in cached_relax_params + cached_relax_params[pidx] = tvm.nd.array(param, device_cpu) + else: + assert torch_param_name not in cached_torch_params + cached_torch_params[torch_param_name] = torch_param + del torch_param + + def get_item(i): + # If the weight is already provided by `model_params`, directly use it + # and no need to load from binary file. + if model_params and len(model_params) > i and model_params[i] is not None: + assert i not in cached_relax_params + return tvm.nd.array(model_params[i], device=device) + + # Otherwise, we load the weight from its corresponding binary file. + assert i in self.pidx2pname + relax_pname = self.pidx2pname[i] + torch_pnames = self.f_convert_pname_fwd(relax_pname) + + if i not in cached_relax_params: + for torch_binname in [ + self.torch_pname2binname[torch_pname] for torch_pname in torch_pnames + ]: + if torch_binname in loaded_torch_bins: + continue + load_torch_params_from_bin(torch_binname) + loaded_torch_bins.add(torch_binname) + + if i not in cached_relax_params: + assert len(torch_pnames) > 1 + assert all([torch_pname in cached_torch_params] for torch_pname in torch_pnames) + cached_relax_params[i] = self.f_compute_relax_param( + relax_pname, + [cached_torch_params[torch_pname] for torch_pname in torch_pnames], + ) + for torch_pname in torch_pnames: + del cached_torch_params[torch_pname] + + assert i in cached_relax_params + assert i not in loaded_idx_set + param_on_device = tvm.nd.array(cached_relax_params[i], device=device) + loaded_idx_set.add(i) + del cached_relax_params[i] + return param_on_device + + return get_item + + def get_param_set_item(self) -> Tuple[Callable, List[tvm.nd.NDArray]]: + """A wrapper function which returns the `set_item` + functions for parameter lazy loading. + + The return value of this function is intended to be registered + as `"set_item"`, for use in a module built with + `LazyTransformParams`. + + .. code-block:: python + + set_item,loaded_params = manager.get_param_set_item() + tvm.register_func(func_name="set_item", f=set_item, override=True) + compiled_function() + # `loaded_params` is now fully populated + + Returns + ------- + set_item: Callable[[int,tvm.nd.NDArray]] + + A function that accepts an index and the return value at + that index. + + loaded_params: List[tvm.nd.NDArray] + + A list of loaded parameters, populated by `set_item`. + When initially returned, this list is empty. After + executing the compiled function with + `LazyTransformParams`, `loaded_params` will be + populated. + """ + device_cpu = tvm.cpu() + loaded_params: List[tvm.nd.NDArray] = [] + + def set_item(i: int, computed_param: tvm.nd.NDArray): + if len(loaded_params) <= i: + loaded_params.extend([None for _ in range(i - len(loaded_params) + 1)]) + loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu) + + return set_item, loaded_params + + #################### Below are internally called methods #################### + + def _register_param( + self, + name: str, + var: relax.Var, + quant_spec: quantization.QuantizationSpec, + func_name: str, + shard_dim: Optional[int], + shard_strategy: Optional[str], + ) -> Parameter: + """Register a single parameter in the parameter manager. + In most cases, this method is not directly used outside this class: + it is called by `register_params` above. + + Parameters + ---------- + name : str + The name of the parameter to register. + Name serves as the unique identifier of the parameter. + + var : relax.Var + The parameter relax.Var on the nn.Module side. + + quant_spec : quantization.QuantizationSpec + The quantization specification of the parameter + + func_name : str + The name of the function the input var is in. + For example, the "prefill" function or the "decode" function. + + shard_dim : Optional[int] + The dimension along which the parameter is sharded. + + shard_strategy : Optional[str] + The strategy of sharding the parameter. + + Returns + ------- + param : Parameter + The registered Parameter. + """ + assert ( + var not in self.func_raw_param_map + ), "The input var is not supposed to be already registered." + assert isinstance( + var.struct_info.shape, relax.ShapeExpr + ), "The parameter to register is expected to have shape as a tuple" + + if name in self.params: + # When the input name appears in `self.params`, it means the input + # parameter has been previously registered in some other function. + # Thus, we check if the dtype, shape and the quantization specification + # of both sides are consistent. + param = self.params[name] + assert ( + param.quant_spec == quant_spec + ), "One parameter is expected to be quantized by single specification in all functions." + assert ( + param.param_info.dtype == var.struct_info.dtype + ), "Dtype mismatch of one parameter in two functions." + assert ( + param.param_info.ndim == var.struct_info.ndim + ), "Shape mismatch of one parameter in two functions." + for len0, len1 in zip(param.param_info.shape.values, var.struct_info.shape.values): + if isinstance(len0, tir.IntImm) and isinstance(len1, tir.IntImm): + assert ( + len0.value == len1.value + ), "Shape mismatch of one parameter in two functions." + else: + # Otherwise, the parameter is registered for the first time. + param = Parameter(name, quant_spec, shard_dim, shard_strategy) + self.params[name] = param + self.param_names.append(name) + + param.register_func(func_name, var.struct_info) + # Record the mapping from the input relax.Var to the function name and + # the parameter in the manager. + self.func_raw_param_map[var] = (func_name, param) + return param + + def _dequantize( + self, + param: Parameter, + qparams: List[relax.Var], + bb: relax.BlockBuilder, + func_name: str, + ) -> relax.Var: + """Applying dequantization to the input parameter. + This method is called by `transform_module` below, and is not + directly invoked outside the class. + + Parameters + ---------- + param : Parameter + The parameter whose quantized tensors are to be dequantized. + + qparams : List[relax.Var] + The relax.Var of the quantized tensors of all parameters in the model. + + Returns + ------- + The dequantized parameter, in the form of a relax.Var. + """ + # Get the dequantization function of this parameter. + f_dequantize = param.quant_spec.get_dequantize_func( + param_info=param.param_info_dict[func_name], + qparam_info=[qparam.struct_info for qparam in qparams], + ) + if f_dequantize is None: + # If the parameter does not have a dequantization function, its "quantized + # data" is expected to have only one element. + assert len(qparams) == 1, ( + "A parameter without dequantization function is expected not to have " + 'more than one "quantized data".' + ) + return qparams[0] + else: + # Apply the dequantization function. + return bb.emit(f_dequantize(bb, qparams)) + + def create_parameter_transformation(self, optimize_parameter_order: bool = True): + """Produce an IRModule that can transform the parameters + + Parameters + ---------- + optimize_parameter_order: bool + + If true, reorder the parameter transformations to + prioritize operations that use a currently-open file. If + false, transform the parameters in their default order. + + Returns + ------- + tvm.IRModule + The transformation module + + """ + mod = _create_quantize_func(self) + if optimize_parameter_order: + mod = self.optimize_transform_param_order()(mod) + return mod + + def optimize_transform_param_order(self) -> tvm.transform.Pass: + """Produce an transformation that optimizes for minimal memory footprint + + Returns + ------- + tvm.transform.Pass + The transformation + """ + return ReorderTransformFunc( + self.pidx2pname, + self.torch_pname2binname, + self.f_convert_pname_fwd, + ) + + +@mutator +class ParamReplacer(PyExprMutator): + """The function mutator that updates the model with dequantization. + + Attributes + ---------- + mod : tvm.IRModule + The IRModule of the model to be updated. + + func_name_to_quantized_params : Dict[str, List[relax.Var]] + The mapping from each function name to its input var of quantized data tuple. + + f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] + The function for updating a previous parameter in functions with dequantization. + + param_set : Set[relax.Var] + The set of previous parameters (before applying quantization and dequantization) + in the relax functions. + """ + + mod: tvm.IRModule + func_name_to_quantized_params: Dict[str, List[relax.Var]] + f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] + param_set: Set[relax.Var] + + cur_func_name: str + + def __init__( + self, + mod: tvm.IRModule, + func_name_to_quantized_params: Dict[str, relax.Var], + f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], + ): + super().__init__(mod) + self.mod = mod + self.func_name_to_quantized_params = func_name_to_quantized_params + self.f_replace = f_replace + self.cur_func_name = "" + + def transform(self) -> tvm.IRModule: + for gv, func in self.mod.functions.items(): + if not isinstance(func, relax.Function): + continue + if func.attrs is None or not "num_input" in func.attrs: + continue + + assert ( + gv.name_hint in self.func_name_to_quantized_params + ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" + updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: + num_input = int(func.attrs["num_input"]) + self.param_set = set(func.params[num_input:]) + + body = self.visit_expr(func.body) + return relax.Function( + params=func.params[:num_input] + quantized_params, + body=body, + ret_struct_info=func.ret_struct_info, + is_pure=func.is_pure, + attrs=func.attrs, + ) + + def visit_var_(self, var: Var) -> Expr: + if var in self.param_set: + return self.f_replace(var, self.builder_) + else: + return super().visit_var_(var) + + +################################################################## + + +def load_torch_pname2binname_map( + model_path: str, + use_safetensors: bool, + relax_pnames: Set[str], + f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname], +) -> Dict[str, str]: + """Constructing the dictionary from each torch parameter's name to + the name of the binary shard where the torch parameter is saved. + + Parameters + ---------- + model_path : str + The path of the Hugging Face model on disk. + + use_safetensors: bool + Whether to use ``.safetensors`` instead of ``.bin`` to load model. + + relax_pnames: Set[str] + The name of the Relax parameters. + + f_convert_pname_fwd: Callable[[str], List[str]] + The function which converts Relax parameter name to torch's + parameter names. See ParamManager for more details. + """ + bin_idx_path = None + single_shard_file_name = None + if use_safetensors: + bin_idx_path = os.path.join(model_path, "model.safetensors.index.json") + single_shard_file_name = "model.safetensors" + else: + bin_idx_path = os.path.join(model_path, "pytorch_model.bin.index.json") + single_shard_file_name = "pytorch_model.bin" + single_shard_path = os.path.join(model_path, single_shard_file_name) + + if os.path.isfile(bin_idx_path): + # Multiple weight shards. + with open(bin_idx_path, "r") as f_torch_json: + torch_bin_json = json.load(f_torch_json) + torch_pname2binname = torch_bin_json["weight_map"] + elif os.path.isfile(single_shard_path): + # Single weight shard. + torch_pname2binname = { + torch_pname: single_shard_file_name + for relax_pname in relax_pnames + for torch_pname in f_convert_pname_fwd(relax_pname) + } + else: + suffix = ".safetensors" if use_safetensors else ".bin" + shard_names = [] + # Collect Scan every single file with the suffix + for filename in os.listdir(model_path): + if filename.endswith(suffix): + shard_names.append(filename) + if len(shard_names) == 1: + torch_pname2binname = { + torch_pname: shard_names[0] + for relax_pname in relax_pnames + for torch_pname in f_convert_pname_fwd(relax_pname) + } + else: + raise ValueError("Multiple weight shard files without json map is not supported") + return torch_pname2binname + + +def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: + """Construct the Relax function which computes quantization. + This method is called by `transform_module` below, and is not + directly invoked outside the class. + + Parameters + ---------- + param_manager : ParamManager + The parameter manager which has all the parameter information. + + Returns + ------- + The created function which computes quantization. + Precisely, an IRModule which contains the main quantization Relax function + and a series of TIR functions is returned. + """ + bb = relax.BlockBuilder() + param2qrange = dict() + + # Construct the input of the function. + # We need a list of ranges for each + # parameter to get its corresponding tensors loaded from disk. + input_tensor_info: List[relax.TensorStructInfo] = [] + loaded_tensor_ranges: List[range] = [] + for name in param_manager.param_names: + param = param_manager.params[name] + _, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info(name, param.param_info) + loaded_tensor_ranges.append( + range( + len(input_tensor_info), + len(input_tensor_info) + len(loaded_tensor_info), + ) + ) + input_tensor_info += loaded_tensor_info + raw_param_tuple = relax.Var("params", relax.TupleStructInfo(input_tensor_info)) + + with bb.function("transform_params", params=[raw_param_tuple]): + with bb.dataflow(): + quantized_params: List[relax.Var] = [] + for pidx, name in enumerate(param_manager.param_names): + param = param_manager.params[name] + param_vars: List[relax.Var] = [] + # Emit relax.TupleGetItem to get the raw parameters or pre-quantized params. + for loaded_tensor_idx in loaded_tensor_ranges[pidx]: + param_vars.append( + bb.emit(relax.TupleGetItem(raw_param_tuple, loaded_tensor_idx)) + ) + + # Get the quantization function of this parameter. + f_quantize = param.quant_spec.get_quantize_func(param.param_info) + if f_quantize is None: + # If the parameter does not have a quantization function, either it + # does not need quantization or it is pre-quantized. + param2qrange[param] = range( + len(quantized_params), + len(quantized_params) + len(param_vars), + ) + quantized_params += param_vars + else: + # If the parameter has a quantization function, it is not expected + # to be pre-quantized. + assert len(param_vars) == 1, ( + "A parameter with quantization function is not expected " + "to be pre-quantized." + ) + + # Apply the quantization function. + quantized_data = bb.emit(f_quantize(bb, param_vars)) + + if isinstance(quantized_data.struct_info, relax.TupleStructInfo): + n_tensor = len(quantized_data.struct_info.fields) + assert n_tensor > 1 + # Record the range of quantized tensors of this parameter. + param2qrange[param] = range( + len(quantized_params), len(quantized_params) + n_tensor + ) + # Collect the quantized tensors to return. + for i in range(n_tensor): + quantized_params.append(bb.emit(relax.TupleGetItem(quantized_data, i))) + else: + assert isinstance(quantized_data.struct_info, relax.TensorStructInfo) + param2qrange[param] = range( + len(quantized_params), len(quantized_params) + 1 + ) + quantized_params.append(quantized_data) + + output = bb.emit_output(relax.Tuple(quantized_params)) + bb.emit_func_output(output) + + mod = bb.get() + param_manager.param2qrange = param2qrange + # Return the created IRModule. + return bb.get() + + +def transform_params_for_each_rank( + mod: tvm.IRModule, num_shards: int, rank_argument_name: str = "rank_arg" +) -> tvm.IRModule: + """Update a parameter transform to apply across all ranks + + For use in generating a pre-sharded set of weights. Given a + parameter transformation that generates sharded model weights for + a single shard, produce a parameter transformation that generates + sharded model weights for each shard. + + Parameters + ---------- + mod: tvm.IRModule + + A module containing the parameter transformation function, + named "transform_params", along with any subroutines called by + the parameter transformation. + + num_shards: int + + The number of shards to generate. + + rank_argument_name: str + + The name of the argument that specifies the rank. Should be a + R.ShapeTuple with a single R.PrimStructInfo('int64'). + + Returns + ------- + tvm.IRModule + + The modified parameter transformation + """ + generic_transform = mod["transform_params"] + tensor_params = generic_transform.params[1:] + + bb = relax.BlockBuilder() + + with bb.function("transform_params", params=tensor_params): + output = [] + for rank in range(num_shards): + # TODO(Lunderberg): Implement this in terms of a + # generic utility that inlines local functions. + func = generic_transform + func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])}) + func = relax.utils.copy_with_new_vars(func) + func = func.bind_params( + {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)} + ) + shard_tuple = func.body + output.extend([shard_tuple[i] for i in range(len(tensor_params))]) + + with bb.dataflow(): + gv = bb.emit_output(relax.Tuple(output)) + bb.emit_func_output(gv) + + mod["transform_params"] = bb.get()["transform_params"] + return mod + + +def chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm.IRModule: + """Chain two sequential parameter transformations + + For use in manipulating sets of model weights. Given two + parameter transformations that could be applied sequentially, + produce a single parameter transformation whose output is the same + as applying the parameter transformations sequentially. + + + .. code-block:: python + + # Before + params_after_a = mod_a['transform_params'](orig_params) + params_after_b = mod_b['transform_params'](params_after_a) + + # After + mod_ab = chain_parameter_transforms(mod_a, mod_b) + params_after_b = mod_ab['transform_params'](orig_params) + + Parameters + ---------- + mod_a: tvm.IRModule + + The module containing the first parameter transformation. + + mod_b: tvm.IRModule + + The module containing the second parameter transformation. + + Returns + ------- + tvm.IRModule + + The module containing the output + + """ + func_a = mod_a["transform_params"] + func_b = mod_b["transform_params"] + + bb = relax.BlockBuilder() + + with bb.function("transform_params", params=func_a.params): + with bb.dataflow(): + # TODO(Lunderberg): Implement this in terms of a + # generic utility that inlines local functions. + func_a_output = bb.emit(func_a.body) + func_b_param_map = {param: expr for (param, expr) in zip(func_b.params, func_a_output)} + func_b_output = func_b.bind_params(func_b_param_map).body + gv = bb.emit_output(func_b_output) + bb.emit_func_output(gv) + + merged_transform_func = bb.get()["transform_params"] + + new_mod = { + **{ + gvar: func + for gvar, func in mod_a.functions.items() + if gvar.name_hint != "transform_params" + }, + **{ + gvar: func + for gvar, func in mod_b.functions.items() + if gvar.name_hint != "transform_params" + }, + "transform_params": merged_transform_func, + } + return tvm.IRModule(new_mod) diff --git a/mlc_llm/relax_model/rwkv.py b/mlc_llm/relax_model/rwkv.py new file mode 100644 index 0000000..5b47cc3 --- /dev/null +++ b/mlc_llm/relax_model/rwkv.py @@ -0,0 +1,641 @@ +# pylint: disable=missing-docstring,invalid-name +from dataclasses import dataclass +from typing import List, Literal, Tuple + +from tvm import relax, te, tir +from tvm.relax import Expr, op +from tvm.relax.testing import nn +from tvm.script import relax as R +from tvm.script import tir as T + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList, Linear +from .param_manager import ParamManager + +# Reference: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model_run.py + + +@dataclass +class RWKVConfig: + """The configuration class to store the configuration of a `RWKVModel`.""" + + num_hidden_layers: int + vocab_size: int + hidden_size: int + intermediate_size: int + rescale_every: int = 0 + layer_norm_epsilon: float = 1e-5 + max_sequence_length: int = 1024 + dtype: str = "float32" + + def __init__( + self, + num_hidden_layers: int, + vocab_size: int, + hidden_size: int, + intermediate_size: int, + rescale_every: int = 0, + layer_norm_epsilon: float = 1e-5, + context_length: int = 1024, + dtype: str = "float32", + **kwargs, + ) -> None: + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.rescale_every = rescale_every + self.layer_norm_epsilon = layer_norm_epsilon + self.max_sequence_length = context_length + self.dtype = dtype + self.kwargs = kwargs + + +class State: + ATT_X = 0 + ATT_A = 1 + ATT_B = 2 + ATT_P = 3 + FFN_X = 4 + + +def _load_state(state: Expr, hidden_size: int, dtype: str) -> Expr: + # Reuse `attention_kv_cache_view` + f_load_cache = relax.extern("vm.builtin.attention_kv_cache_view") + cache = nn.emit( + relax.call_pure_packed( + f_load_cache, + args=[state, R.shape([1, hidden_size])], + sinfo_args=[R.Tensor((1, hidden_size), dtype)], + ) + ) + return cache + + +def _store_state(state: Expr, value: Expr): + # Reuse `attention_kv_cache_update` + f_store_cache = relax.extern("vm.builtin.attention_kv_cache_update") + + return nn.emit( + relax.op.call_inplace_packed( + f_store_cache, + args=[state, value], + inplace_indices=[0], + sinfo_args=[R.Object()], + ) + ) + + +def is_one(x: tir.PrimExpr) -> bool: + return isinstance(x, tir.IntImm) and x.value == 1 + + +def create_wkv_func(hidden_size: int, dtype: str, out_dtype: str): + @T.prim_func + def wkv_func( + k: T.handle, + v: T.handle, + time_decay: T.handle, + time_first: T.handle, + saved_a: T.handle, + saved_b: T.handle, + saved_p: T.handle, + wkv: T.handle, + out_a: T.handle, + out_b: T.handle, + out_p: T.handle, + ): + T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1}) + context_length = T.int64() + K = T.match_buffer(k, (context_length, hidden_size), dtype=dtype) + V = T.match_buffer(v, (context_length, hidden_size), dtype=dtype) + TimeDecay = T.match_buffer(time_decay, (hidden_size,), dtype=dtype) + TimeFirst = T.match_buffer(time_first, (hidden_size,), dtype=dtype) + SavedA = T.match_buffer(saved_a, (1, hidden_size), dtype=dtype) + SavedB = T.match_buffer(saved_b, (1, hidden_size), dtype=dtype) + SavedP = T.match_buffer(saved_p, (1, hidden_size), dtype=dtype) + WKV = T.match_buffer(wkv, (context_length, hidden_size), dtype=out_dtype) + OutA = T.match_buffer(out_a, (1, hidden_size), dtype=dtype) + OutB = T.match_buffer(out_b, (1, hidden_size), dtype=dtype) + OutP = T.match_buffer(out_p, (1, hidden_size), dtype=dtype) + + P = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + E1 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + E2 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + A_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + B_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + P_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") + + for bx in T.thread_binding(hidden_size // 32, thread="blockIdx.x"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("init"): + vi = T.axis.S(hidden_size, bx * 32 + tx) + A_local[vi] = SavedA[0, vi] + B_local[vi] = SavedB[0, vi] + P_local[vi] = SavedP[0, vi] + for j in range(context_length): + with T.block("main"): + vi = T.axis.S(hidden_size, bx * 32 + tx) + vj = T.axis.opaque(context_length, j) + P[vi] = T.max(P_local[vi], K[vj, vi] + TimeFirst[vi]) + E1[vi] = T.exp(P_local[vi] - P[vi]) + E2[vi] = T.exp(K[vj, vi] + TimeFirst[vi] - P[vi]) + WKV[vj, vi] = T.cast( + (E1[vi] * A_local[vi] + E2[vi] * V[vj, vi]) + / (E1[vi] * B_local[vi] + E2[vi]), + out_dtype, + ) + + P[vi] = T.max(P_local[vi] + TimeDecay[vi], K[vj, vi]) + E1[vi] = T.exp(P_local[vi] + TimeDecay[vi] - P[vi]) + E2[vi] = T.exp(K[vj, vi] - P[vi]) + A_local[vi] = E1[vi] * A_local[vi] + E2[vi] * V[vj, vi] + B_local[vi] = E1[vi] * B_local[vi] + E2[vi] + P_local[vi] = P[vi] + + with T.block("write_back"): + vi = T.axis.S(hidden_size, bx * 32 + tx) + OutA[0, vi] = A_local[vi] + OutB[0, vi] = B_local[vi] + OutP[0, vi] = P_local[vi] + + return wkv_func + + +def _te_concat_saved_x(saved_x: te.Tensor, x: te.Tensor): + return te.compute( + x.shape, + lambda i, j: tir.if_then_else(i == 0, saved_x[0, j], x[i - 1, j]), + ) + + +def _te_get_last_x(x: te.Tensor): + seq_len, hidden_size = x.shape + return te.compute((1, hidden_size), lambda _, j: x[seq_len - 1, j]) + + +class RWKV_Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, dtype): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter( + (num_embeddings, embedding_dim), dtype=dtype, name="weight" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit(op.reshape(x, shape=[-1])) + return nn.emit(op.take(self.weight, x, axis=0)) + + +class RWKV_LayerNorm(nn.Module): + def __init__(self, intermediate_size, dtype, eps=1e-5, name_prefix=""): + super().__init__() + self.eps = eps + self.weight = nn.Parameter( + (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_weight" + ) + self.bias = nn.Parameter( + (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_bias" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit( + op.nn.layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class RWKV_FFN(nn.Module): + def __init__(self, config: RWKVConfig, index: int) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.dtype = config.dtype + self.index = index + self.time_mix_key = nn.Parameter( + (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_k" + ) + self.time_mix_receptance = nn.Parameter( + (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_r" + ) + self.key = Linear( + self.hidden_size, config.intermediate_size, dtype=config.dtype, bias=False + ) + self.receptance = Linear( + self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False + ) + self.value = Linear( + config.intermediate_size, self.hidden_size, dtype=config.dtype, bias=False + ) + + def forward(self, x: Expr, state: Expr) -> Expr: + offset = self.index * 5 + State.FFN_X + context_length = x.struct_info.shape[0] + hidden_size = self.hidden_size + + saved_x = _load_state(state[offset], hidden_size, self.dtype) + if not is_one(context_length): + saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x) + ones = nn.emit(relax.op.ones((hidden_size,), self.dtype)) + xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) + xr = nn.emit( + x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance) + ) + if not is_one(context_length): + x = nn.emit_te(_te_get_last_x, x) + assert is_one(x.struct_info.shape[0]) + saved_x = _store_state(state[offset], x) + + r = nn.emit(op.sigmoid(self.receptance(xr))) + xv = nn.emit(op.square(op.nn.relu(self.key(xk)))) + + return nn.emit(r * self.value(xv)), [saved_x] + + +class RWKV_Attention(nn.Module): + def __init__(self, config: RWKVConfig, index: int) -> None: + super().__init__() + self.index = index + self.dtype = config.dtype + self.hidden_size = config.hidden_size + self.time_decay = nn.Parameter( + (self.hidden_size,), dtype="float32", name=f"att_{index}_time_decay" + ) + self.time_first = nn.Parameter( + (self.hidden_size,), dtype="float32", name=f"att_{index}_time_first" + ) + self.time_mix_key = nn.Parameter( + (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_k" + ) + self.time_mix_value = nn.Parameter( + (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_v" + ) + self.time_mix_receptance = nn.Parameter( + (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_r" + ) + self.key = Linear( + self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False + ) + self.value = Linear( + self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False + ) + self.receptance = Linear( + self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False + ) + self.output = Linear( + self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False + ) + + def forward(self, x: Expr, state: Expr) -> Expr: + # Load current state + ones = nn.emit(relax.op.ones((self.hidden_size,), self.dtype)) + index = self.index + hidden_size = self.hidden_size + context_length = x.struct_info.shape[0] + bb = relax.BlockBuilder.current() + + saved_a = _load_state(state[index * 5 + State.ATT_A], hidden_size, "float32") + saved_b = _load_state(state[index * 5 + State.ATT_B], hidden_size, "float32") + saved_p = _load_state(state[index * 5 + State.ATT_P], hidden_size, "float32") + saved_x = _load_state(state[index * 5 + State.ATT_X], hidden_size, self.dtype) + if not is_one(context_length): + saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x) + + xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) + xv = nn.emit(x * self.time_mix_value + saved_x * (ones - self.time_mix_value)) + xr = nn.emit( + x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance) + ) + + r = nn.emit(op.sigmoid(self.receptance(xr))) + k = nn.emit(op.astype(self.key(xk), "float32")) + v = nn.emit(op.astype(self.value(xv), "float32")) + + gv = bb.add_func(create_wkv_func(hidden_size, "float32", self.dtype), "wkv") + ret = nn.emit( + relax.call_tir( + gv, + [k, v, self.time_decay, self.time_first, saved_a, saved_b, saved_p], + [ + R.Tensor((context_length, hidden_size), self.dtype), + R.Tensor((1, hidden_size), "float32"), + R.Tensor((1, hidden_size), "float32"), + R.Tensor((1, hidden_size), "float32"), + ], + ) + ) + if not is_one(context_length): + x = nn.emit_te(_te_get_last_x, x) + + assert is_one(x.struct_info.shape[0]) + saved_x = _store_state(state[self.index * 5 + State.ATT_X], x) + saved_a = _store_state(state[self.index * 5 + State.ATT_A], ret[1]) + saved_b = _store_state(state[self.index * 5 + State.ATT_B], ret[2]) + saved_p = _store_state(state[self.index * 5 + State.ATT_P], ret[3]) + + return nn.emit(self.output(r * ret[0])), [ + saved_x, + saved_a, + saved_b, + saved_p, + ] + + +class RWKVLayer(nn.Module): + def __init__(self, config: RWKVConfig, index: int) -> None: + super().__init__() + if index == 0: + self.pre_ln = RWKV_LayerNorm( + config.hidden_size, + config.dtype, + eps=config.layer_norm_epsilon, + name_prefix="pre_ln", + ) + self.ln1 = RWKV_LayerNorm( + config.hidden_size, + config.dtype, + eps=config.layer_norm_epsilon, + name_prefix=f"att_{index}", + ) + self.ln2 = RWKV_LayerNorm( + config.hidden_size, + config.dtype, + eps=config.layer_norm_epsilon, + name_prefix=f"ffn_{index}", + ) + self.attention = RWKV_Attention(config, index) + self.feed_forward = RWKV_FFN(config, index) + self.rescale_every = config.rescale_every + self.dtype = config.dtype + self.index = index + + def forward(self, x: Expr, state: Expr) -> Tuple[Expr, List[Expr]]: + if self.index == 0: + x = self.pre_ln(x) + att, att_state = self.attention(self.ln1(x), state) + x = nn.emit(x + att) + ffn, ffn_state = self.feed_forward(self.ln2(x), state) + x = nn.emit(x + ffn) + if self.rescale_every > 0 and (self.index + 1) % self.rescale_every == 0: + x = nn.emit(x / relax.const(2, dtype=self.dtype)) + return x, att_state + ffn_state + + +class RWKVModel(nn.Module): + def __init__(self, config: RWKVConfig) -> None: + super().__init__() + self.embeddings = RWKV_Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + ) + self.blocks = ModuleList( + [RWKVLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.ln_out = RWKV_LayerNorm( + config.hidden_size, + config.dtype, + eps=config.layer_norm_epsilon, + name_prefix="out_ln", + ) + self.hidden_size = config.hidden_size + self.dtype = config.dtype + + def forward(self, input_ids: Expr, state: Expr) -> Tuple[Expr, List[Expr]]: + hidden_states = self.embeddings(input_ids) + states = [] + for _, layer in enumerate(self.blocks): + hidden_states, layer_states = layer(hidden_states, state) + states += layer_states + context_length = hidden_states.struct_info.shape[0] + if not is_one(context_length): + hidden_states = nn.emit_te(_te_get_last_x, hidden_states) + hidden_states = self.ln_out(hidden_states) + return hidden_states, states + + +class RWKVForCausalLM(nn.Module): + def __init__(self, config: RWKVConfig): + self.rwkv = RWKVModel(config) + self.head = Linear( + config.hidden_size, config.vocab_size, dtype=config.dtype, bias=False + ) + self.vocab_size = config.vocab_size + ############ End ############ + + def forward( + self, + input_ids: relax.Expr, + state: relax.Expr, + ): + hidden_states, key_value_cache = self.rwkv(input_ids, state) + logits = nn.emit(self.head(hidden_states)) + logits = nn.emit(op.reshape(logits, (1, 1, self.vocab_size))) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind( + name: str, param_info: relax.TensorStructInfo +) -> ParamQuantKind: + if name.endswith("embeddings.weight"): + return ParamQuantKind.embedding_table + elif name == "head.weight": + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: RWKVConfig, + quant_scheme: QuantizationScheme, + func_name=Literal["prefill", "decode"], +): + if func_name not in ["prefill", "decode"]: + raise ValueError(f"func_name must be 'prefill' or 'decode', got {func_name}") + seq_len = 1 if func_name == "decode" else tir.SizeVar("n", "int64") + + with bb.function(func_name): + model = RWKVForCausalLM(config) + param_manager.register_params( + model, func_name, quant_scheme, get_param_quant_kind + ) + + input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") + # Placeholder for compatibility to LLAMA + all_seq_len_shape = relax.Var("place_holder", R.Object()) + state = relax.Var("state", R.Tuple([R.Object()] * config.num_hidden_layers * 5)) + with bb.dataflow(): + logits, states = model(input_ids, state) + params = [ + input_ids, + all_seq_len_shape, + state, + ] + model.parameters() + + gv = bb.emit_output((logits, relax.Tuple(states))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + f = mod[gv].with_attr("num_input", 3) + if func_name == "prefill": + f = f.with_attr("tir_var_upper_bound", {"n": config.max_sequence_length}) + bb.update_func(gv, f) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: + """NOTE: It's not typical kv-cache, but try to reuse the logic for the quick hack.""" + init_shape = relax.ShapeExpr((1, config.hidden_size)) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + input_dtype_zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + fp32_zeros = bb.emit(relax.op.zeros(init_shape, "float32")) + fp32_neg_inf = bb.emit(fp32_zeros - relax.const(1e30, "float32")) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + conf = [ + ("att_x", input_dtype_zeros), + ("att_a", fp32_zeros), + ("att_b", fp32_zeros), + ("att_p", fp32_neg_inf), + ("ffn_x", input_dtype_zeros), + ] + for i in range(config.num_hidden_layers): + for name, init_value in conf: + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[init_value, init_shape, relax.PrimValue(1)], + sinfo_args=[R.Object()], + ), + name_hint=f"{name}_state_{i}", + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_kv_cache_reset_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: + state = relax.Var("state", R.Tuple([R.Object()] * config.num_hidden_layers * 5)) + init_shape = relax.ShapeExpr((1, config.hidden_size)) + with bb.function("reset_kv_cache", [state]): + with bb.dataflow(): + input_dtype_zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + fp32_zeros = bb.emit(relax.op.zeros(init_shape, "float32")) + fp32_neg_inf = bb.emit(fp32_zeros - relax.const(1e30, "float32")) + caches = [] + for i in range(config.num_hidden_layers): + caches.append( + _store_state(state[i * 5 + State.ATT_X], input_dtype_zeros) + ) + caches.append(_store_state(state[i * 5 + State.ATT_B], fp32_zeros)) + caches.append(_store_state(state[i * 5 + State.ATT_A], fp32_zeros)) + caches.append(_store_state(state[i * 5 + State.ATT_P], fp32_neg_inf)) + caches.append( + _store_state(state[i * 5 + State.FFN_X], input_dtype_zeros) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, config.vocab_size), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args, hf_config): + model_name = args.model + max_seq_len = args.max_seq_len + dtype = args.quantization.model_dtype + + if not model_name.lower().startswith("rwkv-"): + raise ValueError(f"Unsupported model name: {model_name}") + + config = RWKVConfig(**hf_config, dtype=dtype) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + create_func(bb, param_manager, config, args.quantization, "prefill") + create_func(bb, param_manager, config, args.quantization, "decode") + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + # RNN model do not have window size limit + max_window_size=-1, + stop_tokens=[0], + add_prefix_space=False, + ) + create_kv_cache_reset_func(bb, config) + mod = bb.get() + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if ( + "key_weight" in pname + or "value_weight" in pname + or "receptance_weight" in pname + or "output_weight" in pname + or "head_weight" in pname + ): + return [pname.replace("_weight", ".weight")] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + # torch_param: numpy.ndarray + import numpy as np # pylint: disable=import-outside-toplevel + + # rescale_every + if config.rescale_every > 0 and "blocks." in torch_pname: + # based-on the assumption that the layer id is the second element in torch_pname + layer_id = int(torch_pname.split(".")[2]) + if ( + "attention.output.weight" in torch_pname + or "feed_forward.value.weight" in torch_pname + ): + torch_param = torch_param / (2 ** (layer_id // config.rescale_every)) + + # reshape + if "time_" in torch_pname: + torch_param = torch_param.squeeze() + + # convert dtype + if "time_decay" in torch_pname: # need fp32 for this + return [(torch_pname, -np.exp(torch_param.astype("float32")))] + elif "time_first" in torch_pname: + return [(torch_pname, torch_param.astype("float32"))] + else: + return [(torch_pname, torch_param.astype(config.dtype))] + + param_manager.set_param_loading_func( + args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd + ) + return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py new file mode 100644 index 0000000..ac1c9a7 --- /dev/null +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -0,0 +1,913 @@ +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl +from tvm.relax.op.nn import layer_norm +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .llama import Embedding, Linear +from .modules import ModuleList, RotaryEmbedding +from .param_manager import ParamManager + + +@dataclass +class StableLM3bConfig: + def __init__( + self, + dtype="float32", + max_sequence_length=4096, + vocab_size=50304, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + norm_eps=1e-5, + pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + position_embedding_base=10000, + combine_matmul=True, + num_shards=1, + build_model_only=False, + convert_weights_only=False, + **kwargs, + ): + self.dtype = dtype + self.max_sequence_length = max_sequence_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.position_embedding_base = position_embedding_base + self.combine_matmul = combine_matmul + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + return self.num_key_value_heads + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + dtype, + eps=1e-5, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter((hidden_size,), dtype="float16", name="weight") + self.bias = nn.Parameter((hidden_size,), dtype="float16", name="bias") + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit( + layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class StableLM3bMLP(nn.Module): + def __init__(self, config: StableLM3bConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.gate_proj.weight.shard_dim = 0 + self.up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +class StableLM3bAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = ( + config.num_key_value_heads is None + and config.num_attention_heads + or config.num_key_value_heads + ) // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.position_embedding_base = config.position_embedding_base + self.rotary_embedding = rotary_embedding + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: + from tvm.relax.op import ( + astype, + matmul, + maximum, + permute_dims, + reshape, + split, + squeeze, + ) + from tvm.relax.op.nn import softmax + + bsz, q_len, _ = hidden_states.struct_info.shape + assert bsz == 1, "Only support batch size 1 at this moment." + + if self.combine_matmul: + qkv_states = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query_states = relax.TupleGetItem(qkv_states, 0) + key_states = relax.TupleGetItem(qkv_states, 1) + value_states = relax.TupleGetItem(qkv_states, 2) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = nn.emit( + reshape( + query_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + ), + ) + key_states = nn.emit( + reshape( + key_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + value_states = nn.emit( + reshape( + value_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + + kv_seq_len = all_seq_len_shape.struct_info.values[0] + offset = kv_seq_len - q_len + query_states, key_states = self.rotary_embedding(query_states, key_states, offset) + # [bsz, t, nh, hd] + + kv_states_shape = key_states.struct_info.shape + kv_states_dtype = key_states.struct_info.dtype + assert kv_states_shape[0] == 1 # bsz + kv_states_shape = R.shape( + [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) + + squeezed_key = nn.emit(squeeze(key_states, axis=0)) + squeezed_value = nn.emit(squeeze(value_states, axis=0)) + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[k_cache, squeezed_key], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.op.call_inplace_packed( + f_kv_cache_append, + args=[v_cache, squeezed_value], + inplace_indices=[0], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + v_cache = nn.emit( + relax.call_pure_packed( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + key_states = nn.emit(reshape(k_cache, kv_states_shape)) + value_states = nn.emit(reshape(v_cache, kv_states_shape)) + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) + value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) + + query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) + key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) + value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query_states.struct_info.dtype: + attn_weights = astype(attn_weights, query_states.struct_info.dtype) + attn_output = nn.emit(matmul(attn_weights, value_states)) + + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + + attn_output = self.o_proj(attn_output) + return attn_output, ((None, None) if past_key_value is None else past_key_value) + + +class StableLM3bDecoderLayer(nn.Module): + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + self.hidden_size = config.hidden_size + self.self_attn = StableLM3bAttention(config, rotary_embedding) + self.mlp = StableLM3bMLP(config) + self.input_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + self.post_attention_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + ) + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + return hidden_states, present_key_value + + +def _make_causal_mask(input_ids_shape, dtype, src_len): + from tvm.relax.op import broadcast_to + + bsz, tgt_len = input_ids_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, _, i, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + + +class StableLM3bEmbedTokens(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class StableLM3bEmbedTokensWrapper(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = StableLM3bEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class StableLM3bModell(nn.Module): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): + rotary_embedding = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=config.position_embedding_base, + max_sequence_length=config.max_sequence_length, + rotary_pct=0.25, + dtype=config.dtype, + ) + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [ + StableLM3bDecoderLayer(config, rotary_embedding) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) + + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + combined_attention_mask = nn.emit( + relax.op.full( + (bsz, 1, tgt_len, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + return combined_attention_mask + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + # embed positions + attention_mask = self._prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class StableLM3bForCausalLM(nn.Module): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): + self.model = StableLM3bModell(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + assert config.hidden_size % config.num_attention_heads == 0 + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.SizeVar("m", "int64") + with bb.function(func_name): + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.SizeVar("m", "int64") + + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.call_pure_packed( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def emit_shard3d(bb: relax.BlockBuilder) -> None: + from tvm.script import tir as T + + def _emit(dtype: str, global_symbol: str): + @T.prim_func + def shard_3d(a: T.handle, num_shards: T.int64, b: T.handle): + T.func_attr( + { + "tir.noalias": T.bool(True), + "global_symbol": global_symbol, + } + ) + s_0, s_1, s_2 = T.int64(), T.int64(), T.int64() + # pylint: disable=invalid-name + A = T.match_buffer(a, (s_0, s_1, s_2), dtype) + B = T.match_buffer(b, (num_shards, s_0, s_1 // num_shards, s_2), dtype) + # pylint: enable=invalid-name + for j_o, i, j_i, k in T.grid(num_shards, s_0, s_1 // num_shards, s_2): + with T.block("B"): + v_j_o = T.axis.spatial(num_shards, j_o) + v_i = T.axis.spatial(s_0, i) + v_j_i = T.axis.spatial(s_1 // num_shards, j_i) + v_k = T.axis.spatial(s_2, k) + B[v_j_o, v_i, v_j_i, v_k] = A[v_i, v_j_o * (s_1 // num_shards) + v_j_i, v_k] + + bb.add_func(shard_3d, global_symbol) + + _emit("float32", "shard3d_fp32") + _emit("float16", "shard3d_fp16") + _emit("uint32", "shard3d_uint32") + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + max_seq_len = args.max_seq_len + sep_embed = args.sep_embed + + position_embedding_base = 10000 + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + config = StableLM3bConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + convert_weights_only=args.convert_weights_only, + ) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + emit_shard3d(bb) + + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, + ) + + mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + num_shards = args.num_shards + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.num_key_value_heads + if kv_heads is None: + kv_heads = q_heads + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + q = q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)) + k = k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + v = v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + qkv = np.concatenate([q, k, v], axis=1) + qkv = qkv.reshape((-1, hidden_size)).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + intermediate_size = config.intermediate_size + gate, up = torch_params + gate = gate.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + up = up.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + gate_up = np.concatenate([gate, up], axis=1) + gate_up = gate_up.reshape((-1, hidden_size)).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + param_list = [None] * param_manager.nparam_to_load + + return mod, param_manager, param_list, config diff --git a/mlc_llm/transform/__init__.py b/mlc_llm/transform/__init__.py new file mode 100644 index 0000000..2c67369 --- /dev/null +++ b/mlc_llm/transform/__init__.py @@ -0,0 +1,9 @@ +from .clean_up_tir_attrs import CleanUpTIRAttrs +from .decode_matmul_ewise import FuseDecodeMatmulEwise +from .decode_take import FuseDecodeTake +from .decode_transpose import FuseDecodeTranspose +from .fuse_split_rotary_embedding import fuse_split_rotary_embedding +from .lift_tir_global_buffer_alloc import LiftTIRGlobalBufferAlloc +from .reorder_transform_func import ReorderTransformFunc +from .rewrite_attention import rewrite_attention +from .transpose_matmul import FuseTransposeMatmul, FuseTranspose1Matmul, FuseTranspose2Matmul diff --git a/mlc_llm/transform/clean_up_tir_attrs.py b/mlc_llm/transform/clean_up_tir_attrs.py new file mode 100644 index 0000000..93a90f8 --- /dev/null +++ b/mlc_llm/transform/clean_up_tir_attrs.py @@ -0,0 +1,25 @@ +"""Clean up TIR attributes that may affect dispatching""" + +import tvm +from tvm.ir.module import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs") +class CleanUpTIRAttrs: + def transform_module( + self, mod: IRModule, ctx: tvm.transform.PassContext + ) -> IRModule: + undesired_attrs = ["op_pattern"] + + for gv in list(mod.functions): + func = mod[gv] + changed = False + for attr in undesired_attrs: + if func.attrs is not None and attr in func.attrs: + func = func.without_attr(attr) + changed = True + break + + if changed: + mod[gv] = func + return mod diff --git a/mlc_llm/transform/decode_matmul_ewise.py b/mlc_llm/transform/decode_matmul_ewise.py new file mode 100644 index 0000000..7471848 --- /dev/null +++ b/mlc_llm/transform/decode_matmul_ewise.py @@ -0,0 +1,84 @@ +import tvm +from tvm import IRModule, relax, tir +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard + + +def check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["w"] + if not isinstance(call, relax.Call): + return False + gv = call.args[0] + if not isinstance(gv, relax.GlobalVar): + return False + return gv.name_hint.startswith("decode") or gv.name_hint.startswith("fused_decode") + + +def check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["matmul"] + if not isinstance(call, relax.Call): + return False + gv = call.args[0] + if not isinstance(gv, relax.GlobalVar): + return False + return ( + gv.name_hint.startswith("matmul") + or gv.name_hint.startswith("fused_matmul") + or gv.name_hint.startswith("NT_matmul") + or gv.name_hint.startswith("fused_NT_matmul") + ) + + +def pattern_check(): + def f_pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: + return check_decoding(ctx) and check_matmul(ctx) + + return f_pattern_check + + +def decode_matmul_pattern(match_ewise: int, n_aux_tensor: int): + assert n_aux_tensor == 1 or n_aux_tensor == 2 or n_aux_tensor == 3 or n_aux_tensor == 4 + + w_scaled = wildcard() + aux_tensors = [wildcard(), wildcard(), wildcard(), wildcard()] + x = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([w_scaled, *aux_tensors[0:n_aux_tensor]]), + add_constraint=False, + ) + matmul_args = [x, w] + for _ in range(match_ewise): + matmul_args.append(wildcard()) + matmul = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern(matmul_args), add_constraint=False + ) + + annotations = { + "matmul": matmul, + "w": w, + "x": x, + "w_scaled": w_scaled, + } + return matmul, annotations, pattern_check() + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise") +class FuseDecodeMatmulEwise: + def transform_module( + self, mod: IRModule, ctx: tvm.transform.PassContext # pylint: disable=unused-argument + ) -> IRModule: + for n_aux_tensor in [1, 2, 3, 4]: + for match_ewise in [0, 1, 2, 6]: + if match_ewise == 6 and n_aux_tensor != 4: + continue + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_matmul", + *decode_matmul_pattern(match_ewise, n_aux_tensor), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + + return mod diff --git a/mlc_llm/transform/decode_take.py b/mlc_llm/transform/decode_take.py new file mode 100644 index 0000000..cd09771 --- /dev/null +++ b/mlc_llm/transform/decode_take.py @@ -0,0 +1,71 @@ +"""Fusing and inlining decode function into embedding table lookup.""" +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_const, is_op, wildcard + + +def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: + take = ctx.annotated_expr["take"] + decode = ctx.annotated_expr["decode"] + if not isinstance(decode, relax.expr.Call): + return False + if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( + decode.args[0], relax.GlobalVar + ): + return False + return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint + + +def decode_take_pattern(n_aux_tensor: int, match_tir_vars: bool): + aux_tensors = [wildcard(), wildcard(), wildcard()] + decode = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([*aux_tensors[0:n_aux_tensor]]), + add_constraint=False, + ) + indices = ~is_const() + take_args = [decode, indices] + call_tir_args_take = [GlobalVarPattern(), TuplePattern(take_args)] + if match_tir_vars: + call_tir_args_take.append(wildcard()) + take = is_op("relax.call_tir")(*call_tir_args_take, add_constraint=False) + + annotations = { + "take": take, + "decode": decode, + "indices": indices, + } + + return take, annotations, pattern_check + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") +class FuseDecodeTake: + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + for n_aux_tensor in [2, 3]: + for match_tir_vars in [False, True]: + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *decode_take_pattern(n_aux_tensor, match_tir_vars), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + + for gv, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + if "fused_decode" not in gv.name_hint or "take" not in gv.name_hint: + continue + + downcasted_mod = tir.transform.ForceNarrowIndexToInt32()(tvm.IRModule({"main": func}))[ + "main" + ] + sch = tir.Schedule(downcasted_mod) + sch.compute_inline("decode") + mod[gv] = sch.mod["main"] + + return mod diff --git a/mlc_llm/transform/decode_transpose.py b/mlc_llm/transform/decode_transpose.py new file mode 100644 index 0000000..be5dccd --- /dev/null +++ b/mlc_llm/transform/decode_transpose.py @@ -0,0 +1,113 @@ +"""Fusing and inlining transpose function into decode function.""" +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTranspose") +class FuseDecodeTranspose: + def __init__(self, skip_gemm=True) -> None: + self.skip_gemm = skip_gemm + + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + @mutator + class DecodeTransposeFusor(PyExprMutator): + def __init__(self, mod: IRModule, skip_gemm=True): + super().__init__(mod) + self.mod = mod + self.skip_gemm = skip_gemm + + def transform(self) -> IRModule: + for gv, func in self.mod.functions.items(): + if not isinstance(func, relax.Function): + continue + + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(gv, updated_func) + + return self.builder_.get() + + def visit_call_(self, call: relax.Call) -> relax.Expr: + call = self.visit_expr_post_order(call) + + if call.op != tvm.ir.Op.get("relax.matmul"): + return call + + # Do not fuse decode-transpose for GeMM + if self.skip_gemm and ( + call.args[0].struct_info.ndim < 2 + or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) + or call.args[0].struct_info.shape[-2].value != 1 + ): + return call + + matmul_rhs = self.lookup_binding(call.args[1]) + if ( + not isinstance(matmul_rhs, relax.Call) + or matmul_rhs.op != tvm.ir.Op.get("relax.permute_dims") + or matmul_rhs.args[0].struct_info.ndim != 2 + or matmul_rhs.attrs.axes is not None + ): + return call + + transpose_input = self.lookup_binding(matmul_rhs.args[0]) + if ( + not isinstance(transpose_input, relax.Call) + or transpose_input.op != tvm.ir.Op.get("relax.call_tir") + or not transpose_input.args[0].name_hint.startswith("decode") + or not isinstance( + transpose_input.struct_info, relax.TensorStructInfo + ) + ): + return call + + decode_tir_func = self.mod[transpose_input.args[0]] + assert isinstance(decode_tir_func, tir.PrimFunc) + if ( + len(decode_tir_func.body.block.alloc_buffers) != 1 + or not isinstance(decode_tir_func.body.block.body, tir.SeqStmt) + or len(decode_tir_func.body.block.body) != 2 + or not isinstance(decode_tir_func.body.block.body[1], tir.For) + or not isinstance( + decode_tir_func.body.block.body[1].body.body, tir.BlockRealize + ) + or decode_tir_func.body.block.body[1].body.body.block.name_hint + != "T_transpose" + ): + return call + + new_func_buffers = [ + decode_tir_func.buffer_map[var] for var in decode_tir_func.params + ] + new_func_buffers[-1] = decode_tir_func.body.block.alloc_buffers[0] + new_func = tir.PrimFunc( + params=new_func_buffers, + body=tir.BlockRealize( + iter_values=[], + predicate=True, + block=tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=decode_tir_func.body.block.body[0], + ), + ), + ) + # Call `renew_defs` for deep-copy to avoid IR node duplication in + # different PrimFuncs of an IRModule. + new_func = tir.stmt_functor.renew_defs(new_func) + gv = self.builder_.add_func(new_func, func_name="decode") + decoded_matmul_rhs = self.builder_.emit( + relax.call_tir( + gv, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info + ) + ) + return relax.op.matmul( + call.args[0], decoded_matmul_rhs, out_dtype=call.attrs.out_dtype + ) + + return DecodeTransposeFusor(mod, self.skip_gemm).transform() diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py new file mode 100644 index 0000000..ed19a70 --- /dev/null +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -0,0 +1,284 @@ +import tvm +from tvm import relax +from tvm.relax.dpl import ( + PatternContext, + is_op, + rewrite_bindings, + wildcard, + is_tuple_get_item, + GlobalVarPattern, + TuplePattern, + is_shape, +) +from tvm.script import relax as R, tir as T + + +def get_dynamic_split_rotary(): + """Implementation of R.split(rotary_embedding(fused_qkv)) + + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) + def split_rotary( + fused_qkv_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + rotary_offset: T.int64, + batch_size: T.int64, + seq_len: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, + ): + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [batch_size, seq_len, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + + for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) + pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) + + inv_freq: T.float32 = T.float32(1) / T.pow( + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), + ) + freq: T.float32 = pos * inv_freq + cos_value: T.float16 = T.Cast("float16", T.cos(freq)) + sin_value: T.float16 = T.Cast("float16", T.sin(freq)) + + input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] + * T.float16(-1), + Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], + ) + if head_num < num_query_heads: + EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + param_sinfo = [] + for param in split_rotary.params: + if param in split_rotary.buffer_map: + buf = split_rotary.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + split_rotary, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) + + return split_rotary + + +def fuse_split_rotary_embedding( + num_query_heads, num_kv_heads, hidden_size, position_embedding_base +): + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary() + + ( + dyn_batch_size, + dyn_seq_len, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-6:] + + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_batch_size: T.int64(1), + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), + } + ) + + mod["split_rotary"] = split_rotary + + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) + + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([batch_size, seq_len, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + + pat_rotary_embedding_gvar = GlobalVarPattern() + + pat_flat_fused_qkv = wildcard() + pat_offset = wildcard() + + # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_value_shape = wildcard() + + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) + + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_query]), + pat_offset, + add_constraint=False, + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_key]), + pat_offset, + add_constraint=False, + ) + + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) + + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + flat_qkv_tuple = matchings[pat_flat_qkv_tuple] + + flat_query = matchings[pat_flat_query] + flat_key = matchings[pat_flat_key] + flat_value = matchings[pat_flat_value] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + # rotary_embedding_offset = bindings[query].args[-1][1] + rotary_embedding_offset = bindings[embedded_query].args[-1][0] + + batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape + _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv,), + out_sinfo=split_rotary_sinfo, + tir_vars=[rotary_embedding_offset], + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + return { + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod + + return ir_module_pass diff --git a/mlc_llm/transform/lift_tir_global_buffer_alloc.py b/mlc_llm/transform/lift_tir_global_buffer_alloc.py new file mode 100644 index 0000000..5805e9f --- /dev/null +++ b/mlc_llm/transform/lift_tir_global_buffer_alloc.py @@ -0,0 +1,197 @@ +"""Lift global buffer allocation in TIR to graph level""" + +from typing import Dict, List, Tuple, Optional + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +def remove_global_buf_alloc( + func: tir.PrimFunc, +) -> Optional[Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]]: + """Remove the global buffer allocation for a given TIR PrimFunc.""" + if not isinstance(func.body, tir.BlockRealize): + return None + + params = list(func.params) + buffer_map = dict(func.buffer_map) + tensor_sinfo = [] + alloc_buffers = [] + + insertion_point = len(params) + while params[insertion_point - 1].dtype != "handle": + insertion_point -= 1 + assert insertion_point >= 1 + + prev_root_block = func.body.block + for buf_alloc in func.body.block.alloc_buffers: + if buf_alloc.scope() == "global": + param = tir.Var("var_" + buf_alloc.name, "handle") + params.insert(insertion_point, param) + insertion_point += 1 + buffer_map[param] = buf_alloc + tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype)) + else: + alloc_buffers.append(buf_alloc) + + if len(tensor_sinfo) == 0: + return None + + assert len(prev_root_block.iter_vars) == 0 + assert len(prev_root_block.reads) == 0 + assert len(prev_root_block.writes) == 0 + assert len(prev_root_block.match_buffers) == 0 + assert prev_root_block.name_hint == "root" + assert prev_root_block.init is None + root_block = tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=prev_root_block.body, + alloc_buffers=alloc_buffers, + annotations=prev_root_block.annotations, + ) + + updated_func = tir.PrimFunc( + params=params, + body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block), + ret_type=func.ret_type, + buffer_map=buffer_map, + attrs=func.attrs, + ) + return updated_func, tensor_sinfo + + +def contain_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool: + assert isinstance(tensor_sinfo.shape, relax.ShapeExpr) + for v in tensor_sinfo.shape.values: + if not isinstance(v, tir.IntImm): + return True + return False + + +def resolve_tir_var_mapping( + func: tir.PrimFunc, call: relax.Call, tensor_sinfo: List[relax.TensorStructInfo] +) -> Tuple[List[relax.TensorStructInfo], bool]: + """Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function""" + var_map: Dict[tir.Var, tir.PrimExpr] = dict() + + n_arg = len(call.args[1].fields) + for i in range(n_arg): + buffer_shape = func.buffer_map[func.params[i]].shape + arg_shape = call.args[1][i].struct_info.shape.values + assert len(buffer_shape) == len(arg_shape) + for vl, vr in zip(buffer_shape, arg_shape): + if isinstance(vl, tir.Var): + var_map[vl] = vr + elif not isinstance(vl, tir.IntImm): + return [], False + + ret_tensors = call.sinfo_args[0] + ret_tensors = ( + [ret_tensors] + if isinstance(ret_tensors, relax.TensorStructInfo) + else list(ret_tensors.fields) + ) + for i in range(len(ret_tensors)): + buffer_shape = func.buffer_map[func.params[n_arg + i]].shape + ret_tensor_shape = ret_tensors[i].shape.values + assert len(buffer_shape) == len(ret_tensor_shape) + for vl, vr in zip(buffer_shape, ret_tensor_shape): + if isinstance(vl, tir.Var): + var_map[vl] = vr + elif not isinstance(vl, tir.IntImm): + return [], False + + updated_tensor_sinfo = [] + for sinfo in tensor_sinfo: + if not contain_symbolic_var(sinfo): + updated_tensor_sinfo.append(sinfo) + continue + + new_shape = [] + for v in sinfo.shape.values: + new_shape.append(tir.stmt_functor.substitute(v, var_map)) + updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype)) + return updated_tensor_sinfo, True + + +def LiftTIRGlobalBufferAlloc(): + @mutator + class TIRGlobalAllocRewriter(PyExprMutator): + def __init__(self, mod: IRModule): + super().__init__(mod) + self.mod = mod + + def transform(self) -> IRModule: + self.mod = self.builder_.get() + for gv, func in self.mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call): + call = self.visit_expr_post_order(call) + if call.op != tvm.ir.Op.get("relax.call_tir"): + return call + + old_gvar = call.args[0] + + func_before_update = self.mod.functions[old_gvar] + updates = remove_global_buf_alloc(func_before_update) + if updates is None: + return call + updated_func, tensor_sinfo = updates + + assert len(call.sinfo_args) == 1 + if any(contain_symbolic_var(sinfo) for sinfo in tensor_sinfo): + tensor_sinfo, success = resolve_tir_var_mapping( + func_before_update, call, tensor_sinfo + ) + if not success: + # Cannot resolve TIR var mapping. Fall back to no lifting. + return call + + new_gvar = self.builder_.add_func(updated_func, old_gvar.name_hint) + new_args = [new_gvar, *call.args[1:]] + + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + new_call = relax.Call( + call.op, + args=new_args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], + attrs=call.attrs, + ) + emitted_tuple = self.builder_.emit(new_call) + return relax.TupleGetItem(emitted_tuple, 0) + elif isinstance(call.sinfo_args[0], relax.TupleStructInfo): + return relax.Call( + call.op, + args=new_args, + sinfo_args=[ + relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo) + ], + attrs=call.attrs, + ) + else: + raise TypeError( + f"Expected {call.op} to return either R.Tensor or R.Tuple, " + f"but instead returned {call.sinfo_args[0]}" + ) + + @tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc.Inner") + def transform_module(mod: IRModule, _: tvm.transform.PassContext) -> IRModule: + return TIRGlobalAllocRewriter(mod).transform() + + return tvm.ir.transform.Sequential( + [ + transform_module, + tvm.relax.transform.DeadCodeElimination(), + ], + name="LiftTIRGlobalBufferAlloc", + ) diff --git a/mlc_llm/transform/reorder_transform_func.py b/mlc_llm/transform/reorder_transform_func.py new file mode 100644 index 0000000..40403c8 --- /dev/null +++ b/mlc_llm/transform/reorder_transform_func.py @@ -0,0 +1,231 @@ +from typing import Callable, Dict, List, Set, Tuple + +import tvm +from tvm import relax +from tvm.ir.module import IRModule + +""" +This pass in this file reorders the bindings of the weight transform function +according to the weight location in binary files. The goal of the reorder is to +reduce the memory pressure when loading the raw model weights and processing +them. In the ideal case, with this pass, the highest CPU memory usage will +around the size of the largest raw weight binary file. + +Regarding the implementation, the bindings of fetching a raw weight in the +weight transform function are all in the form of `lv = params[idx]`. Here, each +index specifies a raw weight tensor, and the raw weight tensor resides in a +binary file on the disk. + +We group such `lv = params[idx]` into multiple groups, such that all raw weight +tensors in a group come from a same binary file. We reorder the bindings +according to the grouping result based on topological sort. + +In ideal case, after reordering the weight transform function has the following +process during execution: +* load a weight binary file, +* process all weights in this file, +* load another weight binary file, +* process all weights in this file, +* ... + +So the maximum CPU memory usage will be the size of the largest raw weight +binary file, since we process and release all the raw weight tensors immediately +after loading them from the file. +""" + + +def analyze_func( + func: relax.Function, + pidx2binname: Dict[int, str], +) -> Tuple[ + List[relax.Binding], + Dict[relax.Var, List[relax.Binding]], + Dict[relax.Binding, int], +]: + """Binding grouping analysis function. + It takes the function to be analyzed, and mapping from each raw tensor index + to the name of the binary file where it resides. + + This analysis function + * computes a new order of weight fetching bindings (the bindings in form + `lv = params[idx]`) based on weight location on disk. + * collects the dataflow def-use information of the given function for + topological sort (particularly, it collects the consumers of each binding + variables and the number of variables each binding depends on). + + Parameters + ---------- + func : relax.Function + The weight transform function to be analyzed. + + pidx2binname : Dict[int, str] + The mapping from each raw tensor index to the name of the binary + file where it resides. + + Returns + ------- + get_param_bindings : List[relax.Binding] + The weight fetching bindings (`lv = params[idx]`) in the new order. + + var_users : Dict[relax.Var, List[relax.Binding]] + The consumer bindings of each binding variable. + Used for topological sort. + + num_depending_vars : Dict[relax.Binding, int] + The number of variables each binding depends on. + Used for topological sort. + """ + + # The mapping of the weight fetching bindings in each binary file. + # Here empty string means the weight is not in any binary file (e.g., cached + # sin and cos values for rotary embeddings). + binname2get_param_bindings: Dict[str, List[relax.Binding]] = {"": []} + # The set of binding variables. + binding_var_set: Set[relax.Var] = set() + var_users: Dict[relax.Var, List[relax.Binding]] = {} + num_depending_vars: Dict[relax.Binding, int] = {} + + # Sanity check on the function pattern. + assert len(func.params) == 1 + assert isinstance(func.body, relax.SeqExpr) + assert len(func.body.blocks) == 1 + assert isinstance(func.body.blocks[0], relax.DataflowBlock) + assert func.body.blocks[0].bindings[-1].var.same_as(func.body.body) + + params = func.params[0] + bindings = func.body.blocks[0].bindings + + # Go through each binding except the last one. (The last one is the output + # binding `gv = (lv, lv1, ...)`) which we ignore for analysis. + for binding in bindings[:-1]: + value = binding.value + binding_var_set.add(binding.var) + var_users[binding.var] = [] + + if isinstance(value, relax.TupleGetItem) and value.tuple_value.same_as(params): + # For weight fetching bindings (`lv = params[idx]`), we group them + # according to the binary file name. + pidx = value.index + if pidx not in pidx2binname: + binname2get_param_bindings[""].append(binding) + continue + + binname = pidx2binname[pidx] + if binname in binname2get_param_bindings: + binname2get_param_bindings[binname].append(binding) + else: + binname2get_param_bindings[binname] = [binding] + else: + # For other bindings, we collect the use-def information for + # topological sort. + num_depending_vars[binding] = 0 + + def fvisit(obj): + if isinstance(obj, relax.Var) and obj in binding_var_set: + assert obj in var_users + var_users[obj].append(binding) + num_depending_vars[binding] += 1 + + relax.analysis.post_order_visit(value, fvisit) + + # Get the weight fetching bindings in new order according to the group results. + get_param_bindings: List[relax.Binding] = [] + for bindings in binname2get_param_bindings.values(): + get_param_bindings += bindings + + return get_param_bindings, var_users, num_depending_vars + + +def reorder_func( + func: relax.Function, + pidx2binname: Dict[int, str], +) -> relax.Function: + """Reorder the bindings of the input weight transform Relax function + according the weight location in binary files. + + This function first analyzes the input function and gets the reordered + weight fetching bindings and the use-def information for topological sort. + It then reorders all bindings in the function with topological sort. + + Parameters + ---------- + func : relax.Function + The weight transform function to be analyzed. + + pidx2binname : Dict[int, str] + The mapping from each raw tensor index to the name of the binary + file where it resides. + + Returns + ------- + func_updated : relax.Function + The returned function where the bindings are updated with the new order. + """ + get_param_bindings, var_users, num_depending_vars = analyze_func(func, pidx2binname) + + # The bindings in the new order, output by the topological sort. + new_bindings: List[relax.Binding] = [] + # The queue used in the topological sort. + binding_queue: List[relax.Binding] = [] + + for binding, n_depending in list(num_depending_vars.items()): + if n_depending == 0: + binding_queue.append(binding) + del num_depending_vars[binding] + + # Start topological sort: + # each time we emit a weight fetching binding, and then adds all bindings + # that depend on it. + for get_param_binding in get_param_bindings: + binding_queue.append(get_param_binding) + + while len(binding_queue) > 0: + binding = binding_queue.pop(0) + new_bindings.append(binding) + for user_binding in var_users[binding.var]: + num_depending_vars[user_binding] -= 1 + if num_depending_vars[user_binding] == 0: + del num_depending_vars[user_binding] + binding_queue.append(user_binding) + + # Add the output binding. + new_bindings.append(func.body.blocks[0].bindings[-1]) + # Sanity check on the integrity. + assert len(new_bindings) == len(func.body.blocks[0].bindings) + assert len(num_depending_vars) == 0 + + return relax.Function( + func.params, + relax.SeqExpr(blocks=[relax.DataflowBlock(new_bindings)], body=func.body.body), + func.ret_struct_info, + func.is_pure, + func.attrs, + ) + + +@tvm.transform.module_pass(opt_level=0, name="ReorderTransformFunc") +class ReorderTransformFunc: + def __init__( + self, + pidx2pname: Dict[int, str], + pname2binname: Dict[str, str], + f_convert_pname_fwd: Callable[[str], List[str]], + ) -> None: + self.pidx2binname: Dict[int, str] = { + pidx: pname2binname[f_convert_pname_fwd(pname)[0]] + for pidx, pname in pidx2pname.items() + if f_convert_pname_fwd(pname)[0] in pname2binname + } + + def transform_module( + self, + mod: IRModule, + ctx: tvm.transform.PassContext, + ) -> IRModule: + mod = mod.clone() + for gv, func in list(mod.functions.items()): + if isinstance(func, relax.Function): + assert gv.name_hint.endswith("transform_params") + func_updated = reorder_func(func, self.pidx2binname) + mod[gv] = func_updated + return mod diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py new file mode 100644 index 0000000..d6d5693 --- /dev/null +++ b/mlc_llm/transform/rewrite_attention.py @@ -0,0 +1,46 @@ +import tvm +from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard +from tvm.script import relax as R + + +def rewrite_attention(use_flash_mqa=False): + @tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention") + def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule: + Q = wildcard() + K = wildcard() + V = wildcard() + + Q_BNSH = is_op("relax.permute_dims")(Q) + + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) + + K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) + + matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) + divide = is_op("relax.divide")(matmul1, is_const()) + max = is_op("relax.maximum")(divide, is_const()) + min = is_op("relax.minimum")(max, wildcard()) + softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) + matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) + + pattern = is_op("relax.permute_dims")(matmul2) + + def callback(_, matchings): + return R.nn.attention( + matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" + ) + + new_module = {} + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.relax.Function): + func = rewrite_call(pattern, callback, func) + new_module[gvar] = func + + return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos) + + return ir_module_transform diff --git a/mlc_llm/transform/transpose_matmul.py b/mlc_llm/transform/transpose_matmul.py new file mode 100644 index 0000000..fd8a9ae --- /dev/null +++ b/mlc_llm/transform/transpose_matmul.py @@ -0,0 +1,349 @@ +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.dpl.pattern import is_op, wildcard + + +@relax.expr_functor.mutator +class TransposeMatmulCodeGenerator(relax.PyExprMutator): + def __init__(self, mod): + super().__init__(mod) + + @staticmethod + def pattern(): + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(x, wT) + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + transpose_call = context.annotated_expr["wT"] + ndim = transpose_call.args[0].struct_info.ndim + if ndim == -1: + return False + if ndim == 2 and transpose_call.attrs.axes is None: + return True + axes = list(range(ndim)) + axes[-1], axes[-2] = axes[-2], axes[-1] + return list(transpose_call.attrs.axes) == axes + + return o, annotations, _check + + def visit_call_(self, call: relax.Call) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(a_relax, bT_relax) + ).struct_info.shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + dim_equal = a_dim == b_dim + if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + else: + a_indices.append(idx_spatial[i]) + b_indices.append(idx_spatial[i]) + + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + b_indices.append(idx_reduce) + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + function.attrs + and "Composite" in function.attrs + and function.attrs["Composite"] == "transpose_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[1], + call.args[0], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) + + +@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") +class FuseTransposeMatmul: + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + mod = relax.transform.FuseOpsByPattern( + [("transpose_matmul_fuse", *TransposeMatmulCodeGenerator.pattern())] + )(mod) + + transpose_matmul_codegen = TransposeMatmulCodeGenerator(mod) + for gv in mod.functions: + func = mod[gv] + if not isinstance(func, relax.Function): + continue + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(gv, func) + + return transpose_matmul_codegen.builder_.get() + +@relax.expr_functor.mutator +class Transpose1MatmulCodeGenerator(relax.PyExprMutator): + def __init__(self, mod): + super().__init__(mod) + + @staticmethod + def pattern(): + w = wildcard() + x = wildcard() + xT = is_op("relax.permute_dims")(x) + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(xT, wT) + annotations = {"o": o, "w": w, "x": x, "xT": xT, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + x_transpose_call = context.annotated_expr["o"] + w_transpose_call = context.annotated_expr["o"] + x_shape = context.annotated_expr["x"].struct_info.shape + w_shape = context.annotated_expr["w"].struct_info.shape + xT_shape = x_transpose_call.args[0].struct_info.shape + wT_shape = w_transpose_call.args[1].struct_info.shape + + if not ( + xT_shape[0] == x_shape[0] and xT_shape[1] == x_shape[2] + and xT_shape[2] == x_shape[1] and xT_shape[3] == x_shape[3] + ): + return False + + if not ( + wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2] + and wT_shape[2] == w_shape[3] and wT_shape[3] == w_shape[1] + ): + return False + + return True + + return o, annotations, _check + + def visit_call_(self, call: relax.Call) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + + aT_shape = list(a.shape) + aT_shape[-2], aT_shape[-3] = aT_shape[-3], aT_shape[-2] + aT_relax = relax.Var("a", relax.TensorStructInfo(aT_shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2], bT_shape[-3] = bT_shape[-3], bT_shape[-1], bT_shape[-2] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(aT_relax, bT_relax) + ).struct_info.shape + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + def multiply_compute(idx_reduce): + a_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce] + b_indices = [idx_spatial[0], idx_spatial[3], idx_spatial[1], idx_reduce] + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose1_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[0], + call.args[1], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) + + +@tvm.transform.module_pass(opt_level=0, name="FuseTranspose1Matmul") +class FuseTranspose1Matmul: + def transform_module( + self, mod: IRModule, ctx: tvm.transform.PassContext + ) -> IRModule: + mod = relax.transform.FuseOpsByPattern( + [("transpose1_matmul_fuse", *Transpose1MatmulCodeGenerator.pattern())] + )(mod) + + transpose_matmul_codegen = Transpose1MatmulCodeGenerator(mod) + for gv in mod.functions: + func = mod[gv] + if not isinstance(func, relax.Function): + continue + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(gv, func) + + return transpose_matmul_codegen.builder_.get() + + +@relax.expr_functor.mutator +class Transpose2MatmulCodeGenerator(relax.PyExprMutator): + def __init__(self, mod): + super().__init__(mod) + + @staticmethod + def pattern(): + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.permute_dims")(is_op("relax.matmul")(x, wT)) + #oT = is_op("relax.permute_dims")(o) + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + w_transpose_call = context.annotated_expr["wT"] + w_shape = w_transpose_call.args[0].struct_info.shape + wT_shape = w_transpose_call.struct_info.shape + oT_call = context.annotated_expr["o"] + o_shape = oT_call.args[0].struct_info.shape + oT_shape = oT_call.struct_info.shape + + if not ( + wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2] + and wT_shape[2] == w_shape[1] and wT_shape[3] == w_shape[3] + ): + return False + + if not ( + oT_shape[0] == o_shape[0] and oT_shape[1] == o_shape[2] + and oT_shape[2] == o_shape[1] and oT_shape[3] == o_shape[3] + ): + return False + + return True + + return o, annotations, _check + + def visit_call_(self, call: relax.Call) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + output_shape = [a_shape[0], b_shape[-2], a_shape[2], a_shape[3]] + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, b_shape[-1]), name="k") + def multiply_compute(idx_reduce): + a_indices = [idx_spatial[0], idx_reduce, idx_spatial[2], idx_spatial[3]] + b_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce] + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose2_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + #NT_output_shape = function.ret_struct_info.shape + return self.builder_.call_te( + te_transposed_matmul, + call.args[0], + call.args[1], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) + + +@tvm.transform.module_pass(opt_level=0, name="FuseTranspose2Matmul") +class FuseTranspose2Matmul: + def transform_module( + self, mod: IRModule, ctx: tvm.transform.PassContext + ) -> IRModule: + mod = relax.transform.FuseOpsByPattern( + [("transpose2_matmul_fuse", *Transpose2MatmulCodeGenerator.pattern())] + )(mod) + + transpose_matmul_codegen = Transpose2MatmulCodeGenerator(mod) + for gv in mod.functions: + func = mod[gv] + if not isinstance(func, relax.Function): + continue + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(gv, func) + + return transpose_matmul_codegen.builder_.get() diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py new file mode 100644 index 0000000..3f2c5de --- /dev/null +++ b/mlc_llm/utils.py @@ -0,0 +1,741 @@ +# pylint: disable=missing-docstring,invalid-name +import argparse +import functools +import json +import math +import os +import shutil +from typing import Any, Dict, List, Optional, Set + +import numpy as np +import tvm +from tvm import relax + +from .quantization import quantization_schemes +from .relax_model import param_manager + +supported_model_types = set( + [ + "llama", + "gpt_neox", + "gpt_bigcode", + "minigpt", + "moss", + "rwkv", + "gptj", + "chatglm", + "mistral", + "stablelm_epoch", + "gpt2", + "qwen" + ] +) + + +def wrap_tqdm_counter(func, **tqdm_kwargs): + # tqdm isn't a hard requirement, so return the original function + # if it isn't available. + try: + from tqdm import tqdm + except ImportError: + return func + + pbar = tqdm(**tqdm_kwargs) + + @functools.wraps(func) + def inner(*args, **kwargs): + pbar.update(1) + return func(*args, **kwargs) + + return inner + + +def argparse_postproc_common(args: argparse.Namespace) -> None: + if hasattr(args, "device_name"): + if args.device_name == "auto": + if tvm.cuda().exist: + args.device_name = "cuda" + elif tvm.metal().exist: + args.device_name = "metal" + elif tvm.vulkan().exist: + args.device_name = "vulkan" + elif tvm.opencl().exist: + args.device_name = "opencl" + else: + raise ValueError("Cannot auto deduce device-name, please set it") + + model_category_override = { + "moss-moon-003-sft": "gptj", + "moss-moon-003-base": "gptj", + "rwkv-": "rwkv", + "rwkv_world": "rwkv_world", + "minigpt": "minigpt", + } + try: + with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: + config = json.load(i_f) + args.model_category = config["model_type"] + model_path_lower = args.model_path.lower() + if "rwkv" in model_path_lower and "world" in model_path_lower: + args.model_category = "rwkv_world" + except Exception: + args.model_category = "" + model = args.model.lower() + if "rwkv" in model and "world" in model: + model = "rwkv_world" + for prefix, override_category in model_category_override.items(): + if model.startswith(prefix): + args.model_category = override_category + break + assert args.model_category is not None + + model_conv_templates = { + "llama-2": "llama-2", + "llama-2-unconstrained": "llama-2-unconstrained", + "tinyllama": "chatml", + "codellama-7b-instruct": "codellama_instruct", + "codellama-13b-instruct": "codellama_instruct", + "codellama-34b-instruct": "codellama_instruct", + "codellama": "codellama_completion", + "gpt2": "gpt2", + "vicuna-": "vicuna_v1.1", + "dolly-": "dolly", + "stablelm-3b-": "stablelm-3b", + "stablelm-": "stablelm", + "redpajama-": "redpajama_chat", + "minigpt": "minigpt", + "moss-moon-003-sft": "moss", + "moss-moon-003-base": "LM", + "gpt-j-": "LM", + "open_llama": "LM", + "rwkv-": "rwkv", + "rwkv_world": "rwkv_world", + "gorilla-": "gorilla", + "guanaco": "guanaco", + "wizardlm-7b": "wizardlm_7b", # first get rid of 7b + "wizardlm-": "vicuna_v1.1", # all others use vicuna template + "wizardmath-": "wizard_coder_or_math", + "wizardcoder-": "wizard_coder_or_math", + "starcoder": "gpt_bigcode", + "starcoder-unconstrained": "gpt_bigcode-unconstrained", + "gpt_bigcode-santacoder": "gpt_bigcode", + "stablecode-completion": "stablecode_completion", + "stablecode-instruct": "stablecode_instruct", + "chatglm2": "glm", + "chatglm3": "glm", + "codegeex2": "glm", + "tinyllama": "chatml", + "openhermes-2.5-mistral": "open_hermes_mistral", + "neuralhermes-2.5-mistral": "neural_hermes_mistral", + "qwen": "qwen" + } + + for prefix, conv_template in model_conv_templates.items(): + if model.startswith(prefix): + args.conv_template = conv_template + break + else: + args.conv_template = f"{args.model_category}_default" + + if args.quantization not in quantization_schemes: + raise ValueError(f'Quantization "{args.quantization}" is not supported.') + + args.quantization = quantization_schemes[args.quantization] + + use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] + + if use_ft_quant and args.num_shards > 1: + # Preprocess is done after sharding for this case. + args.quantization.linear_weight.do_preprocess = False + args.quantization.final_fc_weight.do_preprocess = False + + +def debug_dump_script(mod, name, args: argparse.Namespace, show_meta=True): + """Debug dump mode""" + if not args.debug_dump: + return + dump_path = os.path.join(args.artifact_path, "debug", name) + with open(dump_path, "w", encoding="utf-8") as outfile: + outfile.write(mod.script(show_meta=show_meta)) + print(f"Dump mod to {dump_path}") + + +def debug_dump_benchmark_script( + mod: tvm.ir.IRModule, + name: str, + args: argparse.Namespace, +) -> None: + """Extract model level benchmark workloads from relax model.""" + if not args.debug_dump: + return + + from tvm.dlight.benchmark import ( # pylint: disable=import-error,import-outside-toplevel + extract_all_func_info_from_relax, + ) + + dump_path = os.path.join(args.artifact_path, "debug", name + ".py") + with open(dump_path, "w", encoding="utf-8") as outfile: + outfile.write( + "# Please save this file to dlight_bench/models and add\n" + + f"# `from .{name} import *` to dlight_bench/models/__init__.py\n" + + "from dlight_bench import DlightBench\n" + + "from tvm.script import tir as T\n\n" + ) + + stmt = [] + try: + relax_funcs, _ = extract_all_func_info_from_relax(mod) + except NotImplementedError: + return + tvm_script_prefix = "# from tvm.script import tir as T" + for relax_func_gv in relax_funcs: # pylint: disable=consider-using-dict-items + for prim_func_gv in relax_funcs[relax_func_gv]: + # add global_symbol + func_body = ( + mod[prim_func_gv] + .with_attr("global_symbol", prim_func_gv.name_hint) + .script(name=prim_func_gv.name_hint) + ) + # remove prefix + if func_body.startswith(tvm_script_prefix + "\n"): + func_body = func_body[len(tvm_script_prefix) :] + # print out + outfile.write(func_body + "\n") + # register + stmt.append( + f"DlightBench.register_bench_workload({prim_func_gv.name_hint}, " + f"'{name}', '{prim_func_gv.name_hint}')" + ) + outfile.write("\n" + "\n".join(stmt) + "\n") + print(f"Dump benchmarking script to {dump_path}.") + + +def debug_load_script(name: str, args: argparse.Namespace): + input_path = os.path.join(args.artifact_path, "debug", name) + lib = {"__file__": input_path} + with open(input_path, "rb") as i_f: + exec(compile(i_f.read(), input_path, "exec"), lib, lib) # pylint: disable=exec-used + return lib["Module"] + + +def debug_dump_shader(ex: tvm.relax.Executable, name: str, args: argparse.Namespace): + """Debug dump mode""" + if not args.debug_dump: + return + target_kind = args.target.kind.default_keys[0] + suffix_map = { + "webgpu": ".wgsl", + "cuda": ".cu", + "metal": ".mtl", + "opencl": ".cl", + } + suffix = suffix_map.get(target_kind, ".txt") + dump_path = os.path.join(args.artifact_path, "debug", name + suffix) + source = ex.mod.imported_modules[0].imported_modules[0].get_source() + with open(dump_path, "w", encoding="utf-8") as outfile: + outfile.write(source) + print(f"Dump shader to {dump_path}") + + +def convert_weights( + mod_transform: tvm.IRModule, + param_mgr: param_manager.ParamManager, + model_params: List[Optional[tvm.nd.NDArray]], + args: argparse.Namespace, +): + # Save the number of parameters before we lower mod_transform, so + # we can use them in the progress bar. + transform_func = mod_transform["transform_params"] + num_original_params = len(transform_func.params[0].struct_info.fields) + num_transformed_params = len(transform_func.struct_info.ret.fields) + + # Remove the dataflow block inside the param transform function, + # so that the LazyTransformParams pass can be applied. + mod_transform = relax.transform.ToNonDataflow()(mod_transform) + mod_transform = relax.transform.LazyTransformParams()(mod_transform) + mod_transform = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_transform) + mod_transform = relax.transform.LegalizeOps()(mod_transform) + + debug_dump_script(mod_transform, "mod_convert_weights.py", args) + + target = detect_local_target() + print(f"Automatically using target for weight quantization: {target}") + device = tvm.device(target.kind.default_keys[0]) + + get_item = param_mgr.get_param_get_item( + device, + model_params, + ) + set_item, loaded_params = param_mgr.get_param_set_item() + + get_item = wrap_tqdm_counter( + get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params + ) + set_item = wrap_tqdm_counter( + set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params + ) + + tvm.register_func(func_name="get_item", f=get_item, override=True) + tvm.register_func(func_name="set_item", f=set_item, override=True) + + if target.kind.name != "llvm": + with tvm.target.Target(target): + mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform) + + ex = relax.build(mod_transform, target=target) + vm = relax.vm.VirtualMachine(ex, device) + print("Start computing and quantizing weights... This may take a while.") + vm["transform_params"]() + print("Finish computing and quantizing weights.") + return loaded_params + + +def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded: int = 1) -> None: + from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel + + assert len(params) % num_presharded == 0 + num_weights = len(params) // num_presharded + + meta_data = {} + param_dict = {} + meta_data["ParamSize"] = len(params) + for i, nd in enumerate(params): + if num_presharded == 1: + param_name = f"param_{i}" + else: + expected_worker_id = i // num_weights + orig_param_id = i % num_weights + param_name = f"param_{orig_param_id}_shard-{expected_worker_id+1}-of-{num_presharded}" + + param_dict[param_name] = nd + + total_size_bytes = sum( + math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params + ) + total_size_gb = total_size_bytes / (1024**3) + print(f"Total param size: {total_size_gb} GB") + tvmjs.dump_ndarray_cache( + param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw" + ) + + +def load_params(artifact_path: str, device) -> List[tvm.nd.NDArray]: + from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel + + params, meta = tvmjs.load_ndarray_cache(f"{artifact_path}/params", device) + plist = [] + size = meta["ParamSize"] + for i in range(size): + plist.append(params[f"param_{i}"]) + return plist + + +def load_params_SLM( + model_weight_path: str, device, model_metadata: Dict[str, Any] +) -> List[tvm.nd.NDArray]: + from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel + + params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) + param_names = [param["name"] for param in model_metadata["params"]] + assert len(param_names) == meta["ParamSize"] + + plist = [] + for param_name in param_names: + plist.append(params[param_name]) + return plist + + +def copy_tokenizer(args: argparse.Namespace) -> None: + for filename in os.listdir(args.model_path): + if filename in [ + "tokenizer.model", + "tokenizer.json", + "vocab.json", + "merges.txt", + "added_tokens.json", + "tokenizer_config.json", + ]: + shutil.copy( + os.path.join(args.model_path, filename), + os.path.join(args.artifact_path, "params"), + ) + + # If we have `tokenizer.model` but not `tokenizer.json`, try convert it to + # `tokenizer.json` with `transformers`. + tokenizer_json_path = os.path.join(args.model_path, "tokenizer.json") + tokenizer_model_path = os.path.join(args.model_path, "tokenizer.model") + if os.path.exists(tokenizer_model_path) and (not os.path.exists(tokenizer_json_path)): + print("Attempting to convert `tokenizer.model` to `tokenizer.json`.") + try: + # pylint: disable=import-outside-toplevel + from transformers import AutoTokenizer + + tokenizer_json_save_dest = os.path.join(args.artifact_path, "params/tokenizer.json") + fast_tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True) + fast_tokenizer.backend_tokenizer.save(tokenizer_json_save_dest) + print(f"Succesfully converted `tokenizer.model` to: {tokenizer_json_save_dest}") + except ImportError: + print( + "WARNING: The model has `tokenizer.model` but not `tokenizer.json`. It is" + + "recommended to use `tokenizer.json`, so we try convert it with `transformers`.\n" + + "However, we were unable to import `transformers`, hence skipping this step." + ) + except Exception as error: # pylint: disable=broad-exception-caught + print( + "WARNING: The model has `tokenizer.model` but not `tokenizer.json`. It is" + + "recommended to use `tokenizer.json`, so we try convert it with `transformers`.\n" + + "However, we are skipping this due to an error:\n", + error, + ) + + +def get_tokenizer_files(path) -> List[str]: + tokenizer_set = { + "tokenizer.model", + "tokenizer.json", + "vocab.json", + "merges.txt", + "added_tokens.json", + } + return [x for x in os.listdir(path) if x in tokenizer_set] + + +def _detect_local_metal_host(): + target_triple = tvm._ffi.get_global_func("tvm.codegen.llvm.GetDefaultTargetTriple")() + process_triple = tvm._ffi.get_global_func("tvm.codegen.llvm.GetProcessTriple")() + host_cpu = tvm._ffi.get_global_func("tvm.codegen.llvm.GetHostCPUName")() + print( + f"Host CPU dection:\n Target triple: {target_triple}\n Process triple: {process_triple}\n Host CPU: {host_cpu}" + ) + if target_triple.startswith("x86_64-"): + return tvm.target.Target( + { + "kind": "llvm", + "mtriple": "x86_64-apple-macos", + "mcpu": host_cpu, + } + ) + # should start with "arm64-" + return tvm.target.Target( + { + "kind": "llvm", + "mtriple": "arm64-apple-macos", + "mcpu": host_cpu, + } + ) + + +def _detect_local_metal(): + dev = tvm.metal() + if not dev.exist: + return None + + return tvm.target.Target( + { + "kind": "metal", + "max_shared_memory_per_block": 32768, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": 32, + }, + host=_detect_local_metal_host(), + ) + + +def _detect_local_cuda(): + dev = tvm.cuda() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "cuda", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + "registers_per_block": 65536, + "arch": "sm_" + dev.compute_version.replace(".", ""), + } + ) + + +def _detect_local_rocm(): + dev = tvm.rocm() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "rocm", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + } + ) + + +def _detect_local_vulkan(): + dev = tvm.vulkan() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "vulkan", + "max_threads_per_block": dev.max_threads_per_block, + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "thread_warp_size": dev.warp_size, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_16bit_buffer": 1, + } + ) + + +def _detect_local_opencl(): + dev = tvm.opencl() + if not dev.exist: + return None + return tvm.target.Target("opencl") + + +def detect_local_target(): + for method in [ + _detect_local_metal, + _detect_local_rocm, + _detect_local_cuda, + _detect_local_vulkan, + _detect_local_opencl, + ]: + target = method() + if target is not None: + return target + + print("Failed to detect local GPU, falling back to CPU as a target") + return tvm.target.Target("llvm") + + +def parse_target(args: argparse.Namespace) -> None: + if not hasattr(args, "target"): + return + if args.target == "auto": + target = detect_local_target() + if target.host is None: + target = tvm.target.Target( + target, + host="llvm", # TODO: detect host CPU + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "cuda" or args.target == "cuda-multiarch": + target = _detect_local_cuda() + if target is None: + raise ValueError("Cannot detect local CUDA GPU target!") + multiarch = args.target == "cuda-multiarch" + args.target = target + args.target_kind = args.target.kind.default_keys[0] + if multiarch: + args.target_kind += "-multiarch" + elif args.target.startswith("nvidia/jetson"): + try: + args.target = tvm.target.Target(args.target) + except ValueError: + raise ValueError("Cannot find configuration of given nvidia/jetson board target!") + if not hasattr(args, "cc_path") or args.cc_path == "": + args.cc_path = "/usr/bin/aarch64-linux-gnu-g++" + from tvm.contrib.cc import ( # pylint: disable=import-outside-toplevel + cross_compiler, + ) + + args.export_kwargs = { + "fcompile": cross_compiler( + args.cc_path, + ), + } + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "metal": + target = _detect_local_metal() + if target is None: + print("Cannot detect local Apple Metal GPU target! Falling back...") + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + } + ), + host=_detect_local_metal_host(), + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "metal_x86_64": + from tvm.contrib import xcode # pylint: disable=import-outside-toplevel + + args.target = tvm.target.Target( + tvm.target.Target( + { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + } + ), + host="llvm -mtriple=x86_64-apple-darwin", + ) + args.target_kind = "metal_x86_64" + args.export_kwargs = { + "fcompile": xcode.create_dylib, + "sdk": "macosx", + "arch": "x86_64", + } + args.lib_format = "dylib" + elif args.target in ["iphone", "iphone-dylib", "iphone-tar"]: + from tvm.contrib import tar, xcode # pylint: disable=import-outside-toplevel + + if args.target == "iphone-dylib": + args.export_kwargs = { + "fcompile": xcode.create_dylib, + "sdk": "iphoneos", + "arch": "arm64", + } + args.lib_format = "dylib" + else: + args.export_kwargs = {"fcompile": tar.tar} + args.lib_format = "tar" + args.system_lib = True + args.system_lib_prefix = f"{args.model}_{args.quantization}_".replace("-", "_") + + @tvm.register_func("tvm_callback_metal_compile") + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + } + ), + host="llvm -mtriple=arm64-apple-darwin", + ) + args.target = target + args.target_kind = "iphone" + elif args.target == "vulkan": + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1, + } + ), + host="llvm", + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "opencl": + target = tvm.target.Target( + "opencl", + host="llvm", + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "webgpu": + args.target = tvm.target.Target( + "webgpu", + host="llvm -mtriple=wasm32-unknown-unknown-wasm", + ) + args.target_kind = "webgpu" + args.lib_format = "wasm" + args.system_lib = True + if os.environ.get("TVM_HOME", "") == "": + raise RuntimeError( + "Please set TVM_HOME for webgpu build following scripts/prep_emcc_deps.sh" + ) + elif args.target in ["android", "android-dylib"]: # android-opencl + from tvm.contrib import ndk, tar + + if args.target == "android-dylib": + args.export_kwargs = { + "fcompile": ndk.create_shared, + } + args.lib_format = "so" + else: + args.export_kwargs = { + "fcompile": tar.tar, + } + args.lib_format = "tar" + args.system_lib = True + args.system_lib_prefix = f"{args.model}_{args.quantization}_".replace("-", "_") + args.target = tvm.target.Target( + "opencl", + host="llvm -mtriple=aarch64-linux-android", # TODO: Only support arm64 for now + ) + args.target_kind = "android" + elif args.target in ["mali"]: + if "TVM_NDK_CC" in os.environ: + from tvm.contrib import ndk + + args.export_kwargs = { + "fcompile": ndk.create_shared, + } + target = tvm.target.Target( + "opencl -device=mali", + host="llvm -mtriple=aarch64-linux-gnu", + ) + args.target = target + args.target_kind = "mali" + else: + args.target = tvm.target.Target(args.target, host="llvm") + args.target_kind = args.target.kind.default_keys[0] + + if args.target_kind == "cuda-multiarch": + from tvm.contrib import nvcc + + assert args.target.arch[3:] != "" + arch_list = os.getenv("CUDA_ARCH_LIST") or os.getenv("TORCH_CUDA_ARCH_LIST") + if arch_list: + compute_versions = [int(v) for v in arch_list.replace(" ", ";").split(";")] + elif int(args.target.arch[3:]) >= 70: + compute_versions = [70, 72, 75, 80, 86, 87, 89, 90] + else: + compute_versions = [60, 61, 62] + + args.target_kind = "cuda" + + @tvm.register_func("tvm_callback_cuda_compile", override=True) + def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument + """use nvcc to generate fatbin code for better optimization""" + arch = [] + for compute_version in compute_versions: + arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] + ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch) + return ptx + + # use mingw to cross compile windows + if hasattr(args, "llvm_mingw") and args.llvm_mingw != "": + from tvm.contrib.cc import ( # pylint: disable=import-outside-toplevel + cross_compiler, + ) + + args.export_kwargs = { + "fcompile": cross_compiler( + os.path.join(args.llvm_mingw, "bin", "x86_64-w64-mingw32-clang++"), + output_format="dll", + ), + } + args.target = args.target.with_host("llvm -mtriple=x86_64-w64-windows-gnu") + args.lib_format = "dll" + + print(f"Target configured: {args.target}") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1ffd135 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +[tool.isort] +profile = "black" +src_paths = ["python/mlc_chat"] +known_third_party = ["numpy", "tvm", "tqdm", "torch", "transformers"] + +[tool.black] +line-length = 100 + +[tool.mypy] +ignore_missing_imports = true +show_column_numbers = true +show_error_context = true +follow_imports = "skip" +ignore_errors = false +strict_optional = false + +[tool.pylint.messages_control] +max-line-length = 100 +disable = """ +duplicate-code, +""" diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..a1866ee --- /dev/null +++ b/python/README.md @@ -0,0 +1,5 @@ +# MLC-Chat Python Package + +This folder contains the source code of MLC-Chat python package, +please refer to the [REST API](https://llm.mlc.ai/docs/deploy/rest.html) +and [Python API](https://llm.mlc.ai/docs/deploy/python.html) documentation for usage. diff --git a/python/mlc_chat/__init__.py b/python/mlc_chat/__init__.py new file mode 100644 index 0000000..f577e03 --- /dev/null +++ b/python/mlc_chat/__init__.py @@ -0,0 +1,7 @@ +"""MLC Chat python package. + +MLC Chat is the app runtime of MLC LLM. +""" +from . import protocol, serve +from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig +from .libinfo import __version__ diff --git a/python/mlc_chat/__main__.py b/python/mlc_chat/__main__.py new file mode 100644 index 0000000..8cb80a6 --- /dev/null +++ b/python/mlc_chat/__main__.py @@ -0,0 +1,47 @@ +"""Entrypoint of all CLI commands from MLC LLM""" +import sys + +from mlc_chat.support import logging +from mlc_chat.support.argparse import ArgumentParser + +logging.enable_logging() + + +def main(): + """Entrypoint of all CLI commands from MLC LLM""" + parser = ArgumentParser("MLC LLM Command Line Interface.") + parser.add_argument( + "subcommand", + type=str, + choices=["compile", "convert_weight", "gen_config", "chat", "bench"], + help="Subcommand to to run. (choices: %(choices)s)", + ) + parsed = parser.parse_args(sys.argv[1:2]) + # pylint: disable=import-outside-toplevel + if parsed.subcommand == "compile": + from mlc_chat.cli import compile as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "convert_weight": + from mlc_chat.cli import convert_weight as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "gen_config": + from mlc_chat.cli import gen_config as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "chat": + from mlc_chat.cli import chat as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "bench": + from mlc_chat.cli import bench as cli + + cli.main(sys.argv[2:]) + else: + raise ValueError(f"Unknown subcommand {parsed.subcommand}") + # pylint: enable=import-outside-toplevel + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/_ffi_api.py b/python/mlc_chat/_ffi_api.py new file mode 100644 index 0000000..b0074ad --- /dev/null +++ b/python/mlc_chat/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for mlc_chat""" +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc" prefix. +# e.g. TVM_REGISTER_GLOBAL("mlc.Tokenizer") +tvm._ffi._init_api("mlc", __name__) # pylint: disable=protected-access diff --git a/python/mlc_chat/base.py b/python/mlc_chat/base.py new file mode 100644 index 0000000..13c7ba9 --- /dev/null +++ b/python/mlc_chat/base.py @@ -0,0 +1,28 @@ +"""Load MLC LLM library and _ffi_api functions.""" +import ctypes +import os +import sys + +import tvm +import tvm._ffi.base + +from . import libinfo + +SKIP_LOADING_MLCLLM_SO = os.environ.get("SKIP_LOADING_MLCLLM_SO", "0") + + +def _load_mlc_llm_lib(): + """Load MLC LLM lib""" + if sys.platform.startswith("win32") and sys.version_info >= (3, 8): + for path in libinfo.get_dll_directories(): + os.add_dll_directory(path) + # pylint: disable=protected-access + lib_name = "mlc_llm" if tvm._ffi.base._RUNTIME_ONLY else "mlc_llm_module" + # pylint: enable=protected-access + lib_path = libinfo.find_lib_path(lib_name, optional=False) + return ctypes.CDLL(lib_path[0]), lib_path[0] + + +# only load once here +if SKIP_LOADING_MLCLLM_SO == "0": + _LIB, _LIB_PATH = _load_mlc_llm_lib() diff --git a/python/mlc_chat/callback.py b/python/mlc_chat/callback.py new file mode 100644 index 0000000..bf63c31 --- /dev/null +++ b/python/mlc_chat/callback.py @@ -0,0 +1,141 @@ +"""Namespace of callback functions in Python API.""" +# pylint: disable=unused-import, invalid-name, unnecessary-pass +from queue import Queue +from typing import Optional + + +def _get_delta_message(curr_message: str, new_message: str) -> str: + r"""Given the current message and the new message, compute the delta message + (the newly generated part, the diff of the new message from the current message). + + Parameters + ---------- + curr_message : str + The message generated in the previous round. + new_message : str + The message generated in the new round. + + Returns + ------- + delta_message : str + The diff of the new message from the current message (the newly generated part). + """ + from tvm._ffi import get_global_func # pylint: disable=import-outside-toplevel + + f_get_delta_message = get_global_func("mlc.get_delta_message") + return f_get_delta_message(curr_message, new_message) + + +class DeltaCallback: + """Base class that fetches delta callback""" + + def __init__(self): + r"""Initialize the callback class.""" + self.curr_message = "" + + def __call__(self, message: str = "", stopped: bool = False): + r"""Process newly generated message using callback functions. + + Parameters + ---------- + message : str + The newly generated message. + stopped : bool + Whether generation reaches an end. If True, clear the state of current message. + """ + if stopped: + self.stopped_callback() + self.curr_message = "" + else: + delta = _get_delta_message(self.curr_message, message) + self.curr_message = message + self.delta_callback(delta) + + def delta_callback(self, delta_message: str): + r"""Perform a callback action on the delta message. + This vary depending on the callback method. + + Parameters + ---------- + delta_message : str + The delta message. + """ + raise NotImplementedError + + def stopped_callback(self): + r"""Perform a callback action when we receive a "stop generating" signal. + Can optionally ignore this function if no action need to be done when + generation stops.""" + pass + + +class StreamToStdout(DeltaCallback): + """Stream the output of the chat module to stdout.""" + + def __init__(self, callback_interval: int = 2): + r"""Initialize the callback class with callback interval. + + Parameters + ---------- + callback_interval : int + The refresh rate of the streaming process. + """ + super().__init__() + self.callback_interval = callback_interval + + def delta_callback(self, delta_message: str): + r"""Stream the delta message directly to stdout. + + Parameters + ---------- + delta_message : str + The delta message (the part that has not been streamed to stdout yet). + """ + print(delta_message, end="", flush=True) + + def stopped_callback(self): + r"""Stream an additional '\n' when generation ends.""" + print() + + +class StreamIterator(DeltaCallback): + """Stream the output using an iterator. + A queue stores the delta messages""" + + def __init__(self, callback_interval: int = 2, timeout: Optional[float] = None): + r"""Initialize the callback class with callback interval and queue timeout. + + Parameters + ---------- + callback_interval : int + The refresh rate of the streaming process. + timeout : Optional[float] + Timeout for put and get from the delta messages queue + """ + super().__init__() + self.delta_messages: Queue = Queue() + self.callback_interval = callback_interval + self.timeout = timeout + + def delta_callback(self, delta_message: str): + r"""Stream the delta message to iterator (adding). + + Parameters + ---------- + delta_message : str + The delta message (the part that has not been added to queue yet). + """ + self.delta_messages.put(delta_message, timeout=self.timeout) + + def stopped_callback(self): + """Using None as the stop signal for the iterator""" + self.delta_messages.put(None, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.delta_messages.get(timeout=self.timeout) + if value: + return value + raise StopIteration() diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py new file mode 100644 index 0000000..79de775 --- /dev/null +++ b/python/mlc_chat/chat_module.py @@ -0,0 +1,1234 @@ +"""The Python API for MLC chat.""" + +#! pylint: disable=too-many-lines +import dataclasses +import inspect +import json +import os +import subprocess +import sys +import time +import warnings +from dataclasses import asdict, dataclass, fields +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import tvm +from tvm.runtime import disco # pylint: disable=unused-import + +from mlc_chat.support import logging +from mlc_chat.support.auto_device import detect_device +from mlc_chat.support.config import ConfigBase + +from . import base as _ + +if TYPE_CHECKING: + from mlc_chat.interface.openai_api import ChatMessage + +# pylint: disable=line-too-long +_PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" +# pylint: enable=line-too-long + + +logger = logging.getLogger(__name__) + + +@dataclass +class ConvConfig: # pylint: disable=too-many-instance-attributes + r"""A dataclass that represents user-defined partial configuration for conversation template. + + This is an attribute of :class:`mlc_chat.ChatConfig`, which can then be passed in to the + instantiation of a :class:`mlc_chat.ChatModule` instance to override the default + setting in ``mlc-chat-config.json`` under the model folder. Note that we will + first load the predefined template with the name specified in ``conv_template``. + + Since the configuration is partial, everything will be ``Optional``. + + Parameters + ---------- + name : Optional[str] + Name of the conversation. + system : Optional[str] + The prompt encoded before starting the chat. + roles : Optional[List[str]] + An array that describes the role names of the user and the model. These + names are specific to the model being used. + messages : Optional[List[List[str]]] + The chat history represented as an array of string pairs in the following + format: ``[[role_0, msg_0], [role_1, msg_1], ...]``. + offset : Optional[int] + The offset used to begin the chat from the chat history. When offset + is not ``0``, ``messages[0:offset-1]`` will be encoded. + separator_style : Optional[int] + Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). + seps : Optional[List[str]] + An array of strings indicating the separators to be used after a user + message and a model message respectively. + role_msg_sep : Optional[str] + A string indicating the separator between a role and a message. + role_empty_sep : Optional[str] + A string indicating the separator to append to a role when there is no message yet. + stop_str : Optional[str] + When the ``stop_str`` is encountered, the model will stop generating output. + stop_tokens : Optional[List[int]] + A list of token IDs that act as stop tokens. + prefix_tokens : Optional[List[int]] + Token list prefixing the conversation. + add_bos : Optional[bool] + Determines whether a beginning-of-string (bos) token should be added + before the input tokens. + """ + + name: Optional[str] = None + system: Optional[str] = None + roles: Optional[List[str]] = None + messages: Optional[List[List[str]]] = None + offset: Optional[int] = None + separator_style: Optional[int] = None + seps: Optional[List[str]] = None + role_msg_sep: Optional[str] = None + role_empty_sep: Optional[str] = None + stop_str: Optional[str] = None + stop_tokens: Optional[List[int]] = None + prefix_tokens: Optional[List[int]] = None + add_bos: Optional[bool] = None + + def __post_init__(self): + if self.messages is not None and self.offset is None: + self.offset = len(self.messages) + + +@dataclass +class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + r"""A dataclass that represents user-defined partial configuration for the + chat config file. + + An instance of ``ChatConfig`` can be passed in to the instantiation of a + :class:`mlc_chat.ChatModule` instance to override the default setting in + ``mlc-chat-config.json`` under the model folder. + + Since the configuration is partial, everything will be ``Optional``. + + Note that we will exploit this class to also represent ``mlc-chat-config.json`` + during intermediate processing. + + Parameters + ---------- + model_lib : Optional[str] + The necessary model library to launch this model architecture. We recommend + reuse model library when possible. For example, all LLaMA-7B models can + use ``vicuna-v1-7b-{matching quantization scheme}``. So you can distribute + LLaMA-7B weight variants and still use them in prebuilt MLC chat apps. + local_id : Optional[str] + Uniquely identifying the model in application. This is also used by + command line interface app to specify which model to run. + conv_template : Optional[str] + The name of the conversation template that this chat uses. + temperature : Optional[float] + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. + repetition_penalty : Optional[float] + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). + top_p : Optional[float] + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. + mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. + max_gen_len : Optional[int] + The maximum number of tokens to be generated in each round. Would simply + stop generating after this number is exceeded. + shift_fill_factor : Optional[float] + The fraction of maximum window size to shift when it is exceeded. + tokenizer_files : Optional[List[str]] + List of tokenizer files of the model. + conv_config : Optional[ConvConfig] + The partial overriding configuration for conversation template. Will first + load the predefined template with the name specified in ``conv_template`` + and then override some of the configurations specified in ``conv_config``. + model_category : Optional[str] + The category of the model's architecture (e.g. ``llama``, ``gpt_neox``, ``rwkv``). + model_name : Optional[str] + Name of the model (e.g. ``Llama-2-7b-chat-hf``). + tensor_parallel_shards : Optional[str] + Tensor parallel degree. + use_presharded_weights : Optional[bool] + If True, the weights were saved with sharding already applied. + context_window_size : Optional[int] + Maximum kv cache window size. + prefill_chunk_size: Optional[int] + (Experimental) The chunk size during prefilling. By default, + the chunk size is the same as sliding window or max sequence length. + This flag subjects to future refactoring. + attention_sink_size : Optional[int] + (Experimental) The number of stored sinks. Only supported on Mistral yet. By default, + the number of sinks is 4. This flag subjects to future refactoring. + sliding_window_size : Optional[int] + (Experimental) The sliding window size in sliding window attention (SWA). + This optional field overrides the `sliding_window_size` in config.json for + those models that use SWA. Currently only useful when compiling Mistral. + This flag subjects to future refactoring. + opt : Optional[str] + Optimization flags. MLC LLM maintains a predefined set of optimization flags, + denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, + and O3 represents extreme optimization that could potentially break the system. + Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. + --opt="cublas_gemm=1;cudagraph=0". + """ + + model_lib: Optional[str] = None + local_id: Optional[str] = None + conv_template: Optional[str] = None + temperature: Optional[float] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + mean_gen_len: Optional[int] = None + max_gen_len: Optional[int] = None + shift_fill_factor: Optional[float] = None + tokenizer_files: Optional[List[str]] = None + conv_config: Optional[ConvConfig] = None + model_category: Optional[str] = None + model_name: Optional[str] = None + tensor_parallel_shards: Optional[int] = None + use_presharded_weights: Optional[bool] = None + context_window_size: Optional[int] = None + sliding_window_size: Optional[int] = None + prefill_chunk_size: Optional[int] = None + attention_sink_size: Optional[int] = None + max_batch_size: Optional[int] = None + opt: Optional[str] = None + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + @classmethod + def _from_json(cls, json_obj: dict): + return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class GenerationConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + r"""A dataclass that represents user-defined generation configuration. + + An instance of ``GenerationConfig`` can be passed in to the generate function + of a :class:`mlc_chat.ChatModule` instance to override the default generation + setting in ``mlc-chat-config.json`` and ``ChatConfig`` under the model folder. + + Once the generation ends, ``GenerationConfig`` is discarded, since the values + will only override the ``ChatConfig`` generation settings during one generation, + unless it is recurrently passed to generate function. This allows changing generation + settings over time, without overriding ``ChatConfig`` permanently. + + Since the configuraiton is partial, everything will be ``Optional``. + + Parameters + ---------- + temperature : Optional[float] + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. + presence_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood + to talk about new topics. Negative values can increase the likelihood of + repetition. + frequency_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. Negative values can increase the likelihood of + repetition. + repetition_penalty : Optional[float] + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. Only + Active when presence_penalty and frequency_penalty are both 0.0. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). + top_p : Optional[float] + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. + mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. + max_gen_len : Optional[int] + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. + n : Optional[int] + This parameter determines the number of text samples to generate. The default + value is ``1``. Note that this parameter is only used when ``stream`` is set to + ``False``. + stop : Optional[Union[str, List[str]]] + When ``stop`` is encountered, the model will stop generating output. + It can be a string or a list of strings. If it is a list of strings, the model + will stop generating output when any of the strings in the list is encountered. + Note that this parameter does not override the default stop string of the model. + """ + + temperature: Optional[float] = None + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + mean_gen_len: Optional[int] = None + max_gen_len: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + n: Optional[int] = None # pylint: disable=invalid-name + stop: Optional[Union[str, List[str]]] = None + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + @classmethod + def _from_chat_config(cls, chat_config_obj: ChatConfig): + return cls( + **{ + f.name: getattr(chat_config_obj, f.name) + for f in fields(chat_config_obj) + if f.name in inspect.signature(cls).parameters + } + ) + + +class PlaceInPrompt(Enum): + """The place of an input message in a prompt.""" + + # The input message should have role names and corresponding seperators appended both prior to + # it and after it, making it a complete prompt. + All = 0 # pylint: disable=invalid-name + # The input message is only the beginning part of a prompt, no role name and separator should + # be appended after the message since there will be future messages appended after the message. + Begin = 1 # pylint: disable=invalid-name + # The input message is in the middle of a prompt, nothing should be appended before or after + # the message. + Middle = 2 # pylint: disable=invalid-name + # The input message is the ending part of a prompt, no role name and separator should be + # appended prior to it since the message is concatenated to some prior messages. + End = 3 # pylint: disable=invalid-name + + +def _get_model_path(model: str) -> Tuple[str, str]: + """Use user-provided argument ``model`` to search for a valid model path. + + We define "valid" as having an ``mlc-chat-config.json`` right under the folder. + + Parameters + ---------- + model : str + User's input; may be a compiled model's name, or a full path. + + Returns + ------ + model_path : str + A "valid" path to model folder, with ``os.isfile(os.path.join(model_path, + "mlc-chat-config.json"))`` being ``True``. + chat_file : str + Essentially ``os.path.join(model_path, "mlc-chat-config.json")``. + + Raises + ------ + FileNotFoundError: if we cannot find a valid `model_path`. + """ + if model.startswith("HF://"): + from mlc_chat.support.download import ( # pylint: disable=import-outside-toplevel + download_mlc_weights, + ) + + logger.info("Downloading model from HuggingFace: %s", model) + mlc_dir = download_mlc_weights(model) + cfg_dir = mlc_dir / "mlc-chat-config.json" + return str(mlc_dir), str(cfg_dir) + + # Note that the order of this list corresponds to our search priority + candidate_paths = [ + f"{model}", # full path, or just the name + f"dist/prebuilt/{model}", # Using prebuilt workflow + f"dist/{model}/params", # Default directory after mlc_llm.build_model() + f"dist/prebuilt/mlc-chat-{model}", # Also prebuilt workflow, but missed prefix + ] + + # Look for the first folder that has `mlc-chat-config.json` under it + for candidate in candidate_paths: + chat_file = os.path.join(candidate, "mlc-chat-config.json") + if os.path.isfile(chat_file): + logger.info("Using model folder: %s", os.path.abspath(candidate)) + logger.info("Using mlc chat config: %s", os.path.abspath(chat_file)) + return candidate, chat_file + + # Failed to find a valid model_path, analyzing error for user + + # First see if any candidate path is an actual folder + found_folder = False + valid_dir_str = "" + for candidate in candidate_paths: + if os.path.isdir(candidate): + valid_dir_str += f"- {os.path.abspath(candidate)}\n" + found_folder = True + + if found_folder: + # Error 1: there is a folder, but not an mlc-llm model folder (E1) + raise FileNotFoundError( + "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n" + "Specifically, we cannot find `mlc-chat-config.json`, a required file. You should " + "provide a path that contains the file.\n" + "According to your input `model`, we looked at folder(s):\n" + f"{valid_dir_str}" + "MLC-Chat consumes models that are processed by the MLC-LLM build process.\n" + f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on " + "how to load a model." + ) + # Error 2: cannot find a folder (E0) + all_paths_str = "".join(f"- {path}\n" for path in candidate_paths) + raise FileNotFoundError( + "Cannot find the model folder. We searched over the following possible paths:\n" + f"{all_paths_str}" + "You can try to pass in `model=/path/to/your-model-path`, and confirm " + "that it contains `mlc-chat-config.json`, among other essential files.\n" + f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an " + "example on how to load a model." + ) + + +def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfig]) -> ChatConfig: + """Read in the config file in model path, then potentially override with user input. + + Parameters + ---------- + config_file_path : str + ``chat_file`` returned by ``_get_model_path()``. + user_chat_config : Optional[ChatConfig] + User's input, a partial ``ChatConfig`` to override the one in ``config_file_path``. + + Returns + ------ + final_chat_config : ChatConfig + ``ChatConfig`` corresponding to ``config_file_path``, overriden by ``user_chat_config``. + """ + final_chat_config = None + with open(config_file_path, mode="rt", encoding="utf-8") as file: + json_object = json.load(file) + final_chat_config = ChatConfig._from_json(json_object) # pylint: disable=protected-access + if user_chat_config is not None: + # We override using user's chat config + for field in fields(user_chat_config): + field_name = field.name + field_value = getattr(user_chat_config, field_name) + if field_value is not None: + if field_name == "model_lib": + warn_msg = ( + 'WARNING: Do not override "model_lib" in ChatConfig. ' + "This override will be ignored. Please use ChatModule.model_lib_path to " + "override the full model library path instead." + ) + warnings.warn(warn_msg) + else: + setattr(final_chat_config, field_name, field_value) + return final_chat_config + + +def _get_generation_config( + user_chat_config: ChatConfig, user_generation_config: Optional[GenerationConfig] +) -> GenerationConfig: + """Read in the config file in model path, then potentially override with user input. + + Parameters + ---------- + user_chat_config : ChatConfig + ``ChatConfig`` that contain the generation settings to be overriden. + user_generation_config : Optional[GenerationConfig] + User's input, a partial ``GenerationConfig`` to override the ``ChatConfig``. + + Returns + ------ + final_generation_config : GenerationConfig + ``GenerationConfig`` corresponding to ``user_chat_config``, overriden by + ``user_generation_config``. + """ + # pylint: disable=protected-access + final_generation_config = GenerationConfig._from_chat_config(user_chat_config) + # pylint: enable=protected-access + if user_generation_config is not None: + # We override using user's chat config + for field in fields(user_generation_config): + field_name = field.name + field_value = getattr(user_generation_config, field_name) + if field_value is not None: + setattr(final_generation_config, field_name, field_value) + return final_generation_config + + +def _get_lib_module_path( # pylint: disable=too-many-arguments + model: str, + model_path: str, + chat_config: ChatConfig, + model_lib_path: Optional[str], + device_name: str, + config_file_path: str, +) -> str: + """Look up the model library. Then return a corresponding ``tvm`` runtime Module. + + Parameters + ---------- + model : str + User's input; may be a compiled model's name, or a full path. + model_path : str + Model path found by `_get_model_path`. + chat_config : ChatConfig + Chat config after potential overrides. Returned by ``_get_chat_config``. + model_lib_path : Optional[str] + User's input. Supposedly a full path to model library. Prioritized to use. + device_name : str + User's input. Used to construct the library model file name. + config_file_path : str + The path to ``mlc-chat-config.json``. Used for error message making. + + Returns + ------- + model_lib_path : str + The path pointing to the model library we find. + + Raises + ------ + FileNotFoundError: if we cannot find a valid model library file. + """ + # 1. Use user's model_lib_path if provided + if model_lib_path is not None: + if os.path.isfile(model_lib_path): + logger.info("Using library model: %s", model_lib_path) + return model_lib_path + raise FileNotFoundError( + f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\n" + f"Please refer to {_PYTHON_GET_STARTED_TUTORIAL_URL} as tutorial on model loading." + ) + + # 2. Generate all possible file names according to OS + candidate_lib_names = [] + if sys.platform.startswith("linux"): + candidate_lib_names = [f"{chat_config.model_lib}-{device_name}.so"] + elif sys.platform.startswith("Darwin"): + # Note that `dylib` comes before `so` since we prioritize `dylib` for MacOS + candidate_lib_names = [ + f"{chat_config.model_lib}-{device_name}.dylib", + f"{chat_config.model_lib}-{device_name}.so", + ] + elif sys.platform.startswith("win32"): + candidate_lib_names = [f"{chat_config.model_lib}-{device_name}.dll"] + else: + candidate_lib_names = [ + f"{chat_config.model_lib}-{device_name}.dylib", + f"{chat_config.model_lib}-{device_name}.so", + f"{chat_config.model_lib}-{device_name}.dll", + ] + + # 3. Generate possible model library paths + candidate_paths = [] + for lib_name in candidate_lib_names: + # Equivalent to {model_path}/../ + pardir_model_path = os.path.abspath(os.path.join(os.path.abspath(model_path), os.pardir)) + candidate_paths.extend( + [ + f"{lib_name}", + f"dist/prebuilt/lib/{lib_name}", # Using prebuilt workflow + f"dist/{model}/{lib_name}", # Default directory after mlc_llm.build_model() + os.path.join(model_path, lib_name), # User put library inside `model_path` + os.path.join(pardir_model_path, lib_name), # Under parent directory of `model_path` + ] + ) + + # 4. Search for model library + for candidate in candidate_paths: + if os.path.isfile(candidate): + logger.info("Using library model: %s", os.path.abspath(candidate)) + return candidate + + # 5. Error + err_msg = ( + f"Cannot find the model library that corresponds to `{chat_config.model_lib}`.\n" + f"`{chat_config.model_lib}` is either provided in the `chat_config` " + f"you passed in, or specified in {config_file_path}.\n" + "We searched over the following possible paths: \n" + ) + for candidate in candidate_paths: + err_msg += f"- {candidate}\n" + err_msg += ( + "If you would like to directly specify the model library path, you may " + "consider passing in the `ChatModule.model_lib_path` parameter.\n" + f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example " + "on how to load a model." + ) + raise FileNotFoundError(err_msg) + + +def _convert_chat_config_to_json_str( + chat_config: Optional[ChatConfig], conv_template: Optional[str] +) -> str: + """Convert user's input ChatConfig to a json string, omitting ``None`` fields. + + Parameters + ---------- + chat_config : Optional[ChatConfig] + User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``. + conv_template : Optional[str] + The ``conv_template`` that will be used after considering potential override. + + Returns + ------ + json_str : str + A JSON string that corresponds to user's ``chat_config`` input. + Returns "" if ``chat_config`` unspecified. + """ + if chat_config is None: + return "" + # Current logic does not allow partial ChatConfig without specifying the + # conv_template. Hence we use the conv_template after considering potential overrides. + chat_config.conv_template = conv_template + # Only want to keep entries that are not None; otherwise, we would override things to None + assert hasattr(ChatConfig, "conv_config") # in case dataclass attribute name changes + chat_dict = {} + for key, value in asdict(chat_config).items(): + if key == "conv_config" and value is not None: + # conv template is another dict, do the same thing + conv_dict = {} + for conv_k, conv_v in value.items(): + if conv_v is not None: + conv_dict[conv_k] = conv_v + chat_dict[key] = conv_dict + continue + if value is not None: + chat_dict[key] = value + + return json.dumps(chat_dict) + + +def _convert_generation_config_to_json_str(generation_config: Optional[GenerationConfig]) -> str: + """Convert user's input GenerationConfig to a json string. + + Parameters + ---------- + generation_config : Optional[GenerationConfig] + User's input. A partial GenerationConfig for overriding ChatConfig generation settings. + + Returns + ------ + json_str : str + A JSON string that corresponds to user's ``generation_config`` input. + Returns "" if ``generation_config`` unspecified. + """ + if generation_config is None: + return "" + return json.dumps(asdict(generation_config)) + + +def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): + cmd = [ + sys.executable, + "-m", + "mlc_chat.cli.model_metadata", + model_lib_path, + "--memory-only", + "--mlc-chat-config", + config_file_path, + ] + subprocess.run(cmd, check=False) + + +class ChatModule: # pylint: disable=too-many-instance-attributes + r"""The ChatModule for MLC LLM. + + Examples + -------- + + .. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamToStdout + + # Create a ChatModule instance + cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + + # Generate a response for a given prompt + output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + # Print prefill and decode performance statistics + print(f"Statistics: {cm.stats()}\n") + + output = cm.generate( + prompt="How many points did you list out?", + progress_callback=StreamToStdout(callback_interval=2), + ) + + + Parameters + ---------- + model: str + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. + + device : str + The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. If no 'device_id' + is provided, it will be set to 0 by default. + + chat_config : Optional[ChatConfig] + A ``ChatConfig`` instance partially filled. Will be used to override the + ``mlc-chat-config.json``. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over + possible paths. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + device: str = "auto", + chat_config: Optional[ChatConfig] = None, + model_lib_path: Optional[str] = None, + ): + # 0. Get device: + # Retrieve device_name and device_id (if any, default 0) from device arg + self.device = detect_device(device) + device_type = self.device.device_type + device_id = self.device.device_id + + self.energy_events = {} + self.generate_counter = 0 + + # 1. Populate chat module and their functions + fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create") + assert fcreate_chat_mod is not None + chat_mod = fcreate_chat_mod(device_type, device_id) + + # chat module related functions + self._reload_func = chat_mod["reload"] + self._unload_func = chat_mod["unload"] + self._prefill_func = chat_mod["prefill"] + self._embed_func = chat_mod["embed"] + self._prefill_with_embed_func = chat_mod["prefill_with_embed"] + self._decode_func = chat_mod["decode"] + self._raw_generate_func = chat_mod["raw_generate"] + self._reset_chat_func = chat_mod["reset_chat"] + self._load_json_override_func = chat_mod["load_json_override"] + self._stopped_func = chat_mod["stopped"] + self._get_message_func = chat_mod["get_message"] + self._runtime_stats_text_func = chat_mod["runtime_stats_text"] + self._verbose_runtime_stats_text_func = chat_mod["verbose_runtime_stats_text"] + self._reset_runtime_stats_func = chat_mod["reset_runtime_stats"] + self._get_config_json_func = chat_mod["get_config_json"] + self._process_system_prompts_func = chat_mod["process_system_prompts"] + self._evaluate_func = chat_mod["evaluate"] + self._get_role0_func = chat_mod["get_role0"] + self._get_role1_func = chat_mod["get_role1"] + + # 2. Look up model_path + self.model_path, self.config_file_path = _get_model_path(model) + + # 3. Instantiate chat_config + self.chat_config = _get_chat_config(self.config_file_path, chat_config) + + # 4. Look up model library + try: + self.model_lib_path = _get_lib_module_path( + model, + self.model_path, + self.chat_config, + model_lib_path, + self.device.MASK2STR[self.device.device_type], + self.config_file_path, + ) + except FileNotFoundError: + logger.info("Model lib not found. Now compiling model lib on device...") + from mlc_chat.interface import ( # pylint: disable=import-outside-toplevel + jit, + ) + + self.model_lib_path = str( + jit.jit( + model_path=Path(self.model_path), + chat_config=asdict(self.chat_config), + device=self.device, + ) + ) + _inspect_model_lib_metadata_memory_usage(self.model_lib_path, self.config_file_path) + + # 5. Call reload + user_chat_config_json_str = _convert_chat_config_to_json_str( + self.chat_config, self.chat_config.conv_template + ) + self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str) + + def generate( + self, + prompt: Union[str, List["ChatMessage"]], + generation_config: Optional[GenerationConfig] = None, + progress_callback=None, + stateless=False, + ) -> Union[str, List[str]]: + r"""A high-level method that returns the full response from the chat module given a user + prompt. User can optionally specify which callback method to use upon receiving the + response. By default, no callback will be applied. + + Parameters + ---------- + prompt: Union[str, List[ChatMessage]] + The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] + generation_config: Optional[GenerationConfig] + The generation config object to override the ChatConfig generation settings. + progress_callback: object + The optional callback method used upon receiving a newly generated message from the + chat module. See `mlc_chat/callback.py` for a full list of available callback classes. + Currently, only streaming to stdout callback method is supported, see `Examples` for + more detailed usage. + + Returns + ------- + output : string + The generated full output from the chat module. + + Examples + -------- + .. code-block:: python + + # Suppose we would like to stream the response of the chat module to stdout + # with a refresh interval of 2. Upon calling generate(), We will see the response of + # the chat module streaming to stdout piece by piece, and in the end we receive the + # full response as a single string `output`. + + from mlc_chat import ChatModule, GenerationConfig, callback + cm = ChatModule(xxx) + prompt = "what's the color of banana?" + output = cm.generate( + prompt, GenerationConfig(temperature=0.8), callback.StreamToStdout(callback_interval=2) + ) + print(output) + """ + new_msgs = [] + num_return_sequences = 1 + return_str = True + if (generation_config is not None) and (generation_config.n is not None): + num_return_sequences = generation_config.n + return_str = False + + for idx in range(num_return_sequences): + if stateless: + self.reset_chat() + self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.start"] = time.time_ns() + self._prefill(prompt, generation_config=generation_config) + self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.end"] = time.time_ns() + + if not progress_callback: + decode_counter = 0 + while not self._stopped(): + self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns() + self._decode(generation_config=generation_config) + self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns() + self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns() + new_msg = self._get_message() + self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns() + new_msgs.append(new_msg) + else: + # apply callback with a rate of callback_interval + i, new_msg = 0, "" + decode_counter = 0 + while not self._stopped(): + self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns() + self._decode(generation_config=generation_config) + self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns() + if i % progress_callback.callback_interval == 0 or self._stopped(): + self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns() + new_msg = self._get_message() + self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns() + progress_callback(new_msg) + i += 1 + progress_callback(stopped=True) + new_msgs.append(new_msg) + return new_msgs[0] if return_str else new_msgs + + def reset_chat(self, chat_config: Optional[ChatConfig] = None): + r"""Reset the chat session, clear all chat history, and potentially + override the original `mlc-chat-config.json`. + + Parameters + ---------- + chat_config : Optional[ChatConfig] + A ``ChatConfig`` instance partially filled. If specified, the chat + module will reload the `mlc-chat-config.json`, and override it with + ``chat_config``, just like in initialization. + + Note + ---- + The model remains the same after :func:`reset_chat`. + To reload module, please either re-initialize a :class:`ChatModule` instance + or use :func:`_reload` instead. + """ + self._reset_chat_func() + if chat_config is not None: + # Redo the overriding + self.chat_config = _get_chat_config(self.config_file_path, chat_config) + user_chat_config_json_str = _convert_chat_config_to_json_str( + chat_config, self.chat_config.conv_template + ) + # Second argument is `partial_update = True` + self._load_json_override_func(user_chat_config_json_str, True) + + def embed_text(self, input: str): # pylint: disable=redefined-builtin + r"""Given a text input, returns its embedding in the LLM. + + Parameters + ---------- + input : str + The user input string. + + Returns + ------- + embedding : tvm.runtime.NDArray + The embedding of the text. + + Note + ---- + This is a high-level method and is only used for retrieving text embeddings. Users are + not supposed to call :func:`generate` after calling this method in the same chat session, + since the input to this method is not prefilled and will cause error. If user needs to + call :func:`generate` later, please call :func:`reset_chat` first. + For a more fine-grained embedding API, see :func:`_embed`. + """ + return self._embed_func(input, PlaceInPrompt.Middle.value) + + def stats(self, verbose=False) -> str: + r"""Get the runtime stats of the encoding step, decoding step (and embedding step if exists) + of the chat module in text form. + + Returns + ------- + stats : str + The runtime stats text. + """ + if verbose: + return self._verbose_runtime_stats_text_func() + return self._runtime_stats_text_func() + + def benchmark_generate(self, prompt: str, generate_length: int) -> str: + r"""Controlled generation with input prompt and fixed number of + generated tokens, ignoring system prompt. For example, + + .. code:: python + + from mlc_chat import ChatModule + + cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + output = cm.benchmark_generate("What's the meaning of life?", generate_length=256) + print(f"Generated text:\n{output}\n") + print(f"Statistics: {cm.stats()}") + + will generate 256 tokens in total based on prompt "What's the meaning + of life?". After generation, you can use `cm.stats()` to print the + generation speed. + + Notes + ----- + 1. This function is typically used in controlled benchmarks. It generates + text without system prompt (i.e., it is pure text generation with no chat + style) and ignores the token stop model(s). + 2. To make the benchmark as accurate as possible, we first do a round of + warmup prefill and decode before text generation. + 3. This function resets the previous performance statistics. + + Parameters + ---------- + prompt : str + The prompt of the text generation. + + generate_length : int + The target length of generation. + + Returns + ------- + output : str + The generated text output. + """ + if generate_length < 0: + raise ValueError( + "The generation length is expected to be non-negative, " + f"while the given length is {generate_length}" + ) + + # warmup run + self.reset_chat() + self._prefill(prompt) + self._decode() + + return self._raw_generate_func(prompt, generate_length) + + def _reload( + self, + lib: str, + model_path: str, + app_config_json: str = "", + ): + r"""Reload the chat module from the given library and model path. + + Parameters + ---------- + lib : str + The library path. + model_path : str + The model path. + app_config_json: str + The partial config that is used to partially override the model configuration. + """ + self.energy_events[f"load_model.start"] = time.time_ns() + self._reload_func(lib, model_path, app_config_json) + self.energy_events[f"load_model.end"] = time.time_ns() + + def _unload(self): + r"""Unload the chat module and clear memory of all loaded models.""" + self.energy_events[f"unload_model.start"] = time.time_ns() + self._unload_func() + self.energy_events[f"unload_model.end"] = time.time_ns() + + def _prefill( + self, + input: Union[str, List["ChatMessage"]], # pylint: disable=redefined-builtin + decode_next_token: bool = True, + place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, + ): + r"""Run prefill stage for a given input and optionally decode the first output token. + User can decide where to place the input in the prompt. + + Parameters + ---------- + input : Union[str, List[ChatMessage]] + The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] + decode_next_token : bool + Whether to decode the next token after prefilling. + place_in_prompt: PlaceInPrompt + The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. + """ + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + if isinstance(input, list): + # Populate conversation.messages using load_json_override + if len(input) > 1: + conv_config = json.loads(self._get_config_json())["conv_config"] + messages = [] + role0 = self._get_role_0() + role1 = self._get_role_1() + for _, msg in enumerate(input[:-1]): + role = msg.role + content = msg.content + if role in ("user", "system"): + messages.append([role0, content]) + elif role == "assistant": + messages.append([role1, content]) + else: + raise ValueError("Only user and assistant roles are supported.") + if not input[-1].role == "user": + raise ValueError("Last message should be from user.") + conv_config["messages"] = messages + conv_config["offset"] = 0 + # Otherwise, the offset will be set to the length of the conversation, + # which means history will be retained even after calling reset_chat + self._load_json_override( + json.dumps({"conv_config": conv_config}), + partial_update=True, + ) + input_str = input[-1].content + else: + input_str = input + + self._prefill_func( + input_str, decode_next_token, place_in_prompt.value, generation_config_str + ) + + def _embed( + self, + input: str, # pylint: disable=redefined-builtin + place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, + ): + r"""A more fine-grained embedding API. Given a text input, get the embedding of the + tokenized prompt. User can decide where to place the input in the prompt. This functionality + usually aids the subsequent call to :func:`_prefill_with_embed`. + + Parameters + ---------- + input : str + The user input string. + place_in_prompt: PlaceInPrompt + The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. + + Returns + ------- + embedding : tvm.runtime.NDArray + The embedding of the text. + """ + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + return self._embed_func(input, place_in_prompt.value, generation_config_str) + + def _prefill_with_embed( + self, + embedding: tvm.runtime.NDArray, + decode_next_token: bool = True, + generation_config: Optional[GenerationConfig] = None, + ): + r"""Given an embedding, run the prefill stage and optionally decode the first output token. + + Parameters + ---------- + embedding : tvm.runtime.NDArray + The embedding of user input. + decode_next_token : bool + Whether to decode the next token after prefilling. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. + """ + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + self._prefill_with_embed_func(embedding, decode_next_token, generation_config_str) + + def _decode(self, generation_config: Optional[GenerationConfig] = None): + r"""Decode the next token, the decoding result is stored in a buffer and + can be retrieved by :func:`get_message`. + + Parameters + ---------- + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. + """ + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + self._decode_func(generation_config_str) + + def _stopped(self) -> bool: + r"""Check if the stop condition is met for the current round. + + Returns + ------- + stopped : bool + """ + return self._stopped_func() != 0 + + def _get_message(self) -> str: + r"""Get the output message in the current round. + + Returns + ------- + message : str + + Note + ---- + This function returns the message that corresponds to + all the tokens decoded so far. + """ + return self._get_message_func() + + def _get_config_json(self): + r"""Get the configuration of the chat module in a single json string. + + Returns + ------- + config : str + The config json string. + """ + return self._get_config_json_func() + + def _load_json_override(self, config_str: str, partial_update: bool = False): + r"""Load JSON config and override existing configurations for the chat module. + + Parameters + ---------- + config_str : str + A json config string that partially specifies some of the options. + partial_update : bool + Whether it's a partial update or full update. If set to true, we perform a partial + update on some of the provided options; if set to false, all options must be provided. + """ + self._load_json_override_func(config_str, partial_update) + + def _get_role_0(self): + r"""Get the name of role 0 in the conversation. + + Returns + ------- + name : str + The name of role 0. + """ + return self._get_role0_func() + + def _get_role_1(self): + r"""Get the name of role 1 in the conversation. + + Returns + ------- + name : str + The name of role 1. + """ + return self._get_role1_func() + + def _reset_runtime_stats(self): + r"""Reset the runtime stats, clear all performance history.""" + self._reset_runtime_stats_func() + + def _process_system_prompts(self): + r"""Pre-process by prefilling the system prompts, running prior to any user input.""" + self.energy_events["prompt.system.start"] = time.time_ns() + self._process_system_prompts_func() + self.energy_events["prompt.system.end"] = time.time_ns() diff --git a/python/mlc_chat/cli/__init__.py b/python/mlc_chat/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/cli/bench.py b/python/mlc_chat/cli/bench.py new file mode 100644 index 0000000..4b9af7c --- /dev/null +++ b/python/mlc_chat/cli/bench.py @@ -0,0 +1,62 @@ +"""Command line entrypoint of benchmark.""" +from mlc_chat.help import HELP +from mlc_chat.interface.bench import bench +from mlc_chat.interface.chat import ChatConfigOverride +from mlc_chat.support.argparse import ArgumentParser + + +def main(argv): + """Parse command line arguments and call `mlc_llm.interface.bench`.""" + parser = ArgumentParser("MLC LLM Chat CLI") + + parser.add_argument( + "model", + type=str, + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--prompt", + type=str, + default="What is the meaning of life?", + help=HELP["prompt"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--opt", + type=str, + default="O2", + help=HELP["opt"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_deploy"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--overrides", + type=ChatConfigOverride.from_str, + default="", + help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--generate-length", + type=int, + default=256, + help=HELP["generate_length"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--model-lib-path", + type=str, + default=None, + help=HELP["model_lib_path"] + ' (default: "%(default)s")', + ) + parsed = parser.parse_args(argv) + bench( + model=parsed.model, + prompt=parsed.prompt, + device=parsed.device, + opt=parsed.opt, + overrides=parsed.overrides, + generate_length=parsed.generate_length, + model_lib_path=parsed.model_lib_path, + ) diff --git a/python/mlc_chat/cli/benchmark.py b/python/mlc_chat/cli/benchmark.py new file mode 100644 index 0000000..e6014aa --- /dev/null +++ b/python/mlc_chat/cli/benchmark.py @@ -0,0 +1,86 @@ +"""A command line tool for benchmarking a chat model.""" +import argparse +from pathlib import Path + +from mlc_chat import ChatConfig, ChatModule + +parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.") +parser.add_argument( + "--model", + type=str, + help="""The model folder after compiling with MLC-LLM build process. The parameter can either + be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a + full path to the model folder. In the former case, we will use the provided name to search for + the model folder over possible paths.""", + required=True, +) +parser.add_argument( + "--model-lib", + type=str, + help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static + library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat, + as the runtime of MLC LLM, depends on the compiled model library to generate tokens. + """, + required=False, +) +parser.add_argument( + "--tensor-parallel-shards", + "--num-shards", + type=int, + help="Number of GPUs to be used.", + dest="tensor_parallel_shards", + required=False, +) +parser.add_argument( + "--device", + type=str, + help="""The description of the device to run on. User should provide a string in the form of + 'device_name:device_id' or 'device_name', where 'device_name' is one of 'cuda', 'metal', + 'vulkan', 'rocm', 'opencl', and 'device_id' is the device id to run on. If no 'device_id' is + provided, it will be set to 0 by default. + """, + required=True, +) +parser.add_argument( + "--prompt", + type=str, + help="The prompt to generate from.", + required=True, +) +parser.add_argument( + "--generate-length", + type=int, + help="The length (numer of tokens) of the generated text.", + required=True, +) + + +def _load_prompt(path_or_prompt: str) -> str: + """Load the prompt from a file or use the provided prompt.""" + try: + path = Path(path_or_prompt) + if path.is_file(): + with path.open("r", encoding="utf-8") as in_file: + return in_file.read() + except: # pylint: disable=bare-except + pass + return path_or_prompt + + +def main(): + """The main function that runs the benchmarking.""" + args = parser.parse_args() + chat_module = ChatModule( + model=args.model, + device=args.device, + chat_config=ChatConfig(tensor_parallel_shards=args.tensor_parallel_shards), + model_lib_path=args.model_lib, + ) + prompt = _load_prompt(args.prompt) + output = chat_module.benchmark_generate(prompt, generate_length=args.generate_length) + print(f"Generated text:\n{output}\n") + print(f"Statistics: {chat_module.stats(verbose=True)}") + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/cli/chat.py b/python/mlc_chat/cli/chat.py new file mode 100644 index 0000000..96edef2 --- /dev/null +++ b/python/mlc_chat/cli/chat.py @@ -0,0 +1,54 @@ +"""Command line entrypoint of chat.""" +from mlc_chat.help import HELP +from mlc_chat.interface.chat import ChatConfigOverride, chat +from mlc_chat.support.argparse import ArgumentParser + + +def main(argv): + """Parse command line arguments and call `mlc_llm.interface.chat`.""" + parser = ArgumentParser("MLC LLM Chat CLI") + + parser.add_argument( + "model", + type=str, + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--opt", + type=str, + default="O2", + help=HELP["opt"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_deploy"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--overrides", + type=ChatConfigOverride.from_str, + default="", + help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--model-lib-path", + type=str, + default=None, + help=HELP["model_lib_path"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--energy-events", + type=str, + default="energy_events.txt", + help="Energy events file to use for energy profiling (default: energy_events.txt)" + ) + parsed = parser.parse_args(argv) + chat( + model=parsed.model, + device=parsed.device, + opt=parsed.opt, + overrides=parsed.overrides, + model_lib_path=parsed.model_lib_path, + energy_events_filename=parsed.energy_events, + ) diff --git a/python/mlc_chat/cli/check_device.py b/python/mlc_chat/cli/check_device.py new file mode 100644 index 0000000..a78fd4d --- /dev/null +++ b/python/mlc_chat/cli/check_device.py @@ -0,0 +1,30 @@ +"""Check if a device exists.""" +import sys + +from tvm.runtime import Device +from tvm.runtime import device as as_device + + +def _check_device(device: Device) -> bool: + try: + return bool(device.exist) + except: # pylint: disable=bare-except + return False + + +def main(): + """Entrypoint for device check.""" + device_str = sys.argv[1] + device_ids = [] + i = 0 + while True: + if _check_device(as_device(device_str, i)): + device_ids.append(i) + i += 1 + else: + break + print(f"check_device:{','.join(str(i) for i in device_ids)}") + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py new file mode 100644 index 0000000..c56b404 --- /dev/null +++ b/python/mlc_chat/cli/compile.py @@ -0,0 +1,142 @@ +"""Command line entrypoint of compilation.""" +import argparse +import json +import re +from functools import partial +from pathlib import Path +from typing import Union + +from mlc_chat.help import HELP +from mlc_chat.interface.compile import ( # pylint: disable=redefined-builtin + ModelConfigOverride, + OptimizationFlags, + compile, +) +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import ( + detect_mlc_chat_config, + detect_model_type, + detect_quantization, +) +from mlc_chat.support.auto_target import ( + detect_system_lib_prefix, + detect_target_and_host, +) + + +def main(argv): + """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if path.is_dir(): + raise argparse.ArgumentTypeError(f"Output cannot be a directory: {path}") + parent = path.parent + if not parent.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") + return path + + def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path: + path = Path(path) + if not auto_create and not path.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {path}") + if auto_create and not path.is_dir(): + path.mkdir(parents=True) + return path + + def _check_system_lib_prefix(prefix: str) -> str: + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" + if prefix == "" or re.match(pattern, prefix): + return prefix + raise argparse.ArgumentTypeError( + "Invalid prefix. It should only consist of " + "numbers (0-9), alphabets (A-Z, a-z) and underscore (_)." + ) + + parser = ArgumentParser("mlc_chat compile") + parser.add_argument( + "model", + type=detect_mlc_chat_config, + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--quantization", + type=str, + choices=list(QUANTIZATION.keys()), + help=HELP["quantization"] + + " (default: look up mlc-chat-config.json, choices: %(choices)s)", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help=HELP["model_type"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_compile"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--host", + type=str, + default="auto", + help=HELP["host"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--opt", + type=OptimizationFlags.from_str, + default="O2", + help=HELP["opt"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--system-lib-prefix", + type=str, + default="auto", + help=HELP["system_lib_prefix"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help=HELP["output_compile"] + " (required)", + ) + parser.add_argument( + "--overrides", + type=ModelConfigOverride.from_str, + default="", + help=HELP["overrides"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--debug-dump", + type=partial(_parse_dir, auto_create=True), + default=None, + help=HELP["debug_dump"] + " (default: %(default)s)", + ) + parsed = parser.parse_args(argv) + target, build_func = detect_target_and_host(parsed.device, parsed.host) + parsed.model_type = detect_model_type(parsed.model_type, parsed.model) + parsed.quantization = detect_quantization(parsed.quantization, parsed.model) + parsed.system_lib_prefix = detect_system_lib_prefix( + parsed.device, parsed.system_lib_prefix, parsed.model_type.name, parsed.quantization.name + ) + with open(parsed.model, "r", encoding="utf-8") as config_file: + config = json.load(config_file) + + compile( + config=config, + quantization=parsed.quantization, + model_type=parsed.model_type, + target=target, + opt=parsed.opt, + build_func=build_func, + system_lib_prefix=parsed.system_lib_prefix, + output=parsed.output, + overrides=parsed.overrides, + debug_dump=parsed.debug_dump, + ) diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_chat/cli/convert_weight.py new file mode 100644 index 0000000..5e97cc7 --- /dev/null +++ b/python/mlc_chat/cli/convert_weight.py @@ -0,0 +1,95 @@ +"""Command line entrypoint of weight conversion.""" +import argparse +from pathlib import Path +from typing import Union + +from mlc_chat.help import HELP +from mlc_chat.interface.convert_weight import convert_weight +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import detect_config, detect_model_type +from mlc_chat.support.auto_device import detect_device +from mlc_chat.support.auto_weight import detect_weight + + +def main(argv): + """Parse command line argumennts and apply quantization.""" + + def _parse_source(path: Union[str, Path], config_path: Path) -> Path: + if path == "auto": + return config_path.parent + path = Path(path) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Model source does not exist: {path}") + return path + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser = ArgumentParser("MLC AutoLLM Quantization Framework") + parser.add_argument( + "config", + type=detect_config, + help=HELP["config"] + " (required)", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=list(QUANTIZATION.keys()), + help=HELP["quantization"] + " (required, choices: %(choices)s)", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help=HELP["model_type"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--device", + default="auto", + type=detect_device, + help=HELP["device_quantize"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--source", + type=str, + default="auto", + help=HELP["source"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--source-format", + type=str, + choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"], + default="auto", + help=HELP["source_format"] + ' (default: "%(default)s", choices: %(choices)s")', + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help=HELP["output_quantize"] + " (required)", + ) + + parsed = parser.parse_args(argv) + parsed.source, parsed.source_format = detect_weight( + weight_path=_parse_source(parsed.source, parsed.config), + config_json_path=parsed.config, + weight_format=parsed.source_format, + ) + model = detect_model_type(parsed.model_type, parsed.config) + convert_weight( + config=parsed.config, + quantization=QUANTIZATION[parsed.quantization], + model=model, + device=parsed.device, + source=parsed.source, + source_format=parsed.source_format, + output=parsed.output, + ) diff --git a/python/mlc_chat/cli/delivery.py b/python/mlc_chat/cli/delivery.py new file mode 100644 index 0000000..cc5fd07 --- /dev/null +++ b/python/mlc_chat/cli/delivery.py @@ -0,0 +1,280 @@ +"""Continuous model delivery for MLC LLM models.""" +import argparse +import dataclasses +import json +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple, Union + +from huggingface_hub import HfApi # pylint: disable=import-error +from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error + +from mlc_chat.support import logging +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.constants import MLC_TEMP_DIR +from mlc_chat.support.download import git_clone +from mlc_chat.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + +GEN_CONFIG_OPTIONAL_ARGS = [ + "context_window_size", + "sliding_window_size", + "prefill_chunk_size", + "attention_sink_size", + "tensor_parallel_shards", +] + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + conv_template: str + quantization: str + source_format: str = "auto" + # If unspecified in CLI, remains to be None and will not be + # passed to `gen_config` or `convert_weight` + context_window_size: int = None + sliding_window_size: int = None + prefill_chunk_size: int = None + attention_sink_size: int = None + tensor_parallel_shards: int = None + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path: + if isinstance(model, Path): + if not model.exists(): + raise ValueError(f"Invalid model source: {model}") + return model + if model.startswith("https://") or model.startswith("git://"): + result = deferred.create_temp_dir() / "repo" + git_clone(model, result, ignore_lfs=False) + return result + result = Path(model) + if result.exists(): + return result + raise ValueError(f"Invalid model source: {model}") + + +def _run_quantization( + model_info: ModelInfo, + repo: str, + api: HfApi, +) -> bool: + logger.info("[HF] Creating repo https://huggingface.co/%s", repo) + try: + api.create_repo(repo_id=repo, private=False) + except HfHubHTTPError as error: + if error.response.status_code != 409: + raise + logger.info("[HF] Repo already exists. Recreating...") + api.delete_repo(repo_id=repo) + api.create_repo(repo_id=repo, private=False) + logger.info("[HF] Repo recreated") + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as output_dir: + log_path = Path(output_dir) / "logs.txt" + with log_path.open("a", encoding="utf-8") as log_file: + assert isinstance(model_info.model, Path) + logger.info("[MLC] Processing in directory: %s", output_dir) + # Required arguments + cmd = [ + sys.executable, + "-m", + "mlc_chat", + "gen_config", + str(model_info.model), + "--quantization", + model_info.quantization, + "--conv-template", + model_info.conv_template, + "--output", + output_dir, + ] + # Optional arguments + for optional_arg in GEN_CONFIG_OPTIONAL_ARGS: + optional_arg_val = getattr(model_info, optional_arg, None) + if optional_arg_val is not None: + # e.g. --context-window-size 4096 + cmd += ["--" + optional_arg.replace("_", "-"), str(optional_arg_val)] + + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + cmd = [ + sys.executable, + "-m", + "mlc_chat", + "convert_weight", + str(model_info.model), + "--quantization", + model_info.quantization, + "--source-format", + model_info.source_format, + "--output", + output_dir, + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Complete!") + if not (Path(output_dir) / "ndarray-cache.json").exists(): + logger.error( + "[%s] Model %s. Quantization %s. No weights metadata found.", + red("FAILED"), + model_info.model_id, + model_info.quantization, + ) + succeeded = False + logger.info("[HF] Uploading to: https://huggingface.co/%s", repo) + for _retry in range(10): + try: + api.upload_folder( + folder_path=output_dir, + repo_id=repo, + commit_message="Initial commit", + ) + except Exception as exc: # pylint: disable=broad-except + logger.error("[%s] %s. Retrying...", red("FAILED"), exc) + else: + break + else: + raise RuntimeError("Failed to upload to HuggingFace Hub with 10 retries") + return succeeded + + +def _main( # pylint: disable=too-many-locals + username: str, + api: HfApi, + spec: Dict[str, Any], +): + failed_cases: List[Tuple[str, str]] = [] + for task_index, task in enumerate(spec["tasks"], 1): + with DeferredScope() as deferred: + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model = _clone_repo(task["model"], deferred) + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info = { + "model_id": task["model_id"], + "model": model, + "conv_template": task["conv_template"], + } + # Process optional arguments + for optional_arg in GEN_CONFIG_OPTIONAL_ARGS: + # e.g. "context_window_size": task.get("context_window_size", None) + model_info[optional_arg] = task.get(optional_arg, None) + if isinstance(quantization, str): + model_info["quantization"] = quantization + else: + model_info["quantization"] = quantization.pop("format") + model_info.update(quantization) + repo = spec.get("destination", "{username}/{model_id}-{quantization}-MLC").format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ) + logger.info( + "%s%s. %s%s. %s%s", + bold("Model: "), + green(task["model_id"]), + bold("Quantization: "), + green(model_info["quantization"]), + bold("Repo: "), + green(f"https://huggingface.co/{repo}"), + ) + with DeferredScope() as inner_deferred: + model_info["model"] = _clone_repo(model_info["model"], inner_deferred) + result = _run_quantization( + ModelInfo(**model_info), + repo=spec["destination"].format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ), + api=api, + ) + if not result: + failed_cases.append( + (task["model_id"], model_info["quantization"]), + ) + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for model_id, quantization in failed_cases: + logger.info(" Model %s. Quantization %s.", model_id, quantization) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous model delivery") + parser.add_argument( + "--username", + type=str, + required=True, + help="HuggingFace username", + ) + parser.add_argument( + "--token", + type=str, + required=True, + help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens", + ) + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + parsed.username, + spec=parsed.spec, + api=HfApi(token=parsed.token), + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/cli/gen_config.py b/python/mlc_chat/cli/gen_config.py new file mode 100644 index 0000000..dd68484 --- /dev/null +++ b/python/mlc_chat/cli/gen_config.py @@ -0,0 +1,106 @@ +"""Command line entrypoint of configuration generation.""" +from pathlib import Path +from typing import Union + +from mlc_chat.help import HELP +from mlc_chat.interface.gen_config import CONV_TEMPLATES, gen_config +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import detect_config, detect_model_type + + +def main(argv): + """Parse command line argumennts and call `mlc_llm.compiler.gen_config`.""" + parser = ArgumentParser("MLC LLM Configuration Generator") + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser.add_argument( + "config", + type=detect_config, + help=HELP["config"] + " (required)", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=list(QUANTIZATION.keys()), + help=HELP["quantization"] + " (required, choices: %(choices)s)", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help=HELP["model_type"] + ' (default: "%(default)s", choices: %(choices)s)', + ) + parser.add_argument( + "--conv-template", + type=str, + required=True, + choices=list(CONV_TEMPLATES), + help=HELP["conv_template"] + " (required, choices: %(choices)s)", + ) + parser.add_argument( + "--context-window-size", + type=int, + default=None, + help=HELP["context_window_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--sliding-window-size", + type=int, + default=None, + help=HELP["sliding_window_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--prefill-chunk-size", + type=int, + default=None, + help=HELP["prefill_chunk_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--attention-sink-size", + type=int, + default=None, + help=HELP["attention_sink_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--tensor-parallel-shards", + type=int, + default=None, + help=HELP["tensor_parallel_shards"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--max-batch-size", + type=int, + default=80, + help=HELP["max_batch_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help=HELP["output_gen_mlc_chat_config"] + " (required)", + ) + parsed = parser.parse_args(argv) + model = detect_model_type(parsed.model_type, parsed.config) + gen_config( + config=parsed.config, + model=model, + quantization=QUANTIZATION[parsed.quantization], + conv_template=parsed.conv_template, + context_window_size=parsed.context_window_size, + sliding_window_size=parsed.sliding_window_size, + prefill_chunk_size=parsed.prefill_chunk_size, + attention_sink_size=parsed.attention_sink_size, + tensor_parallel_shards=parsed.tensor_parallel_shards, + max_batch_size=parsed.max_batch_size, + output=parsed.output, + ) diff --git a/python/mlc_chat/cli/model_metadata.py b/python/mlc_chat/cli/model_metadata.py new file mode 100644 index 0000000..9939476 --- /dev/null +++ b/python/mlc_chat/cli/model_metadata.py @@ -0,0 +1,182 @@ +"""A tool that inspects the metadata of a model lib.""" +import json +import math +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, List, Union + +import numpy as np + +from mlc_chat.support import logging +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def _extract_metadata(model_lib: Path) -> Dict[str, Any]: + # pylint: disable=import-outside-toplevel + from tvm.runtime import device, load_module + from tvm.runtime.relax_vm import VirtualMachine + + # pylint: enable=import-outside-toplevel + + return json.loads(VirtualMachine(load_module(model_lib), device("cpu"))["_metadata"]()) + + +def _report_all(metadata: Dict[str, Any]) -> None: + # Print JSON with aesthetic values that packs each parameter into one line, + # while keeping the rest indented. + indent = 2 + indents = " " * indent + params = metadata.pop("params") + params = indents * 2 + (",\n" + indents * 2).join(json.dumps(p) for p in params) + lines = json.dumps( + metadata, + sort_keys=True, + indent=indent, + ).splitlines() + lines.insert(1, indents + '"params": [\n' + params + "\n" + indents + "],") + beautified_json = "\n".join(lines) + print(beautified_json) + + +def _read_dynamic_shape(shape: List[Union[int, str]], config: Union[Dict, ConfigBase]) -> List[int]: + if isinstance(config, ConfigBase): + config = asdict(config) + param_shape = [] + for s in shape: + if isinstance(s, int): + param_shape.append(s) + else: + if config is None: + logger.error( + "%s: Encountered dynamic shape %s, need to specify `--mlc-chat-config` for " + + "memory usage calculation.", + red("FAILED"), + red(s), + ) + raise AttributeError + if not s in config: + logger.error( + "%s to retrieve concrete %s for dynamic shape from %s.", + red("FAILED"), + red(s), + config, + ) + raise KeyError + param_shape.append(config[s]) + return param_shape + + +def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBase]): + params_bytes = 0.0 + for param in metadata["params"]: + if all(isinstance(v, int) for v in param["shape"]): + assert all(v > 0 for v in param["shape"]), "All shapes should be strictly positive." + param_shape = param["shape"] + else: + # Contains dynamic shape; use config to look up concrete values + param_shape = _read_dynamic_shape(param["shape"], config) + params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize + temp_func_bytes = 0.0 + for _func_name, func_bytes in metadata["memory_usage"].items(): + temp_func_bytes = max(temp_func_bytes, func_bytes) + kv_cache_bytes = metadata["kv_cache_bytes"] + + return params_bytes, temp_func_bytes, kv_cache_bytes + + +def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBase]) -> None: + params_bytes, temp_func_bytes, kv_cache_bytes = _compute_memory_usage(metadata, config) + total_size = params_bytes + temp_func_bytes + kv_cache_bytes + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", + green("Total memory usage"), + total_size / 1024 / 1024, + params_bytes / 1024 / 1024, + kv_cache_bytes / 1024 / 1024, + temp_func_bytes / 1024 / 1024, + ) + + logger.info( + "To reduce memory usage, " + "tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`" + ) + + +def _print_memory_usage_in_json(metadata: Dict[str, Any], config: Dict) -> None: + params_bytes, temp_func_bytes, kv_cache_bytes = _compute_memory_usage(metadata, config) + print( + json.dumps( + { + "params_bytes": params_bytes, + "temp_func_bytes": temp_func_bytes, + "kv_cache_bytes": kv_cache_bytes, + } + ) + ) + + +def main(): + """Entry point for the model metadata tool.""" + parser = ArgumentParser(description="A tool that inspects the metadata of a model lib.") + parser.add_argument( + "model_lib", + type=Path, + help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static + library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat, + as the runtime of MLC LLM, depends on the compiled model library to generate tokens. + """, + ) + parser.add_argument( + "--mlc-chat-config", + type=Path, + help="""The `mlc-chat-config.json` file specific to a model variant. This is only required + when `memory-only` is true and `model_lib` contains a dynamic parameter shape (i.e. using + a variable to represent the shape). For instance, `model.embed_tokens.q_weight` can have + shape `["vocab_size", 512]`. In these cases, we look up the concrete value in + `mlc-chat-config.json`. + """, + ) + parser.add_argument( + "--memory-only", + action="store_true", + help="""If set, only inspect the metadata in memory usage and print richer analysis. + Otherwise, the tool will load all the metadata from the model library file but only print + the basic information in JSON. + """, + ) + parser.add_argument( + "--print-memory-usage-in-json-only", + action="store_true", + help="""If set, only inspect the metadata in memory usage and print usage in raw JSON.""", + ) + parsed = parser.parse_args() + # Load metadata from model lib + try: + metadata = _extract_metadata(parsed.model_lib) + except: # pylint: disable=bare-except + logger.exception("%s to read metadata section in legacy model lib.", red("FAILED")) + return + # Load mlc_chat_config if provided + cfg = None + if parsed.mlc_chat_config: + mlc_chat_config_path = Path(parsed.mlc_chat_config) + if not mlc_chat_config_path.exists(): + raise ValueError(f"{mlc_chat_config_path} does not exist.") + with open(mlc_chat_config_path, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + # Main body + if parsed.print_memory_usage_in_json_only: + _print_memory_usage_in_json(metadata, cfg) + elif parsed.memory_only: + _report_memory_usage(metadata, cfg) + else: + _report_all(metadata) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/cli/worker.py b/python/mlc_chat/cli/worker.py new file mode 100644 index 0000000..5f64e30 --- /dev/null +++ b/python/mlc_chat/cli/worker.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Internal DiscoWorker for Disco ProcessSession.""" +import os +import sys + +from tvm import runtime as _ # pylint: disable=unused-import +from tvm._ffi import get_global_func + +from .. import base # pylint: disable=unused-import, no-name-in-module + + +def main(): + """Main worker function""" + if len(sys.argv) != 5: + print("Usage: ") + return + + worker_id = int(sys.argv[1]) + num_workers = int(sys.argv[2]) + if sys.platform == "win32": + import msvcrt # pylint: disable=import-outside-toplevel,import-error + + reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) + writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + else: + reader = int(sys.argv[3]) + writer = int(sys.argv[4]) + + worker_func = get_global_func("runtime.disco.WorkerProcess") + worker_func(worker_id, num_workers, reader, writer) + + +if __name__ == "__main__": + try: + main() + except (KeyboardInterrupt, IOError): + pass diff --git a/python/mlc_chat/compiler_pass/__init__.py b/python/mlc_chat/compiler_pass/__init__.py new file mode 100644 index 0000000..762ba8c --- /dev/null +++ b/python/mlc_chat/compiler_pass/__init__.py @@ -0,0 +1,2 @@ +"""Compiler passes used in MLC LLM.""" +from . import pipeline as _pipeline diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py new file mode 100644 index 0000000..5850729 --- /dev/null +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -0,0 +1,159 @@ +"""A couple of passes that simply attach additional information onto the IRModule.""" + +from typing import Dict + +import tvm +from tvm import IRModule, relax, tir +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds") +class AttachVariableBounds: # pylint: disable=too-few-public-methods + """Attach variable bounds to each Relax function, which primarily helps with memory planning.""" + + def __init__(self, variable_bounds: Dict[str, int]): + self.variable_bounds = variable_bounds + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachAdditionalPrimFuncs") +class AttachAdditionalPrimFuncs: # pylint: disable=too-few-public-methods + """Attach extra TIR PrimFuncs to the IRModule""" + + def __init__(self, functions: Dict[str, tir.PrimFunc]): + self.functions = functions + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for func_name, func in self.functions.items(): + mod[func_name] = func.with_attr("global_symbol", func_name) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr") +class AttachMemoryPlanAttr: # pylint: disable=too-few-public-methods + """Attach memory planning attribute for dynamic function output planning to Relax functions.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc") +class AttachLogitProcessFunc: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["apply_logit_bias_inplace"] = _apply_logit_bias_inplace + mod["apply_penalty_inplace"] = _apply_penalty_inplace + mod["apply_bitmask_inplace"] = _apply_bitmask_inplace + return mod + + +@T.prim_func +def _apply_logit_bias_inplace( + var_logits: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_logit_bias: T.handle, +) -> None: + """Function that applies logit bias in place.""" + T.func_attr( + {"global_symbol": "apply_logit_bias_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + # seq_ids + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp] + + +@T.prim_func +def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals + var_logits: T.handle, + var_seq_ids: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_token_cnt: T.handle, + var_penalties: T.handle, +) -> None: + """Function that applies penalties in place.""" + T.func_attr( + {"global_symbol": "apply_penalty_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + # Penalties: (presence_penalty, frequency_penalty, repetition_penalty) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( + penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] + ) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0, + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2], + ) + + +@T.prim_func +def _apply_bitmask_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_bitmask: T.handle, +) -> None: + """Function that applies vocabulary masking in place.""" + T.func_attr( + {"global_symbol": "apply_bitmask_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (num_seq, (vocab_size + 31 // 32)), "int32") + + for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + 1023) // 1024, "blockIdx.x"): + for fused_s_v_1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) + vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) + T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) + logits[seq_ids[vs], vv] = T.if_then_else( + (bitmask[vs, vv // 32] >> (vv % 32)) & 1 == 1, + logits[seq_ids[vs], vv], + T.float32(-1e10), + ) diff --git a/python/mlc_chat/compiler_pass/clean_up_tir_attrs.py b/python/mlc_chat/compiler_pass/clean_up_tir_attrs.py new file mode 100644 index 0000000..f7c9ad2 --- /dev/null +++ b/python/mlc_chat/compiler_pass/clean_up_tir_attrs.py @@ -0,0 +1,30 @@ +"""A compiler pass that cleans up undesired TIR attrs.""" +from typing import List + +import tvm +from tvm.ir.module import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs") +class CleanUpTIRAttrs: # pylint: disable=too-few-public-methods + """A compiler pass that cleans up undesired TIR attrs.""" + + def __init__(self, attrs: List[str]): + self.attrs = attrs + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for g_var, func in mod.functions_items(): + changed = False + for attr in self.attrs: + if func.attrs is not None and attr in func.attrs: + func = func.without_attr(attr) + changed = True + break + if changed: + mod[g_var] = func + return mod diff --git a/python/mlc_chat/compiler_pass/cublas_dispatch.py b/python/mlc_chat/compiler_pass/cublas_dispatch.py new file mode 100644 index 0000000..2310486 --- /dev/null +++ b/python/mlc_chat/compiler_pass/cublas_dispatch.py @@ -0,0 +1,31 @@ +"""A compiler pass that dispatches patterns to CUBLAS.""" +import tvm +import tvm.relax.backend.contrib.cublas as _cublas +from tvm import IRModule, relax +from tvm.relax.backend import get_patterns_with_prefix + + +@tvm.transform.module_pass(opt_level=0, name="CublasDispatch") +class CublasDispatch: # pylint: disable=too-few-public-methods,broad-exception-raised + """A compiler pass that dispatches patterns to CUBLAS.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + has_cublas = tvm.get_global_func("relax.ext.cublas", True) + if not has_cublas: + raise Exception("CUBLAS is not enabled.") + + patterns = get_patterns_with_prefix("cublas") + + model_names = [ + gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) + ] + mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=True + ), + relax.transform.RunCodegen({}, entry_functions=model_names), + ] + )(mod) + return mod diff --git a/python/mlc_chat/compiler_pass/estimate_memory_usage.py b/python/mlc_chat/compiler_pass/estimate_memory_usage.py new file mode 100644 index 0000000..f3ac747 --- /dev/null +++ b/python/mlc_chat/compiler_pass/estimate_memory_usage.py @@ -0,0 +1,84 @@ +"""Memory usage estimation analysis function for Relax functions.""" +import json +from typing import Any, Dict + +import tvm +from tvm import relax, tir +from tvm.ir import IRModule, Op +from tvm.relax.expr_functor import PyExprVisitor, visitor + +from mlc_chat.support import logging + +logger = logging.getLogger(__name__) + + +@tvm.transform.module_pass(opt_level=0, name="AttachMetadata") +class AttachMetadataWithMemoryUsage: # pylint: disable=too-few-public-methods + """Attach a Relax function that returns metadata in a JSON string""" + + def __init__(self, metadata: Dict[str, Any]): + self.metadata = metadata + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + + def _emit_metadata(metadata): + bb = relax.BlockBuilder() # pylint: disable=invalid-name + with bb.function("main", params=[]): + bb.emit_func_output(relax.StringImm(json.dumps(metadata))) + return bb.finalize()["main"] + + self.metadata["memory_usage"] = _MemoryEstimator().run(mod) + mod["_metadata"] = _emit_metadata(self.metadata) + return mod + + +@visitor +class _MemoryEstimator(PyExprVisitor): + """The IR visitor which estimates the memory usage of each Relax function.""" + + def __init__(self) -> None: + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self._op_alloc_tensor = Op.get("relax.builtin.alloc_tensor") + self._op_alloc_storage = Op.get("relax.memory.alloc_storage") + + def run(self, mod: IRModule) -> Dict[str, int]: + """Entry point of the visitor.""" + result: Dict[str, int] = {} + for global_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self.visit_expr(func) + result[global_var.name_hint] = self.planned_alloc_mem + logger.info( + "[Memory usage] Function `%s`: %.2f MB", + global_var.name_hint, + self.planned_alloc_mem / 1024 / 1024, + ) + return result + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op == self._op_alloc_tensor: + self._builtin_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value) + elif call.op == self._op_alloc_storage: + self._storage_alloc(size=call.args[0]) + super().visit_call_(call) + + def _builtin_tensor_alloc(self, shape: relax.Expr, dtype_str: str) -> None: + assert isinstance(shape, relax.ShapeExpr) + size = 1 + for dim_len in shape.values: + if not isinstance(dim_len, tvm.tir.IntImm): + return + size *= dim_len.value + dtype = tvm.DataType(dtype_str) + self.planned_mem_num += 1 + self.planned_alloc_mem += size * ((dtype.bits + 7) // 8) * dtype.lanes + + def _storage_alloc(self, size: relax.Expr) -> None: + assert isinstance(size, relax.ShapeExpr) + if isinstance(size.values[0], tir.IntImm): + self.planned_mem_num += 1 + self.planned_alloc_mem += size.values[0].value diff --git a/python/mlc_chat/compiler_pass/fuse_add_norm.py b/python/mlc_chat/compiler_pass/fuse_add_norm.py new file mode 100644 index 0000000..88ed1dc --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_add_norm.py @@ -0,0 +1,211 @@ +"""A compiler pass that fuses add + rms_norm.""" + +import tvm +from tvm import relax +from tvm.relax.dpl import PatternContext, rewrite_bindings +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type" +# pylint: disable=too-many-locals,invalid-name + + +def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int): + inv_hidden_size = T.float32(1.0 / float(hidden_size)) + eps = T.float32(eps) + add_local_size = hidden_size // TX + + @T.prim_func(private=True) + def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int32() + A = T.match_buffer(pA, (batch_size, 1, hidden_size), "float16") + B = T.match_buffer(pB, (batch_size, 1, hidden_size), "float16") + C = T.match_buffer(pC, (hidden_size,), "float16") + O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16") + add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16") + add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local") + sum_shared = T.alloc_buffer((batch_size, 1), scope="shared") + sum_local = T.alloc_buffer((TX, batch_size, 1), scope="local") + for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"): + for v_tx in T.thread_binding( + TX, + thread="threadIdx.x", + annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}, + ): + for i in range(add_local_size): + with T.block("T_add"): + bx = T.axis.spatial(batch_size, v_bx) + h = T.axis.spatial(hidden_size, i * TX + v_tx) + add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h] + with T.block("T_write_back"): + bx = T.axis.spatial(batch_size, v_bx) + v_ax1 = T.axis.spatial(1, 0) + h = T.axis.spatial(hidden_size, i * TX + v_tx) + add[bx, v_ax1, h] = add_local[h // TX] + with T.block("T_multiply_red_rf_init"): + tx, bx = T.axis.remap("SS", [v_tx, v_bx]) + sum_local[tx, bx, 0] = T.float32(0) + for v_i, _j in T.grid(add_local_size, 1): + with T.block("T_multiply_red_rf_update"): + tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i]) + sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i]) + for _j in range(1): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): + with T.block("T_multiply_red"): + tx, bx = T.axis.remap("RS", [v_tx_2, v_bx]) + T.reads(sum_local[tx, bx, 0]) + T.writes(sum_shared[bx, 0]) + with T.init(): + sum_shared[bx, 0] = T.float32(0) + sum_shared[bx, 0] += sum_local[tx, bx, 0] + for i in range(add_local_size): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): + with T.block("T_cast_2"): + bx = T.axis.spatial(batch_size, v_bx) + h = T.axis.spatial(hidden_size, i * TX + v_tx_2) + O[bx, 0, h] = T.float16( + T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps) + * T.float32(add_local[h // TX]) + * T.float32(C[h]) + ) + + return decode_add_rms + + +def _get_add_rms_norm_prefill(hidden_size: int, eps: float, TX: int): + inv_hidden_size = T.float32(1.0 / float(hidden_size)) + eps = T.float32(eps) + add_local_size = hidden_size // TX + + @T.prim_func(private=True) + def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + seq_len = T.int32() + A = T.match_buffer(pA, (1, seq_len, hidden_size), "float16") + B = T.match_buffer(pB, (1, seq_len, hidden_size), "float16") + C = T.match_buffer(pC, (hidden_size,), "float16") + O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16") + add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16") + add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local") + sum_shared = T.alloc_buffer((1, seq_len), scope="shared") + sum_local = T.alloc_buffer((TX, 1, seq_len), scope="local") + for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"): + for v_tx in T.thread_binding( + TX, + thread="threadIdx.x", + annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}, + ): + for v_i in range(add_local_size): + with T.block("T_add"): + bx = T.axis.spatial(seq_len, v_bx) + h = T.axis.spatial(hidden_size, v_i * TX + v_tx) + add_local[h // TX] = A[0, bx, h] + B[0, bx, h] + with T.block("T_write_back"): + bx = T.axis.spatial(seq_len, v_bx) + h = T.axis.spatial(hidden_size, v_i * TX + v_tx) + add[0, bx, h] = add_local[h // TX] + with T.block("T_multiply_red_rf_init"): + tx, bx = T.axis.remap("SS", [v_tx, v_bx]) + sum_local[tx, 0, bx] = T.float32(0) + for v_i, _j in T.grid(add_local_size, 1): + with T.block("T_multiply_red_rf_update"): + tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i]) + sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i]) + for _j in range(1): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): + with T.block("T_multiply_red"): + tx, bx = T.axis.remap("RS", [v_tx_2, v_bx]) + with T.init(): + sum_shared[0, bx] = T.float32(0) + sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx] + for v_i in range(add_local_size): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): + with T.block("T_cast_2"): + bx = T.axis.spatial(seq_len, v_bx) + v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2) + O[0, bx, v1] = T.float16( + T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps) + * T.float32(add_local[v1 // TX]) + * T.float32(C[v1]) + ) + + return prefill_add_rms + + +@tvm.transform.module_pass(opt_level=0, name="FuseAddRMSNorm") +class FuseAddRMSNorm: # pylint: disable=too-few-public-methods + """A compiler pass that fuses add + rms_norm.""" + + def __init__(self, target: tvm.target.Target) -> None: + """Initializer. + + Parameters + ---------- + target : tvm.target.Target + Target device. + """ + self.TX = 1024 # default + + if target.max_num_threads < self.TX: + self.TX = target.max_num_threads + + def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule: + """IRModule-level transformation.""" + with PatternContext() as ctx: + pat_x1 = wildcard() + pat_x2 = wildcard() + pat_y = is_op("relax.add")(pat_x1, pat_x2) + pat_w = wildcard() + pat_o = is_op("relax.nn.rms_norm")(pat_y, pat_w) + + def rewriter(matchings, bindings): + x1 = matchings[pat_x1] + x2 = matchings[pat_x2] + weight = matchings[pat_w] + y = matchings[pat_y] + o = matchings[pat_o] + eps = bindings[o].attrs.epsilon + if x1.struct_info.dtype != "float16": + return {} + n, _, h = x1.struct_info.shape + func_name = "fuse_add_norm_prefill" if n == 1 else "fuse_add_norm_decode" + + if all(gv.name_hint != func_name for gv in mod.functions): + h = int(h) + if h % self.TX != 0: + return {} + if n == 1: + func = _get_add_rms_norm_prefill(h, eps, self.TX) + else: + func = _get_add_rms_norm_decode(h, eps, self.TX) + mod[func_name] = func + gvar = mod.get_global_var(func_name) + relax.expr._update_struct_info( # pylint: disable=protected-access + gvar, + relax.FuncStructInfo.opaque_func(ret=relax.ObjectStructInfo()), + ) + else: + gvar = mod.get_global_var(func_name) + o_y_tuple = relax.call_tir( + gvar, + [x1, x2, weight], + out_sinfo=[x1.struct_info, x1.struct_info], + ) + return { + o: relax.TupleGetItem(o_y_tuple, 0), + y: relax.TupleGetItem(o_y_tuple, 1), + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.tir.PrimFunc) and gvar not in new_mod: + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py new file mode 100644 index 0000000..f8a64c8 --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -0,0 +1,86 @@ +"""A compiler pass that fuses dequantize + matmul + elementwise.""" +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard + + +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeMatmulEwise") +class FuseDequantizeMatmulEwise: # pylint: disable=too-few-public-methods + """A compiler pass that fuses dequantize + matmul + elementwise.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + seq = [] + for n_aux_tensor in [1, 2, 3, 4]: + for match_ewise in [0, 1, 2, 6]: + if match_ewise == 6 and n_aux_tensor != 4: + continue + seq.append( + relax.transform.FuseOpsByPattern( + [ + ( + "dequantize_matmul", + *_pattern(match_ewise, n_aux_tensor), + ) + ] + ) + ) + seq.append(relax.transform.FuseTIR()) + return tvm.transform.Sequential(seq)(mod) + + +def _pattern(match_ewise: int, n_aux_tensor: int): + # pylint: disable=invalid-name + w_scaled = wildcard() + x = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([w_scaled] + [wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + matmul = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([x, w] + [wildcard() for _ in range(match_ewise)]), + add_constraint=False, + ) + # pylint: enable=invalid-name + annotations = { + "w_scaled": w_scaled, + "x": x, + "w": w, + "matmul": matmul, + } + + def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["w"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return g_var.name_hint.startswith("dequantize") or g_var.name_hint.startswith( + "fused_dequantize" + ) + + def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["matmul"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return ( + g_var.name_hint.startswith("matmul") + or g_var.name_hint.startswith("fused_matmul") + or g_var.name_hint.startswith("NT_matmul") + or g_var.name_hint.startswith("fused_NT_matmul") + ) + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + return _check_decoding(ctx) and _check_matmul(ctx) + + return matmul, annotations, _check diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_take.py b/python/mlc_chat/compiler_pass/fuse_dequantize_take.py new file mode 100644 index 0000000..8079215 --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_dequantize_take.py @@ -0,0 +1,90 @@ +"""A compiler pass that fuses dequantize + take.""" +import tvm +from tvm import IRModule, relax, tir +from tvm.relax.dpl.pattern import ( + GlobalVarPattern, + TuplePattern, + is_const, + is_op, + wildcard, +) + + +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeTake") +class FuseDequantizeTake: # pylint: disable=too-few-public-methods + """A compiler pass that fuses dequantize + take.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + seq = [] + for n_aux_tensor in [2, 3]: + for match_tir_vars in [False, True]: + seq.append( + relax.transform.FuseOpsByPattern( + [ + ( + "dequantize_take", + *_pattern(n_aux_tensor, match_tir_vars), + ) + ] + ) + ) + seq.append(relax.transform.FuseTIR()) + mod = tvm.transform.Sequential(seq)(mod) + for g_var, func in mod.functions_items(): + name = g_var.name_hint + if isinstance(func, tir.PrimFunc) and ( + ("fused_dequantize" in name) and ("take" in name) + ): + sch_mod = tvm.IRModule({"main": func}) + sch_mod = tir.transform.ForceNarrowIndexToInt32()(sch_mod) + sch = tir.Schedule(sch_mod) + sch.compute_inline("dequantize") + mod[g_var] = sch.mod["main"] + return mod + + +def _pattern(n_aux_tensor: int, match_tir_vars: bool): + dequantize = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + indices = ~is_const() + if match_tir_vars: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([dequantize, indices]), + wildcard(), + ] + else: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([dequantize, indices]), + ] + take = is_op("relax.call_tir")( + *call_tir_args_take, + add_constraint=False, + ) + annotations = { + "take": take, + "dequantize": dequantize, + "indices": indices, + } + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + take = ctx.annotated_expr["take"] + dequantize = ctx.annotated_expr["dequantize"] + if not isinstance(dequantize, relax.expr.Call): + return False + if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( + dequantize.args[0], relax.GlobalVar + ): + return False + return "take" in take.args[0].name_hint and "dequantize" in dequantize.args[0].name_hint + + return take, annotations, _check diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py b/python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py new file mode 100644 index 0000000..d89f62c --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py @@ -0,0 +1,106 @@ +"""A compiler pass that fuses transpose + dequantize.""" +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeTranspose") +class FuseDequantizeTranspose: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + dequantize.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _DequantizeTransposeFuser(mod).transform() + + +@mutator +class _DequantizeTransposeFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__( + self, + mod: IRModule, + ): + super().__init__(mod) + self.mod = mod + + def transform(self) -> IRModule: + """Entry point""" + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + call = self.visit_expr_post_order(call) + if call.op != tvm.ir.Op.get("relax.matmul"): + return call + # Do not fuse dequantize-transpose for GeMM + if ( + call.args[0].struct_info.ndim < 2 + or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) + or call.args[0].struct_info.shape[-2].value != 1 + ): + return call + + matmul_rhs = self.lookup_binding(call.args[1]) + if ( + not isinstance(matmul_rhs, relax.Call) + or matmul_rhs.op != tvm.ir.Op.get("relax.permute_dims") + or matmul_rhs.args[0].struct_info.ndim != 2 + or matmul_rhs.attrs.axes is not None + ): + return call + + transpose_input = self.lookup_binding(matmul_rhs.args[0]) + if ( + not isinstance(transpose_input, relax.Call) + or transpose_input.op != tvm.ir.Op.get("relax.call_tir") + or not transpose_input.args[0].name_hint.startswith("dequantize") + or not isinstance(transpose_input.struct_info, relax.TensorStructInfo) + ): + return call + + dequantize_tir_func = self.mod[transpose_input.args[0]] + assert isinstance(dequantize_tir_func, tir.PrimFunc) + if ( # pylint: disable=too-many-boolean-expressions + len(dequantize_tir_func.body.block.alloc_buffers) != 1 + or not isinstance(dequantize_tir_func.body.block.body, tir.SeqStmt) + or len(dequantize_tir_func.body.block.body) != 2 + or not isinstance(dequantize_tir_func.body.block.body[1], tir.For) + or not isinstance(dequantize_tir_func.body.block.body[1].body.body, tir.BlockRealize) + or dequantize_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose" + ): + return call + + new_func_buffers = [ + dequantize_tir_func.buffer_map[var] for var in dequantize_tir_func.params + ] + new_func_buffers[-1] = dequantize_tir_func.body.block.alloc_buffers[0] + new_func = tir.PrimFunc( + params=new_func_buffers, + body=tir.BlockRealize( + iter_values=[], + predicate=True, + block=tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=dequantize_tir_func.body.block.body[0], + ), + ), + ) + # Call `renew_defs` for deep-copy to avoid IR node duplication in + # different PrimFuncs of an IRModule. + new_func = tir.stmt_functor.renew_defs(new_func) + g_var = self.builder_.add_func(new_func, func_name="dequantize") + dequantize_matmul_rhs = self.builder_.emit( + relax.call_tir(g_var, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info) + ) + return relax.op.matmul(call.args[0], dequantize_matmul_rhs, out_dtype=call.attrs.out_dtype) diff --git a/python/mlc_chat/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py b/python/mlc_chat/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py new file mode 100644 index 0000000..c5a4094 --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py @@ -0,0 +1,322 @@ +"""A compiler pass that fuses dequantize matmul + epilogue.""" +import operator +from functools import reduce + +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl import rewrite_call +from tvm.relax.dpl.pattern import is_op, wildcard + + +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeEpilogue") +class FuseFTDequantizeEpilogue: # pylint: disable=too-few-public-methods + """A compiler pass that fuses FasterTransformer dequantize matmul + epilogue.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + func = fuse_bias(func) + func = fuse_activation(func) + func = fuse_residual_binary(func) + func = fuse_residual_unary(func) + mod[gv] = func + return mod + + +def fuse_bias(func: relax.Function) -> relax.Function: + """ + Fuse following `relax.add` into fastertransformer.gemm_fp16_int as bias: + + Before: + ``` + lv1 = relax.call_dps_packed("fastertransformer.gemm_fp16_int", ...) + lv2 = relax.add(lv1, bias) + + ``` + After: + ``` + lv2 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias", ..., bias, ...) + ``` + + Parameters + ---------- + func : relax.Function + The function before fusion. + + Returns + ------- + ret : relax.Function + The function after fusion. + """ + decode_matmul = is_op("relax.call_dps_packed")(varg_default_wildcard=True) + bias = wildcard() + pattern = is_op("relax.add")(decode_matmul, bias) | is_op("relax.add")(bias, decode_matmul) + + def rewriter(expr, match): + if match[decode_matmul].args[0].global_symbol == "fastertransformer.gemm_fp16_int": + assert len(match[decode_matmul].args) == 2 + args_list = match[decode_matmul].args[1] + assert len(args_list) == 8 + if not args_list[3].value == "identity": + # bias cannot be fused after activation + return expr + matched_bias = match[bias] + bias_stride = ( + matched_bias.struct_info.shape[-1] + if bias + and not reduce(operator.mul, matched_bias.struct_info.shape, 1) + == matched_bias.struct_info.shape[-1] + else 0 + ) + return relax.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias", + [ + args_list[0], # x + args_list[1], # weight + args_list[2], # scale + matched_bias, # bias + args_list[3], # activation + args_list[4], # m + args_list[5], # n + args_list[6], # k + args_list[7], # group_size + bias_stride, # bias_stride + ], + out_sinfo=match[decode_matmul].struct_info, + ) + return expr + + return rewrite_call(pattern, rewriter, func) + + +def fuse_activation(func: relax.Function) -> relax.Function: + """ + Fuse following `relax.nn.silu/relu/gelu` into fastertransformer.gemm_fp16_int_bias + as activation: + + Before: + ``` + lv1 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias", ...) + lv2 = relax.silu(lv1) + + ``` + After: + ``` + lv2 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias", ..., "silu", ...) + ``` + + Parameters + ---------- + func : relax.Function + The function before fusion. + + Returns + ------- + ret : relax.Function + The function after fusion. + """ + # pylint: disable=unsupported-binary-operation + decode_matmul = is_op("relax.call_dps_packed")(varg_default_wildcard=True) + pattern = ( + is_op("relax.nn.silu")(decode_matmul) + | is_op("relax.nn.gelu")(decode_matmul) + | is_op("relax.nn.relu")(decode_matmul) + ) + + def rewriter(expr, match): + if match[decode_matmul].args[0].global_symbol == "fastertransformer.gemm_fp16_int": + matched_activation = match[pattern] + assert matched_activation.op.name in ["relax.nn.silu", "relax.nn.gelu", "relax.nn.relu"] + assert len(match[decode_matmul].args) == 2 + args_list = match[decode_matmul].args[1] + assert len(args_list) == 8 + return relax.call_dps_packed( + "fastertransformer.gemm_fp16_int", + [ + args_list[0], # x + args_list[1], # weight + args_list[2], # scale + matched_activation.op.name[9:], # activation + args_list[4], # m + args_list[5], # n + args_list[6], # k + args_list[7], # group_size + ], + out_sinfo=match[decode_matmul].struct_info, + ) + if match[decode_matmul].args[0].global_symbol == "fastertransformer.gemm_fp16_int_bias": + matched_activation = match[pattern] + assert matched_activation.op.name in ["relax.nn.silu", "relax.nn.gelu", "relax.nn.relu"] + assert len(match[decode_matmul].args) == 2 + args_list = match[decode_matmul].args[1] + assert len(args_list) == 10 + return relax.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias", + [ + args_list[0], # x + args_list[1], # weight + args_list[2], # scale + args_list[3], # bias + matched_activation.op.name[9:], # activation + args_list[5], # m + args_list[6], # n + args_list[7], # k + args_list[8], # group_size + args_list[9], # bias_stride + ], + out_sinfo=match[decode_matmul].struct_info, + ) + return expr + + return rewrite_call(pattern, rewriter, func) + + +def fuse_residual_binary(func: relax.Function) -> relax.Function: + """ + Fuse following `relax.add/multiply` into fastertransformer.gemm_fp16_int_bias as + residual binary operation: + + Before: + ``` + lv1 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias", ...) + lv2 = relax.add(lv1, residual) + + ``` + After: + ``` + lv2 = relax.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias_residual", + ..., + residual, + ..., + "plus", + ... + ) + ``` + + Parameters + ---------- + func : relax.Function + The function before fusion. + + Returns + ------- + ret : relax.Function + The function after fusion. + """ + # pylint: disable=unsupported-binary-operation + decode_matmul = is_op("relax.call_dps_packed")(varg_default_wildcard=True) + residual = wildcard() + pattern = ( + is_op("relax.add")(decode_matmul, residual) + | is_op("relax.add")(residual, decode_matmul) + | is_op("relax.multiply")(decode_matmul, residual) + | is_op("relax.multiply")(residual, decode_matmul) + ) + + def rewriter(expr, match): + if match[decode_matmul].args[0].global_symbol == "fastertransformer.gemm_fp16_int_bias": + matched_binary = match[pattern] + assert matched_binary.op.name in ["relax.add", "relax.multiply"] + binary_op = "plus" if matched_binary.op.name == "relax.add" else "multiply" + assert len(match[decode_matmul].args) == 2 + args_list = match[decode_matmul].args[1] + assert len(args_list) == 10 + matched_residual = match[residual] + if not args_list[9].value == 0: + # fastertransformer.gemm_fp16_int_bias_residual does not support + # bias_stride != 0 yet + return expr + return relax.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias_residual", + [ + args_list[0], # x + args_list[1], # weight + args_list[2], # scale + args_list[3], # bias + matched_residual, # residual + args_list[4], # activation + binary_op, # binary_op + "identity", # unary_op + args_list[5], # m + args_list[6], # n + args_list[7], # k + args_list[8], # group_size + ], + out_sinfo=match[decode_matmul].struct_info, + ) + return expr + + return rewrite_call(pattern, rewriter, func) + + +def fuse_residual_unary(func: relax.Function) -> relax.Function: + """ + Fuse following `relax.nn.silu/relu/gelu` into fastertransformer.gemm_fp16_int_bias_residual + as residual unary operation: + + Before: + ``` + lv1 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias_residual", ...) + lv2 = relax.silu(lv1) + + ``` + After: + ``` + lv2 = relax.call_dps_packed("fastertransformer.gemm_fp16_int_bias_residual", ..., "silu", ...) + ``` + + Parameters + ---------- + func : relax.Function + The function before fusion. + + Returns + ------- + ret : relax.Function + The function after fusion. + """ + # pylint: disable=unsupported-binary-operation + decode_matmul = is_op("relax.call_dps_packed")(varg_default_wildcard=True) + pattern = ( + is_op("relax.nn.silu")(decode_matmul) + | is_op("relax.nn.gelu")(decode_matmul) + | is_op("relax.nn.relu")(decode_matmul) + ) + + def rewriter(expr, match): + if ( + match[decode_matmul].args[0].global_symbol + == "fastertransformer.gemm_fp16_int_bias_residual" + ): + matched_activation = match[pattern] + assert matched_activation.op.name in ["relax.nn.silu", "relax.nn.gelu", "relax.nn.relu"] + assert len(match[decode_matmul].args) == 2 + args_list = match[decode_matmul].args[1] + assert len(args_list) == 12 + return relax.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias_residual", + [ + args_list[0], # x + args_list[1], # weight + args_list[2], # scale + args_list[3], # bias + args_list[4], # residual + args_list[5], # activation + args_list[6], # binary_op + matched_activation.op.name[9:], # activation + args_list[8], # m + args_list[9], # n + args_list[10], # k + args_list[11], # group_size + ], + out_sinfo=match[decode_matmul].struct_info, + ) + return expr + + return rewrite_call(pattern, rewriter, func) diff --git a/python/mlc_chat/compiler_pass/fuse_transpose_matmul.py b/python/mlc_chat/compiler_pass/fuse_transpose_matmul.py new file mode 100644 index 0000000..5b3ecec --- /dev/null +++ b/python/mlc_chat/compiler_pass/fuse_transpose_matmul.py @@ -0,0 +1,151 @@ +"""A compiler pass that fuses transpose + matmul.""" +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") +class FuseTransposeMatmul: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + matmul.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + mod = relax.transform.FuseOpsByPattern( + [ + ( + "transpose_matmul_fuse", + *_pattern(), + ), + ] + )(mod) + transpose_matmul_codegen = _TransposeMatmulFuser(mod) + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(g_var, func) + return transpose_matmul_codegen.builder_.get() + + +def _pattern(): + """Pattern for transpose + matmul.""" + # pylint: disable=invalid-name + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(x, wT) + # pylint: enable=invalid-name + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + transpose_call = context.annotated_expr["wT"] + ndim = transpose_call.args[0].struct_info.ndim + if ndim == -1: + return False + if ndim == 2 and transpose_call.attrs.axes is None: + return True + axes = list(range(ndim)) + axes[-1], axes[-2] = axes[-2], axes[-1] + return list(transpose_call.attrs.axes) == axes + + return o, annotations, _check + + +# pylint: disable=missing-docstring,invalid-name + + +@mutator +class _TransposeMatmulFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod): + super().__init__(mod) + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(a_relax, bT_relax) + ).struct_info.shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + dim_equal = a_dim == b_dim + if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + else: + a_indices.append(idx_spatial[i]) + b_indices.append(idx_spatial[i]) + + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + b_indices.append(idx_reduce) + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[1], + call.args[0], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) diff --git a/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py new file mode 100644 index 0000000..bf709bc --- /dev/null +++ b/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py @@ -0,0 +1,195 @@ +"""A compiler pass that lifts TIR-level global allocation to Relax.""" +from typing import Dict, List, Tuple + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc") +class LiftTIRGlobalBufferAlloc: # pylint: disable=too-few-public-methods + """A compiler pass that lifts TIR-level global allocation to Relax.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + return _TIRGlobalAllocRewriter(mod).transform() + + +@mutator +class _TIRGlobalAllocRewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule): + super().__init__(mod) + self.mod = mod + self.gv2new_tensor_sinfo: Dict[ + tvm.ir.GlobalVar, Tuple[List[relax.TensorStructInfo], tir.PrimFunc] + ] = {} + + def transform(self) -> IRModule: + """Entry point of the transformation""" + for g_var, func in self.mod.functions_items(): + if isinstance(func, tir.PrimFunc): + updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) + if len(tensor_sinfo_list) > 0: + self.gv2new_tensor_sinfo[g_var] = (tensor_sinfo_list, func) + self.builder_.update_func(g_var, updated_func) + + self.mod = self.builder_.get() + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed + call = self.visit_expr_post_order(call) + if ( + call.op != tvm.ir.Op.get("relax.call_tir") + or call.args[0] not in self.gv2new_tensor_sinfo + ): + return call + + g_var = call.args[0] + tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] + + assert len(call.sinfo_args) == 1 + if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): + tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo) + if not success: + # Cannot resolve TIR var mapping. Fall back to no lifting. + self.builder_.update_func(g_var, func_before_update) + self.gv2new_tensor_sinfo.pop(g_var) + return call + + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + new_call = relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], + attrs=call.attrs, + ) + emitted_tuple = self.builder_.emit(new_call) + return relax.TupleGetItem(emitted_tuple, 0) + assert isinstance(call.sinfo_args[0], relax.TupleStructInfo) + return relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)], + attrs=call.attrs, + ) + + +def remove_global_buf_alloc( + func: tir.PrimFunc, +) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]: + """Remove the global buffer allocation for a given TIR PrimFunc.""" + assert isinstance(func.body, tir.BlockRealize) + params = list(func.params) + buffer_map = dict(func.buffer_map) + tensor_sinfo = [] + alloc_buffers = [] + + insertion_point = len(params) + while params[insertion_point - 1].dtype != "handle": + insertion_point -= 1 + assert insertion_point >= 1 + + prev_root_block = func.body.block + for buf_alloc in func.body.block.alloc_buffers: + if buf_alloc.scope() == "global": + param = tir.Var("var_" + buf_alloc.name, "handle") + params.insert(insertion_point, param) + insertion_point += 1 + buffer_map[param] = buf_alloc + tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype)) + else: + alloc_buffers.append(buf_alloc) + + if len(tensor_sinfo) == 0: + return func, [] + + assert len(prev_root_block.iter_vars) == 0 + assert len(prev_root_block.reads) == 0 + assert len(prev_root_block.writes) == 0 + assert len(prev_root_block.match_buffers) == 0 + assert prev_root_block.name_hint == "root" + assert prev_root_block.init is None + root_block = tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=prev_root_block.body, + alloc_buffers=alloc_buffers, + annotations=prev_root_block.annotations, + ) + + updated_func = tir.PrimFunc( + params=params, + body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block), + ret_type=func.ret_type, + buffer_map=buffer_map, + attrs=func.attrs, + ) + return updated_func, tensor_sinfo + + +def _has_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool: + assert isinstance(tensor_sinfo.shape, relax.ShapeExpr) + for dim in tensor_sinfo.shape.values: + if not isinstance(dim, tir.IntImm): + return True + return False + + +def _resolve_tir_var_mapping( # pylint: disable=too-many-locals + func: tir.PrimFunc, + call: relax.Call, + tensor_sinfo: List[relax.TensorStructInfo], +) -> Tuple[List[relax.TensorStructInfo], bool]: + """Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function""" + var_map: Dict[tir.Var, tir.PrimExpr] = {} + + n_arg = len(call.args[1].fields) + for i in range(n_arg): + buffer_shape = func.buffer_map[func.params[i]].shape + arg_shape = call.args[1][i].struct_info.shape.values + assert len(buffer_shape) == len(arg_shape) + for v_l, v_r in zip(buffer_shape, arg_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + ret_tensors = call.sinfo_args[0] + ret_tensors = ( + [ret_tensors] # type: ignore[assignment] + if isinstance(ret_tensors, relax.TensorStructInfo) + else list(ret_tensors.fields) + ) + for i, ret_tensor in enumerate(ret_tensors): + buffer_shape = func.buffer_map[func.params[n_arg + i]].shape + ret_tensor_shape = ret_tensor.shape.values + assert len(buffer_shape) == len(ret_tensor_shape) + for v_l, v_r in zip(buffer_shape, ret_tensor_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + updated_tensor_sinfo = [] + for sinfo in tensor_sinfo: + if not _has_symbolic_var(sinfo): + updated_tensor_sinfo.append(sinfo) + continue + new_shape = [] + for dim in sinfo.shape.values: + new_shape.append(tir.stmt_functor.substitute(dim, var_map)) + updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype)) + return updated_tensor_sinfo, True diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py new file mode 100644 index 0000000..98922c6 --- /dev/null +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -0,0 +1,160 @@ +"""The compilation pipeline for LLM applications.""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import tvm +from tvm import IRModule +from tvm import dlight as dl +from tvm.relax import register_pipeline # pylint: disable=no-name-in-module +from tvm.relax.frontend import nn + +from mlc_chat.support import logging + +from .attach_to_ir_module import ( + AttachAdditionalPrimFuncs, + AttachLogitProcessFunc, + AttachMemoryPlanAttr, + AttachVariableBounds, +) +from .clean_up_tir_attrs import CleanUpTIRAttrs +from .cublas_dispatch import CublasDispatch +from .estimate_memory_usage import AttachMetadataWithMemoryUsage +from .fuse_add_norm import FuseAddRMSNorm +from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise +from .fuse_dequantize_take import FuseDequantizeTake +from .fuse_dequantize_transpose import FuseDequantizeTranspose +from .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue +from .fuse_transpose_matmul import FuseTransposeMatmul +from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc +from .rewrite_kv_cache_creation import RewriteKVCacheCreation +from .scatter_tuple_get_item import ScatterTupleGetItem + +logger = logging.getLogger(__name__) + + +@tvm.transform.module_pass(opt_level=0, name="_LogProgress") +class _LogProgress: # pylint: disable=too-few-public-methods + """A dummy compiler pass that does nothing but logging.""" + + def __init__(self, *args): + self.args = args + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """A dummy transformation""" + logger.info(*self.args) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="DebugDump") +class _DebugDump: # pylint: disable=too-few-public-methods + """A dummy compiler pass that does nothing but logging. + Only enabled when debug_dump is not None""" + + def __init__(self, file_name: str, file_path: Optional[Path], show_meta: bool = False): + self.file_name = file_name + self.file_path = file_path + self.show_meta = show_meta + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """A dummy transformation that dumps the module to file""" + if self.file_path is not None: + # NOTE: We use debug level here to avoid spamming the console + logger.debug("Dumping IR to %s", self.file_path / self.file_name) + with open(self.file_path / self.file_name, "w", encoding="utf-8") as f: + f.write(mod.script(show_meta=self.show_meta)) + return mod + + +@register_pipeline("mlc_llm") +def _mlc_llm_pipeline( # pylint: disable=too-many-arguments + target: tvm.target.Target, + flashinfer: bool = False, + cublas_gemm: bool = False, + faster_transformer: bool = False, # pylint: disable=unused-argument + variable_bounds: Dict[str, int] = None, + additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, + metadata: Dict[str, Any] = None, + ext_mods: List[nn.ExternModule] = None, + debug_dump: Optional[Path] = None, +): + variable_bounds = variable_bounds or {} + additional_tirs = additional_tirs or {} + metadata = metadata or {} + ext_mods = ext_mods or [] + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + seq = tvm.transform.Sequential( + [ + # Phase 0. Add additional information for compilation and remove unused Relax func + RewriteKVCacheCreation(target, flashinfer, metadata), + AttachVariableBounds(variable_bounds), + AttachLogitProcessFunc(), + AttachAdditionalPrimFuncs(additional_tirs), + AttachMemoryPlanAttr(), + tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), + _DebugDump("debug-phase0.py", debug_dump, show_meta=False), + # Phase 1. Passes on high-level operator graph + _LogProgress("Running TVM Relax graph-level optimizations"), + FuseFTDequantizeEpilogue(), + FuseDequantizeTranspose(), + CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]), + FuseAddRMSNorm(target=target), + FuseTransposeMatmul(), + _DebugDump("debug-phase1.py", debug_dump, show_meta=False), + # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline + _LogProgress("Lowering to TVM TIR kernels"), + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + _DebugDump("debug-phase2.py", debug_dump, show_meta=False), + # Phase 3. Passes on TIR + _LogProgress("Running TVM TIR-level optimizations"), + FuseDequantizeMatmulEwise(), + FuseDequantizeTake(), + tvm.relax.transform.DeadCodeElimination(), + CleanUpTIRAttrs(["op_pattern"]), + _DebugDump("debug-phase3.py", debug_dump, show_meta=False), + # Phase 4. Low-level Optimizations + _LogProgress("Running TVM Dlight low-level optimizations"), + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + _DebugDump("debug-phase4.py", debug_dump, show_meta=False), + _LogProgress("Lowering to VM bytecode"), + LiftTIRGlobalBufferAlloc(), + ( + tvm.tir.transform.ForceNarrowIndexToInt32() + if target.kind.name != "cuda" + else tvm.transform.Sequential([]) + ), + ScatterTupleGetItem(), + tvm.relax.transform.RewriteDataflowReshape(), + tvm.relax.transform.ToNonDataflow(), + tvm.relax.transform.RemovePurityChecking(), + tvm.relax.transform.CallTIRRewrite(), + tvm.relax.transform.StaticPlanBlockMemory(), + AttachMetadataWithMemoryUsage(metadata), + tvm.relax.transform.RewriteCUDAGraph(), + tvm.relax.transform.LowerAllocTensor(), + tvm.relax.transform.KillAfterLastUse(), + tvm.relax.transform.VMBuiltinLower(), + tvm.relax.transform.VMShapeLower(), + tvm.relax.transform.AttachGlobalSymbol(), + _DebugDump("debug-final.py", debug_dump, show_meta=False), + _LogProgress("Compiling external modules"), + tvm.relax.transform.AttachExternModules(ext_mods), + _LogProgress("Compilation complete! Exporting to disk"), + ] + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py new file mode 100644 index 0000000..808969e --- /dev/null +++ b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py @@ -0,0 +1,154 @@ +"""A pass that rewrites KV cache creation functions in IRModule.""" + +from typing import Any, Dict + +import tvm +from tvm import IRModule, relax + +from mlc_chat.nn import RopeMode, kv_cache + + +def extract_creation_args(func: relax.Function) -> Dict[str, Any]: + """Extract the KV cache creation args from the given generic creation func.""" + assert isinstance(func.body, relax.SeqExpr) + assert len(func.body.blocks) == 1 + assert isinstance(func.body.blocks[0], relax.DataflowBlock) + assert len(func.body.blocks[0].bindings) == 2 + assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding) + assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call) + assert isinstance(func.body.blocks[0].bindings[0].value.op, relax.ExternFunc) + assert ( + func.body.blocks[0].bindings[0].value.op.global_symbol + == "mlc.create_paged_kv_cache_generic" + ) + + args = func.body.blocks[0].bindings[0].value.args + assert len(args) == 10 + assert isinstance(args[0], relax.ShapeExpr) + assert len(args[0].values) == 4 + for i in range(1, 9): + assert isinstance(args[i], relax.PrimValue) + assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) + assert isinstance(args[9], relax.DataTypeImm) + + return { + "max_batch_size": args[0].values[0], + "max_total_seq_len": args[0].values[1], + "prefill_chunk_size": args[0].values[2], + "page_size": args[0].values[3], + "num_hidden_layers": args[1].value.value, + "num_attention_heads": args[2].value.value, + "num_key_value_heads": args[3].value.value, + "head_dim": args[4].value.value, + "rope_mode": args[5].value.value, + "rope_scale": args[6].value.value, + "rope_theta": args[7].value.value, + "rotary_dim": args[8].value.value, + "dtype": args[9].value, + } + + +@tvm.transform.module_pass(opt_level=0, name="RewriteKVCacheCreation") +class RewriteKVCacheCreation: # pylint: disable=too-many-instance-attributes + """Rewrite KV cache creation functions to IRModule.""" + + def __init__( + self, target: tvm.target.Target, flashinfer: bool, metadata: Dict[str, Any] + ) -> None: + """Initializer. + + Parameters + ---------- + target : tvm.target.Target + The target of the model compilation. + + flashinfer : bool + A boolean indicating if flashinfer is enabled. + """ + self.target = target + self.flashinfer = flashinfer + self.metadata = metadata + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + func_dict = {} + creation_func = None + for g_var, func in mod.functions_items(): + # Try to find the `create_paged_kv_cache` func. + if g_var.name_hint == "create_paged_kv_cache": + creation_func = func + else: + func_dict[g_var] = func + + if creation_func is None: + return mod + + new_mod = IRModule(func_dict) + if mod.attrs is not None: + new_mod = new_mod.with_attrs(mod.attrs) + + kwargs = extract_creation_args(creation_func) + + bb = relax.BlockBuilder(new_mod) + self.create_tir_paged_kv_cache(bb, kwargs) + self.create_flashinfer_paged_kv_cache(bb, kwargs) + return bb.finalize() + + def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None: + """Create the TIR-based PagedKVCache""" + max_batch_size = relax.Var( + "max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]]) + ) + max_total_seq_len = relax.Var( + "max_total_seq_len_", relax.ShapeStructInfo([kwargs["max_total_seq_len"]]) + ) + prefill_chunk_size = relax.Var( + "prefill_chunk_size_", relax.ShapeStructInfo([kwargs["prefill_chunk_size"]]) + ) + page_size = relax.Var("page_size_", relax.ShapeStructInfo([kwargs["page_size"]])) + + with bb.function( + name="create_tir_paged_kv_cache", + params=[max_batch_size, max_total_seq_len, prefill_chunk_size, page_size], + ): + cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs) + bb.emit_func_output(cache._expr) # pylint: disable=protected-access + + def create_flashinfer_paged_kv_cache( + self, bb: relax.BlockBuilder, kwargs: Dict[str, Any] + ) -> None: + """Create the FlashInfer-based PagedKVCache""" + # Filter the cases which FlashInfer does not support. + if ( # pylint: disable=too-many-boolean-expressions + not self.flashinfer + or str(kwargs["dtype"]) != "float16" + or kwargs["head_dim"] != 128 + or ( + kwargs["rope_mode"] == RopeMode.INLINE + and kwargs["rotary_dim"] != kwargs["head_dim"] + ) + or ( + # bypass GPT-2 since it uses attn_score_scaling_factor + "gpt2" + in self.metadata["model_type"] + ) + ): + return + + max_batch_size = relax.Var( + "max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]]) + ) + max_total_seq_len = relax.Var( + "max_total_seq_len_", relax.ShapeStructInfo([kwargs["max_total_seq_len"]]) + ) + prefill_chunk_size = relax.Var( + "prefill_chunk_size_", relax.ShapeStructInfo([kwargs["prefill_chunk_size"]]) + ) + page_size = relax.Var("page_size_", relax.ShapeStructInfo([kwargs["page_size"]])) + + with bb.function( + name="create_flashinfer_paged_kv_cache", + params=[max_batch_size, max_total_seq_len, prefill_chunk_size, page_size], + ): + cache = kv_cache.FlashInferPagedKVCache(target=self.target, **kwargs) + bb.emit_func_output(cache._expr) # pylint: disable=protected-access diff --git a/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py b/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py new file mode 100644 index 0000000..281c6ec --- /dev/null +++ b/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py @@ -0,0 +1,51 @@ +"""A compiler pass that scatters TupleGetItem for lazy TupleGetItems.""" + +from typing import Dict + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr import Expr, Var +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="ScatterTupleGetItem") +class ScatterTupleGetItem: # pylint: disable=too-few-public-methods + """A compiler pass that scatters TupleGetItem for lazy TupleGetItems.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Scatter(mod).transform() + + +@mutator +class _Scatter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule) -> None: + super().__init__(mod) + self.mod = mod + self.var_map: Dict[Var, Expr] = {} + + def transform(self) -> IRModule: + """Entry point""" + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_var_binding_(self, binding: relax.VarBinding): + super().visit_var_binding_(binding) + if isinstance(binding.value, relax.TupleGetItem): + self.var_map[binding.var] = binding.value + + def visit_dataflow_var_( # pylint: disable=arguments-renamed + self, var: relax.DataflowVar + ) -> Expr: + if var in self.var_map: + new_var = self.builder_.emit(self.var_map[var], name_hint=var.name_hint) + self.set_var_remap(var.vid, new_var) + self.var_map.pop(var) + return new_var + return var diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_chat/conversation_template.py new file mode 100644 index 0000000..6ca148f --- /dev/null +++ b/python/mlc_chat/conversation_template.py @@ -0,0 +1,77 @@ +"""The conversation template registry and presets in MLC LLM""" + +from typing import Dict, Optional + +from .protocol.conversation_protocol import Conversation, MessagePlaceholders + + +class ConvTemplateRegistry: + """Global conversation template registry for preset templates.""" + + _conv_templates: Dict[str, Conversation] = {} + + @staticmethod + def register_conv_template(conv_template: Conversation, override: bool = False) -> None: + """Register a new conversation template in the global registry. + Using `override = True` to override the previously registered + template with the same name. + """ + name = conv_template.name + if name is None: + raise ValueError("The template to register should have non-None name.") + if name in ConvTemplateRegistry._conv_templates and not override: + raise ValueError( + "The name of the template has been registered " + f"for {ConvTemplateRegistry._conv_templates[name].model_dump_json()}" + ) + ConvTemplateRegistry._conv_templates[name] = conv_template + + @staticmethod + def get_conv_template(name: str) -> Optional[Conversation]: + """Return the conversation template specified by the given name, + or None if the template is not registered. + """ + return ConvTemplateRegistry._conv_templates.get(name, None) + + +############## Preset Conversation Templates ############## + +# Llama2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-2", + system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n ", + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, + seps=[" "], + role_content_sep=" ", + role_empty_sep=" ", + stop_str=["[INST]"], + stop_token_ids=[2], + ) +) + +# Gorilla +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gorilla", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant provides helpful, detailed, and " + "polite responses to the user's inquiries." + ), + role_templates={ + "user": ( + f"<> {MessagePlaceholders.USER.value} <> " + f"{MessagePlaceholders.FUNCTION.value}" + ), + }, + roles={"user": "USER", "assistant": "ASSISTANT", "tool": "USER"}, + seps=["\n", ""], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + ) +) diff --git a/python/mlc_chat/embeddings/__init__.py b/python/mlc_chat/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/embeddings/openai.py b/python/mlc_chat/embeddings/openai.py new file mode 100644 index 0000000..022d55b --- /dev/null +++ b/python/mlc_chat/embeddings/openai.py @@ -0,0 +1,245 @@ +# pylint: disable=missing-docstring +from __future__ import annotations + +from typing import Iterable, List, Optional, Sequence, Tuple + +import numpy as np +from langchain.embeddings import OpenAIEmbeddings # pylint: disable=import-error +from langchain_community.embeddings.openai import ( # pylint: disable=import-error + async_embed_with_retry, + embed_with_retry, +) + +from mlc_chat.support import logging + +logger = logging.getLogger(__name__) + + +class MLCEmbeddings(OpenAIEmbeddings): + def _chunk_tokens(self, texts: Sequence[str]) -> Tuple[List[List], List[int]]: + """Tokenize and chunk texts to fit in the model's context window.""" + if not self.embedding_ctx_length: + raise ValueError( + "embedding_ctx_length must be defined to use _get_len_safe_embeddings." + ) + + try: + import tiktoken # pylint: disable=import-outside-toplevel + except ImportError as err: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for OpenAIEmbeddings. " + "Please install it with `pip install tiktoken`." + ) from err + + tokens = [] + indices = [] + model_name = self.tiktoken_model_name or self.model + try: + encoding = tiktoken.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken.get_encoding(model) + for i, text in enumerate(texts): + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) + for j in range(0, len(token), self.embedding_ctx_length): + tokens.append(token[j : j + self.embedding_ctx_length]) + indices.append(i) + return tokens, indices + + def _batch_embed( + self, inputs: Sequence, *, chunk_size: Optional[int] = None + ) -> List[List[float]]: + batched_embeddings: List[List[float]] = [] + _chunk_size = chunk_size or self.chunk_size + _iter: Iterable = range(0, len(inputs), _chunk_size) + if self.show_progress_bar: + try: + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + _iter = tqdm(_iter) + except ImportError: + pass + + for i in _iter: + response = embed_with_retry( + self, + input=inputs[i : i + _chunk_size], + **self._invocation_params, + ) + batched_embeddings.extend(r["embedding"] for r in response["data"]) + return batched_embeddings + + async def _abatch_embed( + self, inputs: Sequence, *, chunk_size: Optional[int] = None + ) -> List[List[float]]: + batched_embeddings: List[List[float]] = [] + _chunk_size = chunk_size or self.chunk_size + _iter: Iterable = range(0, len(inputs), _chunk_size) + if self.show_progress_bar: + try: + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + _iter = tqdm(_iter) + except ImportError: + pass + + for i in _iter: + response = await async_embed_with_retry( + self, + input=inputs[i : i + _chunk_size], + **self._invocation_params, + ) + batched_embeddings.extend(r["embedding"] for r in response["data"]) + return batched_embeddings + + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + def _get_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, + ) -> List[List[float]]: + tokens, indices = self._chunk_tokens(texts) + batched_embeddings = self._batch_embed(tokens, chunk_size=chunk_size) + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings): + results[idx].append(batched_emb) + num_tokens_in_batch[idx].append(len(tokens_i)) + + embeddings = [] + empty_average = embed_with_retry( + self, + input="", + **self._invocation_params, + )["data"][ + 0 + ]["embedding"] + for _result, num_tokens in zip(results, num_tokens_in_batch): + if len(_result) == 0: + average = empty_average + else: + average = np.average(_result, axis=0, weights=num_tokens) + normalized = (average / np.linalg.norm(average)).tolist() + embeddings.append(normalized) + + return embeddings + + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + async def _aget_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, + ) -> List[List[float]]: + tokens, indices = self._chunk_tokens(texts) + batched_embeddings = await self._abatch_embed(tokens, chunk_size=chunk_size) + + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings): + results[idx].append(batched_emb) + num_tokens_in_batch[idx].append(len(tokens_i)) + + embeddings = [] + empty_average = ( + await async_embed_with_retry( + self, + input="", + **self._invocation_params, + ) + )[ + "data" + ][0]["embedding"] + for _result, num_tokens in zip(results, num_tokens_in_batch): + if len(_result) == 0: + average = empty_average + else: + average = np.average(_result, axis=0, weights=num_tokens) + normalized = (average / np.linalg.norm(average)).tolist() + embeddings.append(normalized) + + return embeddings + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = None + ) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # NOTE: to keep things simple, as long as the embedding_ctx_length is defined, + # we assume the list may contain texts longer than the maximum context and + # use length-safe embedding function. + if self.embedding_ctx_length: + return self._get_len_safe_embeddings( + texts, engine=self.deployment, chunk_size=chunk_size + ) + + embeddings = self._batch_embed(texts, chunk_size=chunk_size) + return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings] + + async def aembed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint async for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # NOTE: to keep things simple, as long as the embedding_ctx_length is defined, + # we assume the list may contain texts longer than the maximum context and + # use length-safe embedding function. + if self.embedding_ctx_length: + return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + + embeddings = await self._abatch_embed(texts, chunk_size=chunk_size) + return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + return self.embed_documents([text])[0] + + async def aembed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint async for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embeddings = await self.aembed_documents([text]) + return embeddings[0] diff --git a/python/mlc_chat/gradio.py b/python/mlc_chat/gradio.py new file mode 100644 index 0000000..1ab6ae6 --- /dev/null +++ b/python/mlc_chat/gradio.py @@ -0,0 +1,247 @@ +"""Gradio interface for MLC Chat.""" +# pylint: disable=import-error,invalid-name,too-many-instance-attributes,too-many-locals +import argparse +import glob +import os +from typing import Dict, Optional + +import gradio as gr + +from .chat_module import ChatModule + + +def _parse_args(): + args = argparse.ArgumentParser("MLC-Chat Gradio Interface") + args.add_argument( + "--artifact-path", + type=str, + default="dist", + help="Please provide a path containing all the model folders you wish to use.", + ) + args.add_argument( + "--device", + type=str, + default="auto", + help="The description of the device to run on. User should provide a string in the \ + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of \ + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the \ + local device), and 'device_id' is the device id to run on. If no 'device_id' \ + is provided, it will be set to 0 by default.", + ) + args.add_argument("--port", type=int, default=7860, help="The port number to run gradio.") + args.add_argument("--host", type=str, default="127.0.0.1", help="The local host to run gradio.") + args.add_argument( + "--share", + action="store_true", + help="Whether to create a publicly shareable link for the interface.", + ) + parsed = args.parse_args() + return parsed + + +def _get_all_available_models_under_dir(artifact_path: str) -> Dict[str, str]: + r"""Given the artifact path storing all models, returns a dict mapping available model names + to the correct `model` args passed into ChatModule. + + Note + ---- + We only search for folders under the artifact_path, without recursive search for subfolders. + For each folder, we count it as a valid MLC model folder if either it contains an + `mlc-chat-config.json` file, or it contains a `params` folder which contains an + `mlc-chat-config.json` file. We will map the name of a valid folder to its full path to the + folder containing `mlc-chat-config.json`. + """ + + # step 0. retrieve the absolute path of artifact_path + search_dir = os.path.abspath(artifact_path) + if not os.path.exists(search_dir): + err_msg = ( + f"The artifact path {artifact_path} you provided is neither a valid full path nor a " + "valid path relative to the current working directory. Please provide a correct " + "artifact path.", + ) + raise FileNotFoundError(err_msg) + + # step 1. go through all the folders, build the model dict + model_dict = {} + for path in glob.glob(os.path.join(search_dir, "*")): + if os.path.isdir(path): + model_name = os.path.basename(os.path.normpath(path)) + # check if it contains `mlc-chat-config.json` + if os.path.exists(os.path.join(path, "mlc-chat-config.json")): + model_dict[model_name] = os.path.abspath(path) + # check if it contains `params/mlc-chat-config.json` + elif os.path.exists(os.path.join(path, "params", "mlc-chat-config.json")): + model_dict[model_name] = os.path.abspath(os.path.join(path, "params")) + + return model_dict + + +class GradioModule: + r"""The Gradio module for MLC Chat. Different from ChatModule Python API, Gradio module allows + users to load in a directory of models, watch the streaming in web browser, and switch between + models more easily to compare performance. + + Note: Multimodality will be supported soon, i.e. allowing users to upload an image to chat. + """ + + def __init__(self, artifact_path: str = "dist", device: str = "auto"): + self.artifact_path = artifact_path + self.device_str = device + self.chat_mod: Optional[ChatModule] = None + self.model_dict = _get_all_available_models_under_dir(artifact_path) + + def gradio_reload_model(self, model_name: str): + r"""Reload the model given the user-selected model name.""" + self.chat_mod = ChatModule(self.model_dict[model_name], self.device_str) + + updated_dict = { + "chatbot": None, + "chat_state": [], + "img_list": [], + "image_model": gr.update(interactive=False, visible=False), + "stream_interval": gr.update(interactive=True, visible=True), + "reset_llm_button": gr.update(interactive=True, visible=True), + "stats_button": gr.update(interactive=True, visible=True), + "stats_output": gr.update(placeholder="Click to get runtime statistics.", visible=True), + "text_input": gr.update(interactive=True, placeholder="Type and press enter"), + } + + return list(updated_dict.values()) + + def gradio_reset_model(self): + r"""Reset the current chat model.""" + self.chat_mod.reset_chat() + + updated_dict = { + "chatbot": None, + "chat_state": [], + "img_list": [], + "text_input": gr.update(interactive=True, placeholder="Type and press enter"), + } + + return list(updated_dict.values()) + + def gradio_ask(self, text_input, chatbot): + r"""Display user text input in the chatbot.""" + chatbot = chatbot + [[text_input, None]] + text_input = "" + return text_input, chatbot + + def gradio_answer(self, chatbot, stream_interval): + r"""Generate and display the chat module's response. + Note: Below is a low-level implementation of generate() API, since it's easier + to yield without delta callback.""" + prompt = chatbot[-1][0] + # pylint: disable=protected-access + self.chat_mod._prefill(prompt) + i, new_msg = 0, "" + while not self.chat_mod._stopped(): + self.chat_mod._decode() + if i % stream_interval == 0 or self.chat_mod._stopped(): + new_msg = self.chat_mod._get_message() + chatbot[-1][1] = new_msg + yield chatbot + i += 1 + # pylint: enable=protected-access + + def gradio_stats(self): + """Get runtime statistics.""" + return self.chat_mod.stats() + + +def launch_gradio( + artifact_path: str = "dist", + device: str = "auto", + port: int = 7860, + share: bool = False, + host: str = "127.0.0.1", +): + r"""Launch the gradio interface with a given port, creating a publically sharable link if + specified.""" + + # create a gradio module + mod = GradioModule(artifact_path, device) + + title = """

MLC Chat Gradio Interface

""" + description = ( + """

Welcome to MLC Chat! Pick a model from your local ids to get started.

""" + ) + + with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + + # ---------------------- user interface design ------------------------- + with gr.Row(): + with gr.Column(scale=0.3): + llm_model = gr.Dropdown(list(mod.model_dict.keys()), label="Language Model") + image_model = gr.Dropdown( + ["-None-"], + label="Do you wanna add an image model?", + visible=False, + interactive=False, + ) + image = gr.Image(type="pil", interactive=False, visible=False) + stream_interval = gr.Slider( + minimum=1.0, + maximum=5.0, + value=2.0, + step=1.0, + interactive=True, + visible=False, + label="Stream Interval", + ) + reset_llm_button = gr.Button("Reset chat", visible=False, interactive=False) + stats_button = gr.Button("Get Runtime Statistics", interactive=False, visible=False) + stats_output = gr.Textbox( + show_label=False, + placeholder="Click to get runtime statistics.", + interactive=False, + visible=False, + container=False, + ) + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(label="MLC Chat") + text_input = gr.Textbox( + show_label=False, + placeholder="Select a model to start chatting!", + interactive=False, + container=False, + ) + + # ---------------------- local variables --------------------------- + # type 1. buttons whose visibility change when llm reload + llm_buttons = [ + image_model, + stream_interval, + reset_llm_button, + stats_button, + stats_output, + text_input, + ] + # type 2. buttons whose visibility change when image model reload + # pylint: disable=unused-variable + image_model_buttons = [image, text_input] + # type 3. chatbot state variables + chatbot_vars = [chatbot, chat_state, img_list] + + # -------------------------- handle control -------------------------- + llm_model.change( + mod.gradio_reload_model, [llm_model], chatbot_vars + llm_buttons, queue=False + ) + text_input.submit(mod.gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( + mod.gradio_answer, [chatbot, stream_interval], [chatbot] + ) + reset_llm_button.click(mod.gradio_reset_model, [], chatbot_vars + [text_input]) + stats_button.click(mod.gradio_stats, [], [stats_output]) + + # launch to the web + demo.launch(share=share, enable_queue=True, server_port=port, server_name=host) + + +if __name__ == "__main__": + ARGS = _parse_args() + launch_gradio(ARGS.artifact_path, ARGS.device, ARGS.port, ARGS.share, ARGS.host) diff --git a/python/mlc_chat/help.py b/python/mlc_chat/help.py new file mode 100644 index 0000000..0464bd0 --- /dev/null +++ b/python/mlc_chat/help.py @@ -0,0 +1,142 @@ +"""Help message for CLI arguments.""" +HELP = { + "config": ( + """ +1) Path to a HuggingFace model directory that contains a `config.json` or +2) Path to `config.json` in HuggingFace format, or +3) The name of a pre-defined model architecture. + +A `config.json` file in HuggingFace format defines the model architecture, including the vocabulary +size, the number of layers, the hidden size, number of attention heads, etc. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. + +A HuggingFace directory often contains a `config.json` which defines the model architecture, +the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, +as well as an optional `generation_config.json` provides additional default configuration for +text generation. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. +""" + ).strip(), + "quantization": """ +The quantization mode we use to compile. If unprovided, will infer from `model`. +""".strip(), + "model": """ +A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. +""".strip(), + "model_lib_path": """ +The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use +the provided ``model`` to search over possible paths. +""".strip(), + "model_type": """ +Model architecture such as "llama". If not set, it is inferred from `mlc-chat-config.json`. +""".strip(), + "device_compile": """ +The GPU device to compile the model to. If not set, it is inferred from GPUs available locally. +""".strip(), + "device_quantize": """ +The device used to do quantization such as "cuda" or "cuda:0". Will detect from local available GPUs +if not specified. +""".strip(), + "device_deploy": """ +The device used to deploy the model such as "cuda" or "cuda:0". Will detect from local +available GPUs if not specified. +""".strip(), + "host": """ +The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS. +Examples of the LLVM triple: +1) iPhones: arm64-apple-ios; +2) ARM64 Android phones: aarch64-linux-android; +3) WebAssembly: wasm32-unknown-unknown-wasm; +4) Windows: x86_64-pc-windows-msvc; +5) ARM macOS: arm64-apple-darwin. +""".strip(), + "opt": """ +Optimization flags. MLC LLM maintains a predefined set of optimization flags, +denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, +and O3 represents extreme optimization that could potentially break the system. +Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. +--opt="cublas_gemm=1;cudagraph=0". +""".strip(), + "system_lib_prefix": """ +Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". +This is useful when compiling multiple models into a single library to avoid symbol +conflicts. Different from objcopy, this takes no effect for shared library. +""".strip(), + "context_window_size": """ +Option to provide the maximum sequence length supported by the model. +This is usually explicitly shown as context length or context window in the model card. +If this option is not set explicitly, by default, +it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`, +and the latter is usually inaccurate for some models. +""".strip(), + "output_compile": """ +The path to the output file. The suffix determines if the output file is a shared library or +objects. Available suffixes: +1) Linux: .so (shared), .tar (objects); +2) macOS: .dylib (shared), .tar (objects); +3) Windows: .dll (shared), .tar (objects); +4) Android, iOS: .tar (objects); +5) Web: .wasm (web assembly). +""".strip(), + "source": """ +The path to original model weight, infer from `config` if missing. +""".strip(), + "source_format": """ +The format of source model weight, infer from `config` if missing. +""".strip(), + "output_quantize": """ +The output directory to save the quantized model weight. Will create `params_shard_*.bin` and +`ndarray-cache.json` in this directory. +""".strip(), + "conv_template": """ +Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model +""".strip(), + "output_gen_mlc_chat_config": """ +The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer +configuration. +""".strip(), + "sliding_window_size": """ +(Experimental) The sliding window size in sliding window attention (SWA). +This optional field overrides the `sliding_window_size` in config.json for +those models that use SWA. Currently only useful when compiling Mistral. +This flag subjects to future refactoring. +""".strip(), + "prefill_chunk_size": """ +(Experimental) The chunk size during prefilling. By default, +the chunk size is the same as sliding window or max sequence length. +This flag subjects to future refactoring. +""".strip(), + "attention_sink_size": """ +(Experimental) The number of stored sinks. Only supported on Mistral yet. By default, +the number of sinks is 4. This flag subjects to future refactoring. +""".strip(), + "max_batch_size": """ +The maximum allowed batch size set for batch prefill/decode function. +""".strip(), + """tensor_parallel_shards""": """ +Number of shards to split the model into in tensor parallelism multi-gpu inference. +""".strip(), + "overrides": """ +Model configuration override. Configurations to override `mlc-chat-config.json`. Supports +`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`, +`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly +specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". +""".strip(), + "chatconfig_overrides": """ +Chat configuration override. Configurations to override ChatConfig. Supports `conv_template`, +`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`, +`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model chat could be explicitly +specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". +""".strip(), + "debug_dump": """ +Specifies the directory where the compiler will store its IRs for debugging purposes +during various phases of compilation. By default, this is set to `None`, indicating +that debug dumping is disabled. +""".strip(), + "prompt": """ +The prompt of the text generation. +""".strip(), + "generate_length": """ +The target length of the text generation. +""".strip(), +} diff --git a/python/mlc_chat/interface/__init__.py b/python/mlc_chat/interface/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/interface/bench.py b/python/mlc_chat/interface/bench.py new file mode 100644 index 0000000..a1d4e27 --- /dev/null +++ b/python/mlc_chat/interface/bench.py @@ -0,0 +1,28 @@ +"""Python entrypoint of benchmark.""" +from typing import Optional + +from mlc_chat.chat_module import ChatConfig, ChatModule + +from .chat import ChatConfigOverride + + +def bench( # pylint: disable=too-many-arguments + model: str, + prompt: str, + device: str, + opt: str, + overrides: ChatConfigOverride, + generate_length: int, + model_lib_path: Optional[str], +): + """run the benchmarking""" + # Set up chat config + config = ChatConfig(opt=opt) + # Apply overrides + config = overrides.apply(config) + # Set up ChatModule + cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + + output = cm.benchmark_generate(prompt, generate_length=generate_length) + print(f"Generated text:\n{output}\n") + print(f"Statistics:\n{cm.stats(verbose=True)}") diff --git a/python/mlc_chat/interface/chat.py b/python/mlc_chat/interface/chat.py new file mode 100644 index 0000000..0df8bb1 --- /dev/null +++ b/python/mlc_chat/interface/chat.py @@ -0,0 +1,196 @@ +"""Python entrypoint of chat.""" +import dataclasses +import re +import json +from typing import List, Optional, Union + +from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error +from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error + +from mlc_chat.callback import StreamToStdout +from mlc_chat.chat_module import ChatConfig, ChatModule, GenerationConfig +from mlc_chat.support import argparse +from mlc_chat.support.config import ConfigOverrideBase + + +@dataclasses.dataclass +class ChatConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes + """Flags for overriding chat config.""" + + conv_template: Optional[str] = None + context_window_size: Optional[int] = None + sliding_window_size: Optional[int] = None + prefill_chunk_size: Optional[int] = None + attention_sink_size: Optional[int] = None + max_batch_size: Optional[int] = None + tensor_parallel_shards: Optional[int] = None + + @staticmethod + def from_str(source: str) -> "ChatConfigOverride": + """Parse model config override values from a string.""" + parser = argparse.ArgumentParser(description="chat config override values") + parser.add_argument("--conv_template", type=str, default=None) + parser.add_argument("--tensor_parallel_shards", type=int, default=None) + parser.add_argument("--context_window_size", type=int, default=None) + parser.add_argument("--sliding_window_size", type=int, default=None) + parser.add_argument("--prefill_chunk_size", type=int, default=None) + parser.add_argument("--attention_sink_size", type=int, default=None) + parser.add_argument("--max_batch_size", type=int, default=None) + + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return ChatConfigOverride( + conv_template=results.conv_template, + tensor_parallel_shards=results.tensor_parallel_shards, + context_window_size=results.context_window_size, + sliding_window_size=results.sliding_window_size, + prefill_chunk_size=results.prefill_chunk_size, + attention_sink_size=results.attention_sink_size, + max_batch_size=results.max_batch_size, + ) + + +@dataclasses.dataclass +class GenerationConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes + """Flags for overriding generation config.""" + + temperature: Optional[float] = None + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + mean_gen_len: Optional[int] = None + max_gen_len: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + n: Optional[int] = None # pylint: disable=invalid-name + stop: Optional[Union[str, List[str]]] = None + + @staticmethod + def from_str(source: str) -> "GenerationConfigOverride": + """Parse model config override values from a string.""" + parser = argparse.ArgumentParser(description="generation config override values") + parser.add_argument("--temperature", type=float, default=None) + parser.add_argument("--repetition_penalty", type=float, default=None) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--mean_gen_len", type=int, default=None) + parser.add_argument("--max_gen_len", type=int, default=None) + parser.add_argument("--presence_penalty", type=float, default=None) + parser.add_argument("--frequency_penalty", type=float, default=None) + parser.add_argument("--n", type=int, default=None) + parser.add_argument("--stop", type=str, default=None) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return GenerationConfigOverride( + temperature=results.temperature, + repetition_penalty=results.repetition_penalty, + top_p=results.top_p, + mean_gen_len=results.mean_gen_len, + max_gen_len=results.max_gen_len, + presence_penalty=results.presence_penalty, + frequency_penalty=results.frequency_penalty, + n=results.n, + stop=results.stop.split(",") if results.stop is not None else None, + ) + + +def _print_help_str(): + help_str = """You can use the following special commands: + /help print the special commands + /exit quit the cli + /stats print out the latest stats (token/sec) + /reset restart a fresh chat + /set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). + Multi-line input: Use escape+enter to start a new line. +""" + print(help_str) + + +def _set_up_key_bindings(): + kb = KeyBindings() + + @kb.add("escape", "enter") + def _(event): + event.current_buffer.insert_text("\n") + + @kb.add("enter") + def _(event): + event.current_buffer.validate_and_handle() + + return kb + + +def chat( + model: str, + device: str, + opt: str, + overrides: ChatConfigOverride, + model_lib_path: Optional[str], + energy_events_filename: str, +): + """chat with a model.""" + # Set up chat config and generate config + config = ChatConfig(opt=opt) + generate_config = GenerationConfig() + # Apply overrides + config = overrides.apply(config) + # Set up ChatModule + cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + _print_help_str() + cm._process_system_prompts() # pylint: disable=protected-access + + # Multi-line input support: set escape+enter as start a new line + kb = _set_up_key_bindings() + + while True: + prompt = get_prompt( + f"{cm._get_role_0()}: ", # pylint: disable=protected-access + key_bindings=kb, + multiline=True, + ) + if prompt[:6] == "/reset": + cm.reset_chat() + elif prompt[:5] == "/exit": + with open(energy_events_filename, 'w', encoding='utf-8') as f: + for event_key, event_value in cm.energy_events.items(): + f.write(f"{event_key} {event_value}\n") + break + elif prompt[:6] == "/stats": + # print(cm.stats(verbose=True), flush=True) + # ----------- prefill ----------- + # throughput: 87.899 tok/s + # total tokens: 10 tok + # total time: 0.114 s + # ------------ decode ------------ + # throughput: 54.603 tok/s + # total tokens: 18 tok + # total time: 0.330 s + # Parse the above metrics into json format + stats = cm.stats(verbose=True) + if stats.startswith("{"): # This is already handled by the backend + print(stats, flush=True) + else: # This is in case the backend has not been changed + stats = stats.strip().split("\n") + float_re = re.compile(r"\d+\.\d+") + int_re = re.compile(r"\d+") + stats_dict = {} + try: + for i in range(0, len(stats), 4): + stats_dict[stats[i].strip('-').strip()] = { + "throughput": f"{float(re.findall(float_re, stats[i + 1])[0])} tok/s", + "total_tokens": f"{int(re.findall(int_re, stats[i + 2])[0])} tok", + "total_time": f"{float(re.findall(float_re, stats[i + 3])[0])} s", + } + print(json.dumps(stats_dict, indent=4), flush=True) + except IndexError: + print(stats, flush=True) + elif prompt[:4] == "/set": + gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1]) + generate_config = gen_config_overrides.apply(generate_config) + elif prompt[:5] == "/help": + _print_help_str() + else: + print(f"{cm._get_role_1()}: ") # pylint: disable=protected-access + cm.generate( + prompt, + progress_callback=StreamToStdout(callback_interval=2), + generation_config=generate_config, + ) diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py new file mode 100644 index 0000000..7688715 --- /dev/null +++ b/python/mlc_chat/interface/compile.py @@ -0,0 +1,230 @@ +"""Python entrypoint of compilation.""" +import dataclasses +import math +from io import StringIO +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +from tvm import IRModule, relax, tir +from tvm.ir.transform import Pass, PassContext +from tvm.relax.frontend import nn +from tvm.target import Target + +from mlc_chat import compiler_pass as _ +from mlc_chat import op as op_ext +from mlc_chat.cli.model_metadata import _report_memory_usage +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +from .compiler_flags import ModelConfigOverride, OptimizationFlags + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class CompileArgs: # pylint: disable=too-many-instance-attributes + """Arguments to MLC LLM's compiler.""" + + config: Path + quantization: Quantization + model: Model + target: Target + opt: OptimizationFlags + build_func: Callable[[IRModule, "CompileArgs", Pass], None] + system_lib_prefix: str + output: Path + overrides: ModelConfigOverride + debug_dump: Optional[Path] + + def __post_init__(self) -> None: + self.opt.update(self.target, self.quantization) + + def display(self) -> None: + """Display the arguments to stdout.""" + out = StringIO() + print(f"{bold('Compiling with arguments:')}", file=out) + print(f" {bold('--config'):<25} {self.config}", file=out) + print(f" {bold('--quantization'):<25} {self.quantization}", file=out) + print(f" {bold('--model-type'):<25} {self.model.name}", file=out) + print(f" {bold('--target'):<25} {self.target.export()}", file=out) + print(f" {bold('--opt'):<25} {self.opt}", file=out) + print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out) + print(f" {bold('--output'):<25} {self.output}", file=out) + print(f" {bold('--overrides'):<25} {self.overrides}", file=out) + # As it's debug only, no need to display + # print(f" {bold('--debug-dump'):<25} {self.debug_dump}", file=out) + print(out.getvalue().rstrip()) + + +def _apply_preproc_to_params( + named_params: List[Tuple[str, nn.Parameter]], + model_config, +) -> Dict[str, tir.PrimFunc]: + extra_tirs: Dict[str, tir.PrimFunc] = {} + for _, param in named_params: + preprocs = param.attrs.get("preprocs", []) + shard_strategy = param.attrs.get("shard_strategy", None) + if shard_strategy is not None and model_config.tensor_parallel_shards > 1: + preprocs.append( + shard_strategy.gen_shard_info( + shards=model_config.tensor_parallel_shards, + weight=param, + ) + ) + if shard_strategy.name not in extra_tirs: + extra_tirs[shard_strategy.name] = shard_strategy.gen_tir( + shards=model_config.tensor_parallel_shards, + weight=param, + ) + param.attrs["preprocs"] = preprocs + return extra_tirs + + +def _compile(args: CompileArgs, model_config: ConfigBase): + def _get_variable_bounds(model_config) -> Dict[str, int]: + if hasattr(model_config, "sliding_window_size"): + return { + "rolling_cache_len": model_config.sliding_window_size, + "kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size, + "seq_len": model_config.prefill_chunk_size, + "batch_size": getattr(model_config, "max_batch_size", 1), + } + return { + "total_seq_len": model_config.context_window_size, + "seq_len": model_config.prefill_chunk_size, + "batch_size": getattr(model_config, "max_batch_size", 1), + } + + def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: + return { + "name": name, + # Record dynamic shape as -1 (e.g. vocab_size) + "shape": [s if isinstance(s, int) else s.name for s in param.shape], + "dtype": param.dtype, + "preprocs": param.attrs["preprocs"], + } + + def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: + all_kv_cache = nn.core._attribute_finder( # pylint: disable=protected-access + model, + prefix="", + condition_yield=lambda x: isinstance(x, nn.KVCache), + ) + result = 0 + for _, kv_cache in all_kv_cache: + result += math.prod(kv_cache.unit_shape) * np.dtype(kv_cache.dtype).itemsize + if getattr(model_config, "sliding_window_size", -1) > 0: + window_size = model_config.sliding_window_size + elif getattr(model_config, "context_window_size", -1) > 0: + window_size = model_config.context_window_size + else: + window_size = 0 + return result * window_size + + model_config = args.overrides.apply(model_config) + with args.target: + op_ext.enable( + target=args.target, + flashinfer=args.opt.flashinfer, + faster_transformer=args.opt.faster_transformer, + ) + # Step 1. Create the quantized model + logger.info("Creating model from: %s", args.config) + if ( + args.quantization.kind == "ft-quant" + and hasattr(model_config, "tensor_parallel_shards") + and model_config.tensor_parallel_shards > 1 + ): + raise NotImplementedError + if ( + hasattr(args.quantization, "linear_weight_layout") + and args.quantization.linear_weight_layout == "KN" + and hasattr(model_config, "tensor_parallel_shards") + and model_config.tensor_parallel_shards > 1 + ): + raise NotImplementedError( + "KN layout (q3f16_0 and q4f16_0) is not supported for tensor parallelism" + ) + model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization) + kv_cache_bytes = _find_kv_cache_bytes(model, model_config) + # Step 2. Exporting the model to TVM Unity + logger.info("Exporting the model to TVM Unity compiler") + mod, named_params, ext_mods = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + allow_extern=True, + ) + # Step 3. Running relax compilation pipeline + logger.info("Running optimizations using TVM Unity") + additional_tirs = _apply_preproc_to_params(named_params, model_config) + variable_bounds = _get_variable_bounds(model_config) + metadata = { + "model_type": args.model.name, + "quantization": args.quantization.name, + "context_window_size": getattr(model_config, "context_window_size", -1), + "sliding_window_size": getattr(model_config, "sliding_window_size", -1), + "attention_sink_size": getattr(model_config, "attention_sink_size", -1), + "prefill_chunk_size": model_config.prefill_chunk_size, # type: ignore + "tensor_parallel_shards": model_config.tensor_parallel_shards, # type: ignore + "kv_cache_bytes": kv_cache_bytes, + } + logger.info("Registering metadata: %s", metadata) + metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params] + with PassContext(config={"relax.backend.use_cuda_graph": args.opt.cudagraph}): + args.build_func( + mod, + args, + pipeline=relax.get_pipeline( # type: ignore + "mlc_llm", + target=args.target, + flashinfer=args.opt.flashinfer, + cublas_gemm=args.opt.cublas_gemm, + faster_transformer=args.opt.faster_transformer, + variable_bounds=variable_bounds, + additional_tirs=additional_tirs, + ext_mods=ext_mods, + metadata=metadata, + debug_dump=args.debug_dump, + ), + ) + _report_memory_usage(metadata=metadata, config=model_config) + logger.info("Generated: %s", bold(str(args.output))) + + +def compile( # pylint: disable=too-many-arguments,redefined-builtin + config: Dict[str, Any], + quantization: Quantization, + model_type: Model, + target: Target, + opt: OptimizationFlags, + build_func: Callable[[IRModule, CompileArgs, Pass], None], + system_lib_prefix: str, + output: Path, + overrides: ModelConfigOverride, + debug_dump: Optional[Path] = None, +): + """Compile a model given its configuration and quantization format to a specific target.""" + if "model_config" in config: + model_config = config.pop("model_config") + model_config.update(config) + model_config = model_type.config.from_dict(model_config) + else: + model_config = model_type.config.from_dict(config) + model_config.kwargs = {} + args = CompileArgs( + model_config, + quantization, + model_type, + target, + opt, + build_func, + system_lib_prefix, + output, + overrides, + debug_dump, + ) + args.display() + _compile(args, model_config) diff --git a/python/mlc_chat/interface/compiler_flags.py b/python/mlc_chat/interface/compiler_flags.py new file mode 100644 index 0000000..7eeedaf --- /dev/null +++ b/python/mlc_chat/interface/compiler_flags.py @@ -0,0 +1,158 @@ +"""Flags for overriding model config.""" +import dataclasses +from io import StringIO +from typing import Optional + +from mlc_chat.support import argparse, logging +from mlc_chat.support.config import ConfigOverrideBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OptimizationFlags: + """Optimization flags""" + + flashinfer: bool = False + cublas_gemm: bool = False + faster_transformer: bool = False + cudagraph: bool = False + + def __repr__(self) -> str: + out = StringIO() + print(f"flashinfer={int(self.flashinfer)}", file=out, end="") + print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") + print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="") + print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "OptimizationFlags": + """Parse optimization flags from a string.""" + + if source in OPT_FLAG_PRESET: + return OPT_FLAG_PRESET[source] + + def boolean(value: str) -> bool: + if value == "0": + return False + if value == "1": + return True + raise ValueError(f"Invalid boolean value: {value}") + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--flashinfer", type=boolean, default=True) + parser.add_argument("--cublas_gemm", type=boolean, default=False) + parser.add_argument("--faster_transformer", type=boolean, default=False) + parser.add_argument("--cudagraph", type=boolean, default=False) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return OptimizationFlags( + flashinfer=results.flashinfer, + cublas_gemm=results.cublas_gemm, + faster_transformer=results.faster_transformer, + cudagraph=results.cudagraph, + ) + + def update(self, target, quantization) -> None: + """Update optimization flags based on additional information.""" + + def _flashinfer(target) -> bool: + from mlc_chat.support.auto_target import ( # pylint: disable=import-outside-toplevel + detect_cuda_arch_list, + ) + + if not self.flashinfer: + return False + if target.kind.name != "cuda": + return False + arch_list = detect_cuda_arch_list(target) + for arch in arch_list: + if arch < 80: + logger.warning("flashinfer is not supported on CUDA arch < 80") + return False + return True + + def _cublas_gemm(target, quantization) -> bool: + """correct cublas_gemm flag""" + if not (target.kind.name == "cuda" and quantization.name in ["q0f16", "q0f32"]): + return False + return self.cublas_gemm + + def _faster_transformer(target) -> bool: + """correct faster_transformer flag""" + if not target.kind.name == "cuda": + return False + return self.faster_transformer + + self.flashinfer = _flashinfer(target) + self.cublas_gemm = _cublas_gemm(target, quantization) + self.faster_transformer = _faster_transformer(target) + + +@dataclasses.dataclass +class ModelConfigOverride(ConfigOverrideBase): + """Flags for overriding model config.""" + + context_window_size: Optional[int] = None + sliding_window_size: Optional[int] = None + prefill_chunk_size: Optional[int] = None + attention_sink_size: Optional[int] = None + max_batch_size: Optional[int] = None + tensor_parallel_shards: Optional[int] = None + + def __repr__(self) -> str: + out = StringIO() + print(f"context_window_size={self.context_window_size}", file=out, end="") + print(f";sliding_window_size={self.sliding_window_size}", file=out, end="") + print(f";prefill_chunk_size={self.prefill_chunk_size}", file=out, end="") + print(f";attention_sink_size={self.attention_sink_size}", file=out, end="") + print(f";max_batch_size={self.max_batch_size}", file=out, end="") + print(f";tensor_parallel_shards={self.tensor_parallel_shards}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "ModelConfigOverride": + """Parse model config override values from a string.""" + parser = argparse.ArgumentParser(description="model config override values") + parser.add_argument("--context_window_size", type=int, default=None) + parser.add_argument("--sliding_window_size", type=int, default=None) + parser.add_argument("--prefill_chunk_size", type=int, default=None) + parser.add_argument("--attention_sink_size", type=int, default=None) + parser.add_argument("--max_batch_size", type=int, default=None) + parser.add_argument("--tensor_parallel_shards", type=int, default=None) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return ModelConfigOverride( + context_window_size=results.context_window_size, + sliding_window_size=results.sliding_window_size, + prefill_chunk_size=results.prefill_chunk_size, + attention_sink_size=results.attention_sink_size, + max_batch_size=results.max_batch_size, + tensor_parallel_shards=results.tensor_parallel_shards, + ) + + +OPT_FLAG_PRESET = { + "O0": OptimizationFlags( + flashinfer=False, + cublas_gemm=False, + cudagraph=False, + ), + "O1": OptimizationFlags( + flashinfer=False, + cublas_gemm=True, + faster_transformer=True, + cudagraph=False, + ), + "O2": OptimizationFlags( + flashinfer=True, + cublas_gemm=True, + faster_transformer=True, + cudagraph=False, + ), + "O3": OptimizationFlags( + flashinfer=True, + cublas_gemm=True, + faster_transformer=True, + cudagraph=True, + ), +} diff --git a/python/mlc_chat/interface/convert_weight.py b/python/mlc_chat/interface/convert_weight.py new file mode 100644 index 0000000..1e28417 --- /dev/null +++ b/python/mlc_chat/interface/convert_weight.py @@ -0,0 +1,169 @@ +"""Python entrypoint of weight conversion.""" + +import dataclasses +import math +import os +from io import StringIO +from pathlib import Path + +import numpy as np +from tvm import tir +from tvm.contrib import tvmjs +from tvm.runtime import Device, NDArray +from tvm.runtime import cpu as cpu_device +from tvm.target import Target + +from mlc_chat.loader import LOADER +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import logging, tqdm +from mlc_chat.support.preshard import apply_preshard +from mlc_chat.support.style import bold, green + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ConversionArgs: # pylint: disable=too-many-instance-attributes + """Arguments to MLC LLM's weight conversation and quantization flow.""" + + config: Path + quantization: Quantization + model: Model + device: Device + source: Path + source_format: str + output: Path + + def display(self) -> None: + """Display the arguments to stdout.""" + + def _device_to_str(device: Device) -> str: + return f"{Device.MASK2STR[device.device_type]}:{device.device_id}" + + out = StringIO() + print(f"{bold('Weight conversion with arguments:')}", file=out) + print(f" {bold('--config'):<25} {self.config}", file=out) + print(f" {bold('--quantization'):<25} {self.quantization}", file=out) + print(f" {bold('--model-type'):<25} {self.model.name}", file=out) + print(f" {bold('--device'):<25} {_device_to_str(self.device)}", file=out) + print(f" {bold('--source'):<25} {self.source}", file=out) + print(f" {bold('--source-format'):<25} {self.source_format}", file=out) + print(f" {bold('--output'):<25} {self.output}", file=out) + print(out.getvalue().rstrip()) + + +def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals + pre_shards_num = os.getenv("MLC_INTERNAL_PRESHARD_NUM") + # model config & quantization config + model_config = args.model.config.from_file(args.config) + if ( + args.quantization.kind == "ft-quant" + and hasattr(model_config, "tensor_parallel_shards") + and model_config.tensor_parallel_shards > 1 + ): + raise NotImplementedError + if pre_shards_num is not None: + model_config.tensor_parallel_shards = int(pre_shards_num) + model, quantize_map = args.model.quantize[args.quantization.kind]( + model_config, args.quantization + ) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_params = dict(_named_params) + + if pre_shards_num is not None: + preshard_funcs = apply_preshard(quantize_map, named_params, int(pre_shards_num), args) + else: + preshard_funcs = None + + def _check_param(name: str, param: NDArray): + nonlocal named_params + if name not in named_params: + raise ValueError(f"Parameter not found in model: {name}") + if name in param_dict: + raise ValueError(f"Duplication: Parameter {name} already computed") + + # Check shape (possibly dynamic) + def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var + if len(actual) != len(expect): + return False + for actual_i, expect_i in zip(actual, expect): + assert isinstance(expect_i, (int, tir.Var)) + if isinstance(expect_i, int) and actual_i != expect_i: + return False + return True + + expect_shape = named_params[name].shape + actual_shape = param.shape + if not _check_shape(actual_shape, expect_shape): + raise ValueError( + f"Parameter {name} has shape {param.shape}, but expected {expect_shape}" + ) + # Check dtype + actual_dtype = param.dtype + expect_dtype = named_params[name].dtype + if actual_dtype != expect_dtype: + raise ValueError( + f"Parameter {name} has dtype {param.dtype}, but expected {expect_dtype}" + ) + del named_params[name] + + # load and quantize + param_dict = {} + total_bytes = 0.0 + with Target.from_device(args.device), tqdm.redirect(): + loader = LOADER[args.source_format]( + path=args.source, + extern_param_map=args.model.source[args.source_format](model_config, args.quantization), + quantize_param_map=quantize_map, + ) + for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs): + _check_param(name, param) + param = param.copyto(cpu_device()) + param_dict[name] = param + total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + total_params = loader.stats.total_param_num + if named_params: + raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") + # Log necessary statistics + logger.info( + "%s after quantization: %.3f GB", + green("Parameter size"), + total_bytes / (1024**3), + ) + logger.info(f"%s: {total_params:,}", green("Total parameters")) + logger.info( + "%s: %.3f", + green("Bits per parameter"), + total_bytes * 8.0 / total_params, + ) + # dump to output directory + tvmjs.dump_ndarray_cache( + param_dict, + str(args.output), + meta_data={ + "ParamSize": len(param_dict), + "ParamBytes": total_bytes, + "BitsPerParam": total_bytes * 8.0 / total_params, + }, + encode_format="f32-to-bf16", + ) + logger.info("Saved to directory: %s", bold(str(args.output))) + + +def convert_weight( # pylint: disable=too-many-arguments + config: Path, + quantization: Quantization, + model: Model, + device: Device, + source: Path, + source_format: str, + output: Path, +): + """MLC LLM's weight conversation and quantization flow.""" + args = ConversionArgs(config, quantization, model, device, source, source_format, output) + args.display() + _convert_args(args) diff --git a/python/mlc_chat/interface/gen_config.py b/python/mlc_chat/interface/gen_config.py new file mode 100644 index 0000000..35592db --- /dev/null +++ b/python/mlc_chat/interface/gen_config.py @@ -0,0 +1,232 @@ +"""Generator of mlc-chat-config.json and tokenizer configuration.""" + +import dataclasses +import json +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional + +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import convert_tiktoken, logging +from mlc_chat.support.style import bold, green, red + +from .compiler_flags import ModelConfigOverride + +logger = logging.getLogger(__name__) + +FOUND = green("Found") +NOT_FOUND = red("Not found") +FAILED = red("Failed") +VERSION = "0.1.0" + + +@dataclasses.dataclass +class MLCChatConfig: # pylint: disable=too-many-instance-attributes + """Fields in the dumped `mlc-chat-config.json` file.""" + + model_type: str + quantization: str + model_config: Dict[str, Any] + vocab_size: int + context_window_size: int + sliding_window_size: int + prefill_chunk_size: int + attention_sink_size: int + tensor_parallel_shards: int + # Control the behavior of the runtime + mean_gen_len: int = None + max_gen_len: int = None + shift_fill_factor: float = None + # Configuration of text generation + temperature: float = None + presence_penalty: float = None + frequency_penalty: float = None + repetition_penalty: float = None + top_p: float = None + # Conversation template + conv_template: str = None + pad_token_id: int = None + bos_token_id: int = None + eos_token_id: int = None + tokenizer_files: List[str] = dataclasses.field(default_factory=list) + # Version control + version: str = VERSION + + def apply_defaults(self) -> None: + """Apply system default value.""" + defaults = { + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "temperature": 0.7, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "top_p": 0.95, + "mean_gen_len": 128, + "max_gen_len": 512, + "shift_fill_factor": 0.3, + } + for key, value in defaults.items(): + if getattr(self, key) is None: + setattr(self, key, value) + logger.info("[System default] Setting %s: %s", bold(key), value) + + +def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements + config: Path, + model: Model, + quantization: Quantization, + conv_template: str, + context_window_size: Optional[int], + sliding_window_size: Optional[int], + prefill_chunk_size: Optional[int], + attention_sink_size: Optional[int], + tensor_parallel_shards: Optional[int], + max_batch_size: int, + output: Path, +): + """Entrypoint of MLC Chat configuration generation.""" + # Step 1. Initialize `mlc-chat-config.json` using `config.json` + model_config = ModelConfigOverride( + context_window_size=context_window_size, + sliding_window_size=sliding_window_size, + prefill_chunk_size=prefill_chunk_size, + attention_sink_size=attention_sink_size, + max_batch_size=max_batch_size, + tensor_parallel_shards=tensor_parallel_shards, + ).apply(model.config.from_file(config)) + mlc_chat_config = MLCChatConfig( + model_type=model.name, + quantization=quantization.name, + model_config=model_config.asdict(), + vocab_size=model_config.vocab_size, + context_window_size=getattr(model_config, "context_window_size", -1), + sliding_window_size=getattr(model_config, "sliding_window_size", -1), + prefill_chunk_size=model_config.prefill_chunk_size, + attention_sink_size=getattr(model_config, "attention_sink_size", -1), + tensor_parallel_shards=model_config.tensor_parallel_shards, + conv_template=conv_template, + ) + # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs + for generation_config_filename in ["generation_config.json", "config.json"]: + generation_config = config.parent / generation_config_filename + if generation_config.exists(): + with generation_config.open("r", encoding="utf-8") as in_file: + generation_config_json = json.load(in_file) + for key, value in generation_config_json.items(): + if hasattr(mlc_chat_config, key) and getattr(mlc_chat_config, key) is None: + setattr(mlc_chat_config, key, value) + logger.info("[%s] Setting %s: %s", generation_config_filename, bold(key), value) + else: + logger.info("%s %s: %s", NOT_FOUND, generation_config_filename, generation_config) + + # Step 3. Copy tokenizer configuration + # 3.1. Copy over the files and populate mlc_chat_config + for filename in TOKENIZER_FILES: + file = config.parent / filename + if file.exists(): + mlc_chat_config.tokenizer_files.append(filename) + dest = output / filename + shutil.copy(file, dest) + logger.info("%s tokenizer config: %s. Copying to %s", FOUND, file, bold(str(dest))) + else: + logger.info("%s tokenizer config: %s", NOT_FOUND, file) + # 3.2. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to + # `tokenizer.json` with `transformers`. + tokenizer_json_file = config.parent / "tokenizer.json" + tokenizer_model_file = config.parent / "tokenizer.model" + if tokenizer_model_file.exists() and (not tokenizer_json_file.exists()): + logger.info( + "The model has `tokenizer.model` but not `tokenizer.json`. " + "It is always recommended to prefer JSON instead. " + "Attempting to convert using HuggingFace transformers library" + ) + try: + from transformers import ( # pylint: disable=import-error,import-outside-toplevel + AutoTokenizer, + ) + + tokenizer_json_save_dest = output / "tokenizer.json" + fast_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True) + fast_tokenizer.backend_tokenizer.save(str(tokenizer_json_save_dest)) + mlc_chat_config.tokenizer_files.append("tokenizer.json") + logger.info("Succesfully converted `tokenizer.model` to: %s", tokenizer_json_save_dest) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Convertion to `tokenizer.json` %s with the exception below. " + "Skipping the conversion. Tokenizer will only use `tokenizer.model`", + FAILED, + exc_info=True, + ) + # 3.3. If we still don't have "tokenizer.json" at this point, try looking for "*.tiktoken" files + if (not tokenizer_json_file.exists()) and list(config.parent.glob("*.tiktoken")): + try: + logger.info( + "The model has tiktoken files but not `tokenizer.json`. " + "Attempting to convert from tiktoken files" + ) + convert_tiktoken.convert_tiktoken( + str(config.parent), str(output), mlc_chat_config.context_window_size + ) + mlc_chat_config.tokenizer_files.append("tokenizer.json") + mlc_chat_config.tokenizer_files.append("vocab.json") + mlc_chat_config.tokenizer_files.append("merges.txt") + mlc_chat_config.tokenizer_files.append("special_tokens_map.json") + logger.info("Succesfully converted from tiktoken files to: %s", str(output)) + except Exception: # pylint: disable=broad-exception-caught + logger.exception("%s with the exception below. Skipping", FAILED) + + # Step 4. Load system default value + mlc_chat_config.apply_defaults() + # Step 5. Dump the configuration file to output directory + with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file: + json.dump(dataclasses.asdict(mlc_chat_config), out_file, indent=2) + logger.info("Dumping configuration file to: %s", bold(out_file.name)) + + +TOKENIZER_FILES = [ + "tokenizer.model", + "tokenizer.json", + "vocab.json", + "merges.txt", + "added_tokens.json", + "tokenizer_config.json", +] + +CONV_TEMPLATES = { + "chatml", + "open_hermes_mistral", + "neural_hermes_mistral", + "llama_default", + "llama-2", + "mistral_default", + "gpt2", + "codellama_completion", + "codellama_instruct", + "vicuna_v1.1", + "conv_one_shot", + "redpajama_chat", + "rwkv_world", + "rwkv", + "gorilla", + "guanaco", + "dolly", + "oasst", + "stablelm", + "stablecode_completion", + "stablecode_instruct", + "minigpt", + "moss", + "LM", + "stablelm-3b", + "gpt_bigcode", + "wizardlm_7b", + "wizard_coder_or_math", + "glm", + "custom", # for web-llm only + "phi-2", + "stablelm-2", + "gemma_instruction", +} diff --git a/python/mlc_chat/interface/jit.py b/python/mlc_chat/interface/jit.py new file mode 100644 index 0000000..6d9b131 --- /dev/null +++ b/python/mlc_chat/interface/jit.py @@ -0,0 +1,128 @@ +"""Just-in-time compilation of MLC-Chat models.""" +import dataclasses +import hashlib +import json +import os +import shlex +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict + +from tvm.runtime import Device + +from mlc_chat.model import MODELS +from mlc_chat.support import logging +from mlc_chat.support.auto_device import device2str +from mlc_chat.support.constants import ( + MLC_CACHE_DIR, + MLC_DSO_SUFFIX, + MLC_JIT_POLICY, + MLC_TEMP_DIR, +) +from mlc_chat.support.style import blue, bold + +from .compiler_flags import ModelConfigOverride, OptimizationFlags + +logger = logging.getLogger(__name__) + + +def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path: + """Just-in-time compile a MLC-Chat model.""" + logger.info( + "%s = %s. Can be one of: ON, OFF, REDO, READONLY", + bold("MLC_JIT_POLICY"), + MLC_JIT_POLICY, + ) + if MLC_JIT_POLICY == "OFF": + raise RuntimeError("JIT is disabled by MLC_JIT_POLICY=OFF") + + with open(model_path / "mlc-chat-config.json", "r", encoding="utf-8") as in_file: + mlc_chat_config = json.load(in_file) + model_type = mlc_chat_config.pop("model_type") + quantization = mlc_chat_config.pop("quantization") + + def _get_optimization_flags() -> str: + opt = chat_config.pop("opt", None) + if opt is None: + opt = "O2" + return repr(OptimizationFlags.from_str(opt)) + + def _get_overrides() -> str: + forbid_list = ["context_window_size", "sliding_window_size", "attention_sink_size"] + result = [] + for field in dataclasses.fields(ModelConfigOverride): + value = chat_config.get(field.name, None) + if value is not None: + if field.name in forbid_list and value == -1: + continue + result.append(f"{field.name}={value}") + if not result: + result = ["tensor_parallel_shards=1"] + return ";".join(result) + + def _get_model_config() -> Dict[str, Any]: + model_config = mlc_chat_config.pop("model_config") + model_config.update(mlc_chat_config) + for field in dataclasses.fields(ModelConfigOverride): + value = chat_config.get(field.name, None) + if value is not None: + model_config[field.name] = value + return MODELS[model_type].config.from_dict(model_config).asdict() + + def _run_jit(opt: str, overrides: str, device: str, dst: str): + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: + dso_path = os.path.join(tmp_dir, f"lib.{MLC_DSO_SUFFIX}") + cmd = [ + sys.executable, + "-m", + "mlc_chat", + "compile", + str(model_path), + "--opt", + opt, + "--overrides", + overrides, + "--device", + device, + "--output", + dso_path, + ] + logger.info("Compiling using commands below:") + logger.info("%s", blue(shlex.join(cmd))) + subprocess.run(cmd, check=True) + shutil.move(dso_path, dst) + logger.info("Using compiled model lib: %s", bold(dst)) + + hash_key = { + "model_config": _get_model_config(), + "overrides": _get_overrides(), + "opt": _get_optimization_flags(), + "device": device2str(device), + "model_type": model_type, + "quantization": quantization, + } + hash_value = hashlib.md5( + json.dumps( + hash_key, + sort_keys=True, + indent=2, + ).encode("utf-8") + ).hexdigest() + dst = MLC_CACHE_DIR / "model_lib" / f"{hash_value}.so" + if dst.is_file() and MLC_JIT_POLICY in ["ON", "READONLY"]: + logger.info("Using cached model lib: %s", bold(str(dst))) + return dst + if MLC_JIT_POLICY == "READONLY": + raise RuntimeError( + "No cached model lib found, and JIT is disabled by MLC_JIT_POLICY=READONLY" + ) + _run_jit( + opt=hash_key["opt"], + overrides=hash_key["overrides"], + device=hash_key["device"], + dst=str(dst), + ) + return dst diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py new file mode 100644 index 0000000..7c7797d --- /dev/null +++ b/python/mlc_chat/interface/openai_api.py @@ -0,0 +1,183 @@ +# pylint: disable=missing-docstring,fixme,too-few-public-methods +""" +Adapted from FastChat's OpenAI protocol: +https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +""" + +import time +from typing import Any, Dict, List, Literal, Optional, Union + +import shortuuid +from pydantic import BaseModel, Field + + +class ToolCalls(BaseModel): + id: str = Field(default_factory=lambda: f"call_{shortuuid.random()}") + type: str = "function" + function: object + + +class ChatMessage(BaseModel): + role: str + content: Union[str, None] + name: Optional[str] = None + tool_calls: Optional[List[ToolCalls]] = None + + +class Function(BaseModel): + description: Optional[str] = None + name: str + parameters: object + + +class Tools(BaseModel): + type: Literal["function"] + function: Dict[str, Any] + + +class ToolChoice(BaseModel): + type: Literal["function"] + function: Dict[str, Any] + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + stream: Optional[bool] = False + temperature: float = None + top_p: float = None + # TODO: replace by presence_penalty and frequency_penalty + repetition_penalty: float = None + mean_gen_len: int = None + # TODO: replace by max_tokens + max_gen_len: int = None + presence_penalty: float = None + frequency_penalty: float = None + n: int = None + stop: Union[str, List[str]] = None + tools: Optional[List[Tools]] = None + tool_choice: Union[Literal["none", "auto"], ToolChoice] = "auto" + # TODO: Implement support for the OpenAI API parameters + # stop: Optional[Union[str, List[str]]] = None + # max_tokens: Optional[int] + # logit_bias + # user: Optional[str] = None + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + completion_tokens: Optional[int] = 0 + total_tokens: int = 0 + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[ChatCompletionResponseChoice] + # TODO: Implement support for the following fields + usage: Optional[UsageInfo] = None + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[ChatCompletionResponseStreamChoice] + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[str]] + stream: Optional[bool] = False + temperature: float = None + repetition_penalty: float = None + top_p: float = None + mean_gen_len: int = None + # TODO: replace by max_tokens + max_gen_len: int = None + presence_penalty: float = None + frequency_penalty: float = None + n: int = None + stop: Union[str, List[str]] = None + # TODO: Implement support for the OpenAI API parameters + # suffix + # logprobs + # echo + # best_of + # logit_bias + # user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + finish_reason: Optional[Literal["stop", "length"]] = None + # TODO: logprobs support + logprobs: Optional[int] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text.completion" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[CompletionResponseStreamChoice] + + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: Optional[str] = None + usage: UsageInfo + + +class VisualStudioCodeCompletionParameters(BaseModel): + temperature: float = None + top_p: float = None + max_new_tokens: int = None + + +class VisualStudioCodeCompletionRequest(BaseModel): + inputs: str + parameters: VisualStudioCodeCompletionParameters + + +class VisualStudioCodeCompletionResponse(BaseModel): + generated_text: str diff --git a/python/mlc_chat/libinfo.py b/python/mlc_chat/libinfo.py new file mode 100644 index 0000000..4c36cab --- /dev/null +++ b/python/mlc_chat/libinfo.py @@ -0,0 +1,70 @@ +"""Library information. This is a standalone file that can be used to get various info""" +#! pylint: disable=protected-access +import os +import sys + +__version__ = "0.1.dev0" +MLC_LIBRARY_PATH = os.environ.get("MLC_LIBRARY_PATH", None) + + +def get_env_paths(env_var, splitter): + """Get path in env variable""" + if os.environ.get(env_var, None): + return [p.strip() for p in os.environ[env_var].split(splitter)] + return [] + + +def get_dll_directories(): + """Get extra mlc llm dll directories""" + curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..")) + dll_path = [ + curr_dir, + os.path.join(source_dir, "build"), + os.path.join(source_dir, "build", "Release"), + ] + if MLC_LIBRARY_PATH: + dll_path.append(MLC_LIBRARY_PATH) + if "CONDA_PREFIX" in os.environ: + dll_path.append(os.path.join(os.environ["CONDA_PREFIX"], "lib")) + if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): + dll_path.extend(get_env_paths("LD_LIBRARY_PATH", ":")) + elif sys.platform.startswith("darwin"): + dll_path.extend(get_env_paths("DYLD_LIBRARY_PATH", ":")) + elif sys.platform.startswith("win32"): + dll_path.extend(get_env_paths("PATH", ";")) + return [os.path.abspath(p) for p in dll_path if os.path.isdir(p)] + + +def find_lib_path(name, optional=False): + """Find mlc llm library + + Parameters + ---------- + name : str + The name of the library + + optional: boolean + Whether the library is required + """ + if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): + lib_name = f"lib{name}.so" + elif sys.platform.startswith("win32"): + lib_name = f"{name}.dll" + elif sys.platform.startswith("darwin"): + lib_name = f"lib{name}.dylib" + else: + lib_name = f"lib{name}.so" + + dll_paths = get_dll_directories() + lib_dll_path = [os.path.join(p, lib_name) for p in dll_paths] + lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] + if not lib_found: + if not optional: + message = ( + f"Cannot find libraries: {lib_name}\n" + + "List of candidates:\n" + + "\n".join(lib_dll_path) + ) + raise RuntimeError(message) + return lib_found diff --git a/python/mlc_chat/loader/__init__.py b/python/mlc_chat/loader/__init__.py new file mode 100644 index 0000000..cc8ba9c --- /dev/null +++ b/python/mlc_chat/loader/__init__.py @@ -0,0 +1,7 @@ +""" +A subpackage of the compiler that represents mapping between external parameters, quantized +parameters and parameters in MLC-defined models. +""" +from .huggingface_loader import HuggingFaceLoader +from .loader import LOADER, Loader +from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/loader/huggingface_loader.py b/python/mlc_chat/loader/huggingface_loader.py new file mode 100644 index 0000000..5334242 --- /dev/null +++ b/python/mlc_chat/loader/huggingface_loader.py @@ -0,0 +1,222 @@ +"""A weight loader for HuggingFace's PyTorch format""" +import gc +import json +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Optional, Tuple + +import numpy as np +from tqdm import tqdm +from tvm.runtime import Device, NDArray +from tvm.runtime.ndarray import array as as_ndarray + +from mlc_chat.support import logging +from mlc_chat.support.preshard import _sharded_param_name +from mlc_chat.support.style import bold + +from .mapping import ExternMapping, QuantizeMapping +from .stats import Stats +from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard + +logger = logging.getLogger(__name__) + + +class HuggingFaceLoader: # pylint: disable=too-few-public-methods + """A loader loading HuggingFace's PyTorch/SafeTensor format and converts them + to MLC's parameters. + + Attributes + ---------- + stats : Stats + Statistics of the loading process. + + extern_param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor. + + torch_to_path : Dict[str, Path] + A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it, + or the path meaning all parameters are stored in a single file. + + cached_files : Dict[Path, Dict[str, np.ndarray]] + A cache of the loaded files. The key is the path of the file, and the value is a mapping + from parameter name to the parameter value. + + quantize_param_map : Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters. + """ + + stats: Stats + cached_files: Dict[Path, Dict[str, np.ndarray]] + torch_to_path: Dict[str, Path] + extern_param_map: ExternMapping + quantize_param_map: Optional[QuantizeMapping] + + def __init__( + self, + path: Path, + extern_param_map: ExternMapping, + quantize_param_map: Optional[QuantizeMapping] = None, + ) -> None: + """Create a parameter loader from HuggingFace PyTorch format. + + Parameters + ---------- + path : pathlib.Path + Path to either a JSON indexing file, or a PyTorch bin file. + 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` + or `model.safetensors.index.json` in the repo, which contains a `weight_map` that + maps each PyTorch parameter to the file containing the weight. + 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo, + which contains all the parameters. + 3) For safetensor file, it is usually `model.safetensors` in the repo, + which contains all the parameters. + + extern_param_map : ExternMapping + Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. + + quantize_param_map: Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters, default to None, which + means no quantization. + """ + assert path.is_file(), f"Path {path} is not a file" + self.stats = Stats() + self.extern_param_map = extern_param_map + self.cached_files = {} + self.torch_to_path = {} + self.quantize_param_map = quantize_param_map + if path.suffix in (".bin", ".safetensors", ".pt"): + self._load_file(path) + for name in self.cached_files[path].keys(): + self.torch_to_path[name] = path + elif path.suffix == ".json": + with path.open("r", encoding="utf-8") as in_file: + torch_weight_map = json.load(in_file)["weight_map"] + for torch_name, path_str in torch_weight_map.items(): + self.torch_to_path[torch_name] = path.parent / path_str + else: + raise FileNotFoundError(f"Unknown file suffix: {path}") + check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) + + def load( + self, device: Device, preshard_funcs: Dict[str, Callable] = None + ) -> Iterator[Tuple[str, NDArray]]: + """Load the parameters and yield the MLC parameter and its value. + + Parameters + ---------- + device : Optional[Device] + The device to store the parameter, default to None, which means using CPU. + + Yields + ------ + Tuple[str, NDArray] + The MLC parameter name and its value, quantized if quantization mapping is provided. + """ + mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) + for mlc_name in tqdm(mlc_names): + param = self._load_mlc_param(mlc_name, device=device) + if preshard_funcs is not None and mlc_name in preshard_funcs: + sharded_params = preshard_funcs[mlc_name](param) + for i, sharded_param in enumerate(sharded_params): + sharded_name = _sharded_param_name(mlc_name, i) + yield from self._load_or_quantize(sharded_name, sharded_param, device) + else: + yield from self._load_or_quantize(mlc_name, param, device) + + cached_files = list(self.cached_files.keys()) + for path in cached_files: + self._unload_file(path) + self.stats.log_time_info("HF") + self.stats.log_mem_usage() + + def _load_mlc_param(self, mlc_name: str, device: Optional[Device]) -> NDArray: + torch_names = self.extern_param_map.param_map[mlc_name] + files_required = {self.torch_to_path[p] for p in torch_names} + files_existing = set(self.cached_files.keys()) + files_to_load = files_required - files_existing + files_to_unload = files_existing - files_required + + # Step 1. When there is some file to unloaded: + # - If no pending file load: unloading is deferred as there is no gain in peak memory usage; + # - Need to load files: unload immediately to save memory and make space for the new files. + if files_to_load: + for path in files_to_unload: + self._unload_file(path) + # Step 2. Load all the files needed + for path in files_to_load: + self._load_file(path) + # Step 3. Collect all torch parameters in order + torch_params = [self.cached_files[self.torch_to_path[i]][i] for i in torch_names] + # Step 4. Apply the mapping function + with self.stats.timer("map_time_sec"): + param = self.extern_param_map.map_func[mlc_name](*torch_params) + if device: + return as_ndarray(param, device=device) + return as_ndarray(param) + + def _load_or_quantize(self, mlc_name, param, device: Device): + if self.quantize_param_map and mlc_name in self.quantize_param_map.param_map: + with self.stats.timer("quant_time_sec"): + q_names = self.quantize_param_map.param_map[mlc_name] + q_params = self.quantize_param_map.map_func[mlc_name](param) + device.sync() + for q_name, q_param in zip(q_names, q_params): + logger.info( + '[Quantized] Parameter: "%s", shape: %s, dtype: %s', + bold(q_name), + q_param.shape, + q_param.dtype, + ) + yield q_name, q_param + else: + logger.info( + '[Not quantized] Parameter: "%s", shape: %s, dtype: %s', + bold(mlc_name), + param.shape, + param.dtype, + ) + device.sync() + yield mlc_name, param + + def _load_file(self, path: Path) -> None: + logger.info("Loading HF parameters from: %s", path) + load_func = load_safetensor_shard if path.suffix == ".safetensors" else load_torch_shard + with self.stats.timer("load_time_sec"): + result = {} + for name, param in load_func(path): + result[name] = param + self.stats.mem_add(param.nbytes) + if name not in self.extern_param_map.unused_params: + self.stats.total_param_num += param.size + self.cached_files[path] = result + + def _unload_file(self, path: Path) -> None: + logger.info("Unloading HF weight file: %s", path) + with self.stats.timer("load_time_sec"): + for _, param in self.cached_files[path].items(): + self.stats.mem_rm(param.nbytes) + del self.cached_files[path] + gc.collect() + + +def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]: + # Step 1. Build a map from path to torch parameters + path_to_torch: Dict[Path, List[str]] = defaultdict(list) + for torch_name, path in torch_to_path.items(): + path_to_torch[path].append(torch_name) + # Step 2. Build a map from torch parameters to MLC parameters + torch_to_mlc = defaultdict(list) + for mlc_name, torch_names in param_map.param_map.items(): + for torch_name in torch_names: + torch_to_mlc[torch_name].append(mlc_name) + # Step 3. Construct the ordering that ensures file locality + order = OrderedDict() + for _, torch_names in path_to_torch.items(): + for torch_name in torch_names: + for mlc_name in torch_to_mlc[torch_name]: + if mlc_name not in order: + order[mlc_name] = 1 + return list(order.keys()) + + +__all__ = ["HuggingFaceLoader"] diff --git a/python/mlc_chat/loader/loader.py b/python/mlc_chat/loader/loader.py new file mode 100644 index 0000000..e4c397c --- /dev/null +++ b/python/mlc_chat/loader/loader.py @@ -0,0 +1,12 @@ +"""A centralized registry of all existing loaders.""" +from typing import Any, Dict + +from .huggingface_loader import HuggingFaceLoader + +Loader = Any + +LOADER: Dict[str, Any] = { + "huggingface-torch": HuggingFaceLoader, + "huggingface-safetensor": HuggingFaceLoader, + "awq": HuggingFaceLoader, +} diff --git a/python/mlc_chat/loader/mapping.py b/python/mlc_chat/loader/mapping.py new file mode 100644 index 0000000..26d6811 --- /dev/null +++ b/python/mlc_chat/loader/mapping.py @@ -0,0 +1,101 @@ +"""Parameter mapping for converting different LLM implementations to MLC LLM.""" +import dataclasses +from typing import Callable, Dict, List, Set, Union + +import numpy as np +from tvm.runtime import NDArray + +MapFuncVariadic = Union[ + Callable[[], np.ndarray], + Callable[[np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], +] + + +@dataclasses.dataclass +class ExternMapping: + """Mapping from a parameter name in MLC LLM's model definition to its potential source, + for example, from MLC parameter "model.layers.2.post_attention_layernorm.weight" to PyTorch's + parameter correspondingly. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its source. For example, + in Llama2, the source of MLC parameter "model.layers.0.self_attn.qkv_proj.weight" from + huggingface torch are: + + - "model.layers.0.self_attn.q_proj.weight" + - "model.layers.0.self_attn.k_proj.weight" + - "model.layers.0.self_attn.v_proj.weight" + + map_func : Dict[str, Callable[[np.ndarray, ...], np.ndarray]] + A dictionary that maps the name of a parameter to a function that combines the source + parameters into the MLC parameter. For example, for the above example, the function + would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`. + + unused_params : Set[str] + Parameter names in the source weights that are not used in the MLC LLM model definition. + """ + + param_map: Dict[str, List[str]] = dataclasses.field(default_factory=dict) + map_func: Dict[str, MapFuncVariadic] = dataclasses.field(default_factory=dict) + unused_params: Set[str] = dataclasses.field(default_factory=set) + + def add_mapping( + self, + map_from: str, + map_to: List[str], + func: MapFuncVariadic, + ) -> None: + """Add a mapping from MLC parameters to source parametes as well as a mapping function.""" + self.param_map[map_from] = map_to + self.map_func[map_from] = func + + def add_unused(self, name: str): + """Add a parameter name in the source parameters to the set of unused parameters.""" + self.unused_params.add(name) + + +@dataclasses.dataclass +class QuantizeMapping: + """Mapping from a parameter in MLC LLM's model definition to its eventual names and values after + quantization. In certain group quantization, for example, `qkv_proj.weight` is mapped to + `qkv_proj.weight_quantized` and `qkv_proj.weight_scale` respectively. If a parameter's name is + not in the mapping, it is assumed to be unchanged, i.e. not quantized. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its destination. For example, + in certain group quantization, the destinations of MLC parameter "qkv_proj.weight` are: + + - "qkv_proj.weight_quantized" + - "qkv_proj.weight_scale" + + map_func : Dict[str, Callable[NDArray, List[NDArray]]] + A dictionary that maps the name of a parameter to a function that splits the MLC parameter + into the destination parameters. + + Notes + ----- + There are two forms of weight conversion in MLC LLM, one is A) on-the-fly quantization to the + raw fp16/bf16/fp32 weights from HuggingFace, and the other is B) loading pre-quantized weights + from an external framework, e.g. AutoGPTQ, AutoAWQ. From the perspective of parameter + correspondence. + + - In case A), it is recommended that the weight loader take both `ExternMapping` and + `QuantizeMapping` as input, and do quantiaztion on the fly as a raw parameter being + loaded into RAM; + - In case B), a pass over `nn.Module` is recommended to take place first to converts parameters + from its non-quantized form to the quantized one, and then only `ExternMapping` is + used to convert the quantized parameters into the desired form. + """ + + param_map: Dict[str, List[str]] + map_func: Dict[str, Callable[[NDArray], List[NDArray]]] + + +__all__ = ["ExternMapping", "QuantizeMapping"] diff --git a/python/mlc_chat/loader/stats.py b/python/mlc_chat/loader/stats.py new file mode 100644 index 0000000..6a97cf9 --- /dev/null +++ b/python/mlc_chat/loader/stats.py @@ -0,0 +1,95 @@ +"""Statistics of the loading process of parameter loaders""" +import dataclasses +import time +from contextlib import contextmanager + +from mlc_chat.support import logging +from mlc_chat.support.style import green + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Stats: + """Statistics of the loading process of parameter loaders. + + Attributes + ---------- + load_time_sec : float + Time used in loading the parameters. + + map_time_sec : float + Time used in applying the mapping function, i.e. `ExternMapping.map_func`. + + quant_time_sec : float + Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`. + + current_memory_gb : float + The current RAM usage in GB. + + total_memory_gb : float + The total size data loaded from disk in GB. + + max_memory_gb : float + The maximum RAM usage in GB. + + total_param_num: int + Total number of parameters (original non-MLC model weights), excluding unused params. + """ + + load_time_sec: float = 0.0 + map_time_sec: float = 0.0 + quant_time_sec: float = 0.0 + + current_memory_gb: float = 0.0 + total_memory_gb: float = 0.0 + max_memory_gb: float = 0.0 + + total_param_num: int = 0 + + def timer(self, attr): + """A context manager to time the scope and add the time to the attribute.""" + + @contextmanager + def timed_scope(): + start_time = time.time() + yield + elapsed_time = time.time() - start_time + setattr(self, attr, getattr(self, attr) + elapsed_time) + + return timed_scope() + + def mem_add(self, nbytes: int): + """Add the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb += mem_gb + self.total_memory_gb += mem_gb + self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb) + + def mem_rm(self, nbytes: int): + """Remove the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb -= mem_gb + + def log_time_info(self, weight_format: str): + """Log the time used in loading, pre-quantization and quantization.""" + logger.info( + "%s: " + "%s loading: %.3f sec; " + "Pre-quantization mapping: %.3f sec; " + "Quantization: %.3f sec", + green("Time usage"), + weight_format, + self.load_time_sec, + self.map_time_sec, + self.quant_time_sec, + ) + + def log_mem_usage(self): + """Log the Memory usage information.""" + logger.info( + "%s: Peak RAM: %.3f GB. Total bytes loaded from disk: %.3f GB", + green("RAM usage"), + self.max_memory_gb, + self.total_memory_gb, + ) diff --git a/python/mlc_chat/loader/utils.py b/python/mlc_chat/loader/utils.py new file mode 100644 index 0000000..b35f9a9 --- /dev/null +++ b/python/mlc_chat/loader/utils.py @@ -0,0 +1,66 @@ +"""Common utilities for loading parameters""" +# pylint: disable=too-few-public-methods +from pathlib import Path +from typing import TYPE_CHECKING, Iterator, Set, Tuple + +import numpy as np + +from mlc_chat.support import logging + +if TYPE_CHECKING: + from tvm.runtime import NDArray + + from .mapping import ExternMapping + + +logger = logging.getLogger(__name__) + + +def check_parameter_usage(param_map: "ExternMapping", extern_weights: Set[str]): + """Check that all external parameters have been used and are stored in the weights file.""" + used_extern_names = set(sum(param_map.param_map.values(), [])) + # Check 1. All extern parameters in the weight files are used unless explicitly specified + unused_extern_names = extern_weights - used_extern_names - param_map.unused_params + if unused_extern_names: + logger.warning( + "Unused extern parameters: %s", + ", ".join(sorted(unused_extern_names)), + ) + # Check 2. All extern parameters required are stored in the weight files + nonexistent_extern_names = used_extern_names - extern_weights + if nonexistent_extern_names: + raise ValueError( + "The following extern parameters do not exist in the weight files:\n " + + "\n ".join(sorted(nonexistent_extern_names)), + ) + + +def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield PyTorch format parameters.""" + import torch # pylint: disable=import-outside-toplevel + + for name, param in torch.load(path, map_location=torch.device("cpu")).items(): + if param is None: + logger.warning("Encountered None param, skipping it: %s", name) + continue + param = param.detach().cpu() + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + param = param.numpy() + yield name, param + + +def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield SafeTensor format parameters.""" + import safetensors # pylint: disable=import-outside-toplevel,import-error + + with safetensors.safe_open(path, framework="pt", device="cpu") as in_file: + for name in in_file.keys(): + param = in_file.get_tensor(name) + param = param.detach().cpu() + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + param = param.numpy() + yield name, param diff --git a/python/mlc_chat/model/__init__.py b/python/mlc_chat/model/__init__.py new file mode 100644 index 0000000..d7b0baa --- /dev/null +++ b/python/mlc_chat/model/__init__.py @@ -0,0 +1,3 @@ +"""Model definition for the compiler.""" +from .model import MODELS, Model +from .model_preset import MODEL_PRESETS diff --git a/python/mlc_chat/model/baichuan/__init__.py b/python/mlc_chat/model/baichuan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/baichuan/baichuan_loader.py b/python/mlc_chat/model/baichuan/baichuan_loader.py new file mode 100644 index 0000000..01b8528 --- /dev/null +++ b/python/mlc_chat/model/baichuan/baichuan_loader.py @@ -0,0 +1,70 @@ +""" +This file specifies how MLC's StableLM parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .baichuan_model import BaichuanConfig, BaichuanForCausalLM + + +def huggingface(model_config: BaichuanConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPT2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = BaichuanForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/baichuan/baichuan_model.py b/python/mlc_chat/model/baichuan/baichuan_model.py new file mode 100644 index 0000000..5bcedd4 --- /dev/null +++ b/python/mlc_chat/model/baichuan/baichuan_model.py @@ -0,0 +1,252 @@ +""" +Implementation for BAICHUAN architecture. +TODO: add docstring +""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Baichuan model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + initializer_range: float + intermediate_size: int + rms_norm_eps: float + use_cache: bool + pad_token_id: int + bos_token_id: int + eos_token_id: int + tie_word_embeddings: bool + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + assert self.tensor_parallel_shards == 1, "Baichuan currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring + + +class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: BaichuanConfig): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.context_window_size + + self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + + def forward( # pylint: disable=too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, + ): + d, h, t = self.head_dim, self.num_heads, total_seq_len + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + # Step 1. QKV Projection + qkv = self.W_pack(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + # Step 2. Apply QK rotary embedding + q, k, v = op_ext.llama_rope(qkv, t, 10000, h, h) + # Step 3. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + # Step 5. Apply output projection + return self.o_proj(output) + + +class BaichuanMLP(nn.Module): + def __init__(self, config: BaichuanConfig): + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * config.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class BaichuanDecoderLayer(nn.Module): + def __init__(self, config: BaichuanConfig): + norm_eps = config.rms_norm_eps + self.self_attn = BaichuanAttention(config=config) + self.mlp = BaichuanMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + +class BaichuanModel(nn.Module): + def __init__(self, config: BaichuanConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [BaichuanDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + hidden_states = self.embed_tokens(input_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class BaichuanForCausalLM(nn.Module): + def __init__(self, config: BaichuanConfig): + self.model = BaichuanModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(inputs, total_seq_len, attention_mask) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/baichuan/baichuan_quantization.py b/python/mlc_chat/model/baichuan/baichuan_quantization.py new file mode 100644 index 0000000..2558942 --- /dev/null +++ b/python/mlc_chat/model/baichuan/baichuan_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's Baichuan parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .baichuan_model import BaichuanConfig, BaichuanForCausalLM + + +def group_quant( + model_config: BaichuanConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BaichuanLM-architecture model using group quantization.""" + model: nn.Module = BaichuanForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: BaichuanConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BaichuanLM-architecture model using FasterTransformer quantization.""" + model: nn.Module = BaichuanForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: BaichuanConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BaichuanLM model without quantization.""" + model: nn.Module = BaichuanForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/gemma/__init__.py b/python/mlc_chat/model/gemma/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/gemma/gemma_loader.py b/python/mlc_chat/model/gemma/gemma_loader.py new file mode 100644 index 0000000..c839978 --- /dev/null +++ b/python/mlc_chat/model/gemma/gemma_loader.py @@ -0,0 +1,121 @@ +""" +This file specifies how MLC's Gemma parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .gemma_model import GemmaConfig, GemmaForCausalLM + + +def huggingface(model_config: GemmaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GemmaConfig + The configuration of the Gemma model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = GemmaForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Modify RMS layernorm weights, since Gemma model adds 1 to the weights + # We add 1 to the weights here for efficiency purpose + mlc_name = f"model.layers.{i}.input_layernorm.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: (x + 1).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + mlc_name = f"model.layers.{i}.post_attention_layernorm.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: (x + 1).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + mlc_name = "model.norm.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: (x + 1).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/gemma/gemma_model.py b/python/mlc_chat/model/gemma/gemma_model.py new file mode 100644 index 0000000..0145589 --- /dev/null +++ b/python/mlc_chat/model/gemma/gemma_model.py @@ -0,0 +1,386 @@ +"""Implementation for Gemma architecture.""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Gemma model.""" + + hidden_size: int + hidden_act: str + intermediate_size: int + attention_bias: bool + num_attention_heads: int + num_key_value_heads: int + head_dim: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.hidden_act != "gelu": + raise ValueError("Only GeLU is supported as the activation for gemma.") + if self.attention_bias: + raise ValueError('Only "False" attention_bias is supported for gemma') + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + assert self.num_attention_heads % self.num_key_value_heads == 0 + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class GemmaEmbedding(nn.Embedding): + """The embedding module specialized for Gemma so that + it can be shared with the final lm_head. + """ + + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which transposes the weight and multiplies + with the input tensor. + """ + weight = nn.op.permute_dims(self.weight) + return nn.op.matmul(x, weight, out_dtype="float32") + + +class GemmaMLP(nn.Module): + def __init__(self, config: GemmaConfig): + super().__init__() + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.gelu(x1) * x2) + + +class GemmaAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GemmaConfig): + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert ( + config.num_key_value_heads % config.tensor_parallel_shards == 0 + ), f"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards" + assert ( + config.num_key_value_heads >= config.tensor_parallel_shards + ), f"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}" + self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + in_features=self.num_q_heads * self.head_dim, + out_features=config.hidden_size, + bias=config.attention_bias, + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.o_proj(output) + + +class GemmaDecoderLayer(nn.Module): + def __init__(self, config: GemmaConfig): + rms_norm_eps = config.rms_norm_eps + self.self_attn = GemmaAttention(config) + self.mlp = GemmaMLP(config) + # Gemma RMSNorm adds 1 to the weights. It is already fused in the loader + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class GemmaModel(nn.Module): + def __init__(self, config: GemmaConfig): + self.hidden_size = config.hidden_size + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = GemmaEmbedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + input_embed = op.ccl_broadcast_from_worker0(input_embed) + hidden_states = input_embed + hidden_states = hidden_states * (self.hidden_size**0.5) + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GemmaForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GemmaConfig): + self.model = GemmaModel(config) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/gemma/gemma_quantization.py b/python/mlc_chat/model/gemma/gemma_quantization.py new file mode 100644 index 0000000..28b4234 --- /dev/null +++ b/python/mlc_chat/model/gemma/gemma_quantization.py @@ -0,0 +1,38 @@ +"""This file specifies how MLC's Gemma parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import GroupQuantize, NoQuantize + +from .gemma_model import GemmaConfig, GemmaForCausalLM + + +def group_quant( + model_config: GemmaConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Gemma-architecture model using group quantization.""" + model: nn.Module = GemmaForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GemmaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = GemmaForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/gpt2/__init__.py b/python/mlc_chat/model/gpt2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/gpt2/gpt2_loader.py b/python/mlc_chat/model/gpt2/gpt2_loader.py new file mode 100644 index 0000000..43c4ff1 --- /dev/null +++ b/python/mlc_chat/model/gpt2/gpt2_loader.py @@ -0,0 +1,79 @@ +""" +This file specifies how MLC's GPT-2 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .gpt2_model import GPT2Config, GPT2LMHeadModel + + +def huggingface(model_config: GPT2Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPT2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = GPT2LMHeadModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + mapping.add_mapping( + "lm_head.weight", + ["wte.weight"], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters["transformer.wte.weight"].dtype, + ), + ) + + for i in range(model_config.n_layer): + mapping.add_unused(f"h.{i}.attn.bias") + + # Transpose c_attn, c_proj and c_fc weights since GPT-2 uses Conv1D + for conv1d_weight_name in ["attn.c_attn", "attn.c_proj", "mlp.c_proj", "mlp.c_fc"]: + src_name = f"h.{i}.{conv1d_weight_name}.weight" + mlc_name = f"transformer.{src_name}" + mapping.add_mapping( + mlc_name, + [src_name], + functools.partial( + lambda x, dtype: x.transpose().astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + # transformer.h.0.attn.c_attn.weight --> h.0.attn.c_attn.weight + source_name = mlc_name.split(".", 1)[1] + mapping.add_mapping( + mlc_name, + [source_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_chat/model/gpt2/gpt2_model.py b/python/mlc_chat/model/gpt2/gpt2_model.py new file mode 100644 index 0000000..911f0dd --- /dev/null +++ b/python/mlc_chat/model/gpt2/gpt2_model.py @@ -0,0 +1,433 @@ +""" +Implementation for GPT-2 architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the GPT-2 model.""" + + vocab_size: int + n_embd: int + n_layer: int + n_head: int + layer_norm_epsilon: int + n_inner: int = -1 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + scale_attn_by_inverse_layer_idx: bool = False + tensor_parallel_shards: int = 1 + head_dim: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.n_inner is None or self.n_inner == -1: + self.n_inner = 4 * self.n_embd + if self.context_window_size == 0: + for name in ["n_positions", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `n_positions` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.head_dim == 0: + self.head_dim = self.n_embd // self.n_head + assert self.head_dim * self.n_head == self.n_embd + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals + + +class GPT2Attention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GPT2Config): + self.embed_dim = config.n_embd + self.num_heads = config.n_head // config.tensor_parallel_shards + self.head_dim = config.head_dim + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + + self.c_attn = nn.Linear( + in_features=self.embed_dim, + out_features=3 * self.num_heads * self.head_dim, + bias=True, + ) + self.c_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True) + + self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + + if self.scale_attn_by_inverse_layer_idx: + attn_score_scaling_factor = 1.0 / float(layer_id + 1) + else: + attn_score_scaling_factor = 1.0 + + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv( + layer_id, qkv, self.num_heads, attn_score_scaling_factor + ), + (b, s, h * d), + ) + return self.c_proj(output) + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + + if self.scale_attn_by_inverse_layer_idx: + attn_score_scaling_factor = 1.0 / float(layer_id + 1) + else: + attn_score_scaling_factor = 1.0 + + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv( + layer_id, qkv, self.num_heads, attn_score_scaling_factor + ), + (b, s, h * d), + ) + return self.c_proj(output) + + +class GPT2MLP(nn.Module): + def __init__(self, config: GPT2Config): + embed_dim = config.n_embd + intermediate_size = config.n_inner // config.tensor_parallel_shards + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + + def forward(self, hidden_states: Tensor): + hidden_states = self.c_fc(hidden_states) + hidden_states = op.gelu(hidden_states, approximate="tanh") + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + def __init__(self, config: GPT2Config): + hidden_size = config.n_embd + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(config) + + def _set_tp(): + def _set(param, hint): + param.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = k = v = self.attn.num_heads * hd + _set( + self.attn.c_attn.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set( + self.attn.c_attn.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.attn.c_proj.weight, tp.ShardSingleDim("_shard_attn_c_proj", dim=1)) + _set( + self.mlp.c_fc.weight, + tp.ShardSingleDim("_shard_c_fc_weight", dim=0), + ) + _set(self.mlp.c_fc.bias, tp.ShardSingleDim("_shard_c_fc_bias", dim=0)) + _set(self.mlp.c_proj.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + with tp.shard_bias(self.attn.c_proj, self.tensor_parallel_shards), tp.shard_bias( + self.mlp.c_proj, self.tensor_parallel_shards + ): + hidden_states = self._apply_residual( + self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id), hidden_states + ) + hidden_states = self._apply_residual(self.mlp(self.ln_2(hidden_states)), hidden_states) + + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + with tp.shard_bias(self.attn.c_proj, self.tensor_parallel_shards), tp.shard_bias( + self.mlp.c_proj, self.tensor_parallel_shards + ): + hidden_states = self._apply_residual( + self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id), + hidden_states, + ) + hidden_states = self._apply_residual(self.mlp(self.ln_2(hidden_states)), hidden_states) + + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") + return out + residual + + +class GPT2Model(nn.Module): + def __init__(self, config: GPT2Config): + assert config.n_embd % config.n_head == 0 + self.wte = nn.Embedding("vocab_size", config.n_embd) + self.wpe = nn.Embedding(config.context_window_size, config.n_embd) + self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + hidden_states = inputs + + # Position Embeddings + # Generate np.arange(offset, offset+seq_len) + # shape[1] indicates the total query length in the batch + input_positions = paged_kv_cache.get_query_positions(inputs.shape[1]) + pos_embd = self.wpe(input_positions) + + # Pass through GPT2Block + hidden_states = inputs + pos_embd + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + hidden_states = inputs + + # Position Embeddings + # Generate np.arange(offset, offset+seq_len) + # shape[1] indicates the total query length in the batch + input_positions = paged_kv_cache.get_query_positions(inputs.shape[1]) + pos_embd = self.wpe(input_positions) + + # Pass through GPT2Block + hidden_states = hidden_states + pos_embd + for layer_id, layer in enumerate(self.h): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GPT2Config): + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, "vocab_size", bias=False) + self.n_layer = config.n_layer + self.n_embed = config.n_embd + self.n_head = config.n_head + self.head_dim = config.head_dim + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.transformer.wte(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.n_layer, + num_attention_heads=self.n_head // self.tensor_parallel_shards, + num_key_value_heads=self.n_head // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NONE, + rope_scale=-1, + rope_theta=-1, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.n_embed], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.n_embed], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/gpt2/gpt2_quantization.py b/python/mlc_chat/model/gpt2/gpt2_quantization.py new file mode 100644 index 0000000..b953d8c --- /dev/null +++ b/python/mlc_chat/model/gpt2/gpt2_quantization.py @@ -0,0 +1,69 @@ +"""This file specifies how MLC's GPT-2 parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .gpt2_model import GPT2Config, GPT2LMHeadModel + + +def group_quant( + model_config: GPT2Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPT-2-architecture model using group quantization.""" + model: nn.Module = GPT2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: GPT2Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPT-2-architecture model using FasterTransformer quantization.""" + model: nn.Module = GPT2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: GPT2Config, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPT-2-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = GPT2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GPT2Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPT-2 model without quantization.""" + model: nn.Module = GPT2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/gpt_bigcode/__init__.py b/python/mlc_chat/model/gpt_bigcode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py new file mode 100644 index 0000000..8d479d3 --- /dev/null +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py @@ -0,0 +1,49 @@ +""" +This file specifies how MLC's GPTBigCode parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM + + +def huggingface(model_config: GPTBigCodeConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPTBigCodeConfig + The configuration of the GPTBigCode model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = GPTBigCodeForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + + return mapping diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py new file mode 100644 index 0000000..10a0291 --- /dev/null +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py @@ -0,0 +1,289 @@ +""" +Implementation for GPTBigCode architecture. +TODO: add docstring +""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GPTBigCodeConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the GPTBigCode model.""" + + n_embd: int + n_inner: int + n_head: int + n_layer: int + n_positions: int + layer_norm_epsilon: float + vocab_size: int + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + if self.n_positions > 0: + self.context_window_size = self.n_positions + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold("n_positions"), + self.context_window_size, + ) + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.n_inner = config.n_inner // config.tensor_parallel_shards + self.c_fc = nn.Linear(in_features=config.n_embd, out_features=self.n_inner, bias=True) + self.c_proj = nn.Linear(in_features=self.n_inner, out_features=config.n_embd, bias=True) + + def forward(self, x: Tensor): + hidden_states = self.c_fc(x) + hidden_states = op.gelu(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GPTBigCodeConfig): + self.n_embd = config.n_embd + self.head_dim = config.n_embd // config.n_head + self.num_q_heads = config.n_head // config.tensor_parallel_shards + self.num_kv_heads = 1 + assert ( + config.tensor_parallel_shards == 1 + ), "GPT bigcode only support tensor parallel shards = 1" + self.c_attn = nn.Linear( + in_features=self.n_embd, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=True, + ) + self.c_proj = nn.Linear( + in_features=self.num_q_heads * self.head_dim, + out_features=config.n_embd, + bias=True, + ) + + self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) + + def forward( # pylint: disable=too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, + ): + d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + 2 * h_kv, d)) + q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2) + + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + return self.c_proj(output) + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.mlp = GPTBigCodeMLP(config) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.n_embd // config.n_head + q = config.n_head * hd + k = 1 * hd + v = 1 * hd + _set(self.attn.c_attn, tp.ShardSingleDim("_shard_c_attn", dim=0, segs=[q, k, v])) + _set(self.attn.c_proj, tp.ShardSingleDim("_shard_c_proj", dim=1)) + _set(self.mlp.c_fc, tp.ShardSingleDim("_shard_mlp_c_fc", dim=0)) + _set(self.mlp.c_proj, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + hidden_states = ( + self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len) + hidden_states + ) + hidden_states = self.mlp(self.ln_2(hidden_states)) + hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + assert config.n_embd % config.n_head == 0 + self.wte = nn.Embedding("vocab_size", config.n_embd) + self.wpe = nn.Embedding(config.n_positions, config.n_embd) + self.h = nn.ModuleList([GPTBigCodeBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + + # Token Embeddings + t_embd = self.wte(inputs) + + # Position Embeddings + # Generate np.arange(offset, offset+seq_len) + def _input_positions(inputs: te.Tensor, total_seq_len: tir.Var): + b, s = inputs.shape + offset = total_seq_len - s + return te.compute( + (b, s), lambda _, j: (offset + j).astype("int32"), name="input_positions" + ) + + input_positions = op.tensor_expr_op( + _input_positions, + name_hint="input_positions", + args=[inputs, total_seq_len], + ) + pos_embd = self.wpe(input_positions) + + # apply position embeddings + hidden_states = t_embd + pos_embd + for layer in self.h: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = self.ln_f(hidden_states) + + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + def __init__(self, config: GPTBigCodeConfig): + self.transformer = GPTBigCodeModel(config) + self.lm_head = nn.Linear(config.n_embd, "vocab_size", bias=False) + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(inputs, total_seq_len, attention_mask) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py new file mode 100644 index 0000000..021cc08 --- /dev/null +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py @@ -0,0 +1,70 @@ +"""This file specifies how MLC's GPTBigCode parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM + + +def group_quant( + model_config: GPTBigCodeConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode-architecture model using group quantization.""" + model: nn.Module = GPTBigCodeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: GPTBigCodeConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode-architecture model using FasterTransformer quantization.""" + model: nn.Module = GPTBigCodeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: GPTBigCodeConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = GPTBigCodeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GPTBigCodeConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode model without quantization.""" + model: nn.Module = GPTBigCodeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/gpt_neox/__init__.py b/python/mlc_chat/model/gpt_neox/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py b/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py new file mode 100644 index 0000000..b7e4027 --- /dev/null +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py @@ -0,0 +1,89 @@ +""" +This file specifies how MLC's GPTNeoX parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM + + +def huggingface(model_config: GPTNeoXConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPTNeoXConfig + The configuration of the GPTNeoX model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = GPTNeoXForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # inv_freq/masked_bias/bias is not used in the model + attn = f"gpt_neox.layers.{i}.attention" + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + mapping.add_unused(f"{attn}.masked_bias") + mapping.add_unused(f"{attn}.bias") + + # change the layout of query_key_value + def transform_qkv_layout(w, dtype): # pylint: disable=invalid-name + num_attention_heads = model_config.num_attention_heads + head_dim = model_config.head_dim + + org_shape = w.shape + w = np.reshape(w, [num_attention_heads, 3 * head_dim, -1]) + qkv = np.split(w, indices_or_sections=3, axis=1) + w = np.concatenate(qkv, axis=0) + w = np.reshape(w, org_shape) + return w.astype(dtype) + + qkv_proj = f"{attn}.query_key_value" + for param_name in ["weight", "bias"]: + mlc_name = f"{qkv_proj}.{param_name}" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + transform_qkv_layout, + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + if ".dense_h_to_4h.bias" in mlc_name or ".dense_4h_to_h.bias" in mlc_name: + param_dtype = model_config.ffn_out_dtype + else: + param_dtype = mlc_param.dtype + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=param_dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py new file mode 100644 index 0000000..130d824 --- /dev/null +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py @@ -0,0 +1,456 @@ +""" +Implementation for GPTNeoX architecture. +TODO: add docstring +""" + +import dataclasses +import logging +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GPTNeoXConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the GPTNeoX model.""" + + use_parallel_residual: bool + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + layer_norm_eps: float + vocab_size: int + rotary_pct: float + position_embedding_base: int = 0 + context_window_size: int = 0 + head_dim: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + ffn_out_dtype: str = "float32" + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class GPTNeoXAttention(nn.Module): # pylint: disable=too-many-instance-attributes + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GPTNeoXConfig): + self.rope_theta = config.position_embedding_base + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim + self.query_key_value = nn.Linear( + in_features=self.hidden_size, + out_features=3 * self.num_attention_heads * self.head_dim, + bias=True, + ) + self.dense = nn.Linear( + self.num_attention_heads * self.head_dim, self.hidden_size, bias=True + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + # hidden_states: [batch_size, seq_len, hidden_size] + batch_size, seq_len, _ = hidden_states.shape + + # q/k/v states: [batch_size, seq_len, hidden_size] + qkv = self.query_key_value(hidden_states) + qkv = op.reshape(qkv, (batch_size, seq_len, 3 * self.num_attention_heads, self.head_dim)) + + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_attention_heads), + (batch_size, seq_len, self.head_dim * self.num_attention_heads), + ) + attn_output = self.dense(output) + return attn_output + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + # hidden_states: [batch_size, seq_len, hidden_size] + batch_size, seq_len, _ = hidden_states.shape + + # q/k/v states: [batch_size, seq_len, hidden_size] + qkv = self.query_key_value(hidden_states) + qkv = op.reshape(qkv, (batch_size, seq_len, 3 * self.num_attention_heads, self.head_dim)) + + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_attention_heads), + (batch_size, seq_len, self.head_dim * self.num_attention_heads), + ) + attn_output = self.dense(output) + return attn_output + + +class GPTNeoXMLP(nn.Module): + def __init__(self, config: GPTNeoXConfig): + super().__init__() + out_dtype = config.ffn_out_dtype + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + self.intermediate_size, + out_dtype=out_dtype, + ) + self.dense_4h_to_h = nn.Linear( + self.intermediate_size, + config.hidden_size, + out_dtype=out_dtype, + ) + + def forward(self, hidden_states: Tensor): + dtype = hidden_states.dtype + if hidden_states.dtype != dtype: + hidden_states = hidden_states.astype(dtype) + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = op.gelu(hidden_states) + if hidden_states.dtype != dtype: + hidden_states = hidden_states.astype(dtype) + hidden_states = self.dense_4h_to_h(hidden_states) + if hidden_states.dtype != dtype: + hidden_states = hidden_states.astype(dtype) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__(self, config: GPTNeoXConfig): + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config) + self.mlp = GPTNeoXMLP(config) + self.use_parallel_residual = config.use_parallel_residual + + def _set_tp(): + def _set(param, hint): + param.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = k = v = self.attention.num_attention_heads * hd + _set( + self.attention.query_key_value.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set( + self.attention.query_key_value.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.attention.dense.weight, tp.ShardSingleDim("_shard_dense", dim=1)) + _set( + self.mlp.dense_h_to_4h.weight, + tp.ShardSingleDim("_shard_dense_h_to_4h_weight", dim=0), + ) + _set(self.mlp.dense_h_to_4h.bias, tp.ShardSingleDim("_shard_dense_h_to_4h_bias", dim=0)) + _set(self.mlp.dense_4h_to_h.weight, tp.ShardSingleDim("_shard_dense_4h_to_h", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + dtype = hidden_states.dtype + attn_input = self.input_layernorm(hidden_states) + with tp.shard_bias(self.attention.dense, self.tensor_parallel_shards): + attn_output = self.attention( + attn_input, + paged_kv_cache, + layer_id, + ) + if self.use_parallel_residual: + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + attn_output = self._apply_residual(attn_output, hidden_states) + mlp_input = self.post_attention_layernorm(attn_output) + with tp.shard_bias(self.mlp.dense_4h_to_h, self.tensor_parallel_shards): + mlp_output = self.mlp(mlp_input) + hidden_states = self._apply_residual(mlp_output.astype(dtype), attn_output) + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + dtype = hidden_states.dtype + attn_input = self.input_layernorm(hidden_states) + with tp.shard_bias(self.attention.dense, self.tensor_parallel_shards): + attn_output = self.attention.batch_forward( + attn_input, + paged_kv_cache, + layer_id, + ) + if self.use_parallel_residual: + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + attn_output = self._apply_residual(attn_output, hidden_states) + mlp_input = self.post_attention_layernorm(attn_output) + with tp.shard_bias(self.mlp.dense_4h_to_h, self.tensor_parallel_shards): + mlp_output = self.mlp(mlp_input) + hidden_states = self._apply_residual(mlp_output.astype(dtype), attn_output) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") + return out + residual + + +class GPTNeoXModel(nn.Module): + def __init__(self, config: GPTNeoXConfig): + self.embed_in = nn.Embedding(num="vocab_size", dim=config.hidden_size) + self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + hidden_states = inputs + + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + hidden_states = inputs + + for layer_id, layer in enumerate(self.layers): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class GPTNeoXForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GPTNeoXConfig): + self.gpt_neox = GPTNeoXModel(config) + self.embed_out = nn.Linear( + in_features=config.hidden_size, + out_features="vocab_size", + bias=False, + dtype="float32", + ) + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + self.rotary_pct = config.rotary_pct + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.gpt_neox.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.embed_out(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.gpt_neox.embed_in(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.gpt_neox(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.embed_out(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.gpt_neox(input_embed, paged_kv_cache) + logits = self.embed_out(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + rotary_dim=int(self.head_dim * self.rotary_pct), + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py b/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py new file mode 100644 index 0000000..9f1daaf --- /dev/null +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's GPTNeoX parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM + + +def group_quant( + model_config: GPTNeoXConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTNeoX-architecture model using group quantization.""" + model: nn.Module = GPTNeoXForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: GPTNeoXConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTNeoX-architecture model using FasterTransformer quantization.""" + model: nn.Module = GPTNeoXForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GPTNeoXConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTNeoX model without quantization.""" + model: nn.Module = GPTNeoXForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/llama/__init__.py b/python/mlc_chat/model/llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/llama/llama_loader.py b/python/mlc_chat/model/llama/llama_loader.py new file mode 100644 index 0000000..5dd902d --- /dev/null +++ b/python/mlc_chat/model/llama/llama_loader.py @@ -0,0 +1,171 @@ +""" +This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .llama_model import LlamaConfig, LlamaForCasualLM +from .llama_quantization import awq_quant + + +def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = LlamaForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate( + [gate, up], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py new file mode 100644 index 0000000..6da1d42 --- /dev/null +++ b/python/mlc_chat/model/llama/llama_model.py @@ -0,0 +1,400 @@ +""" +Implementation for Llama2 architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + num_key_value_heads: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.num_attention_heads % self.num_key_value_heads == 0 + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class LlamaFFN(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class LlamaAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: LlamaConfig): + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert ( + config.num_key_value_heads % config.tensor_parallel_shards == 0 + ), f"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards" + assert ( + config.num_key_value_heads >= config.tensor_parallel_shards + ), f"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}" + self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.o_proj(output) + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.o_proj(output) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + rms_norm_eps = config.rms_norm_eps + self.self_attn = LlamaAttention(config) + self.mlp = LlamaFFN(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn.batch_forward( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class LlamaModel(nn.Module): + def __init__(self, config: LlamaConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + input_embed = op.ccl_broadcast_from_worker0(input_embed) + hidden_states = input_embed + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + def batch_forward(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + input_embeds = op.ccl_broadcast_from_worker0(input_embeds) + hidden_states = input_embeds + for layer_id, layer in enumerate(self.layers): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: LlamaConfig): + self.model = LlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/llama/llama_quantization.py b/python/mlc_chat/model/llama/llama_quantization.py new file mode 100644 index 0000000..0460c98 --- /dev/null +++ b/python/mlc_chat/model/llama/llama_quantization.py @@ -0,0 +1,69 @@ +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .llama_model import LlamaConfig, LlamaForCasualLM + + +def group_quant( + model_config: LlamaConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using group quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: LlamaConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using FasterTransformer quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: LlamaConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: LlamaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/mistral/__init__.py b/python/mlc_chat/model/mistral/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/mistral/mistral_loader.py b/python/mlc_chat/model/mistral/mistral_loader.py new file mode 100644 index 0000000..71a8f1a --- /dev/null +++ b/python/mlc_chat/model/mistral/mistral_loader.py @@ -0,0 +1,165 @@ +""" +This file specifies how MLC's Mistral parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .mistral_model import MistralConfig, MistralForCasualLM +from .mistral_quantization import awq_quant + + +def huggingface(model_config: MistralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MistralConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MistralForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: MistralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : MistralConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_chat/model/mistral/mistral_model.py b/python/mlc_chat/model/mistral/mistral_model.py new file mode 100644 index 0000000..d2b5c57 --- /dev/null +++ b/python/mlc_chat/model/mistral/mistral_model.py @@ -0,0 +1,528 @@ +""" +Implementation for Mistral architecture. +""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm import relax as rx +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MistralConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Mistral model.""" + + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + position_embedding_base: int = 0 + num_key_value_heads: int = 0 + head_dim: int = 0 + sliding_window_size: int = 4096 + prefill_chunk_size: int = 0 + attention_sink_size: int = 4 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.attention_sink_size >= 0 + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("sliding_window_size"), + self.sliding_window_size, + ) + self.prefill_chunk_size = self.sliding_window_size + elif self.prefill_chunk_size > self.sliding_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.sliding_window_size, + bold("sliding_window_size"), + ) + self.prefill_chunk_size = self.sliding_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class RotaryEmbedding(nn.Module): + """Cache relative Rotary Embedding.""" + + def __init__(self, config: MistralConfig): + super().__init__() + self.head_dim = config.head_dim + self.position_embedding_base = config.position_embedding_base + + def forward(self, q: Tensor, k: Tensor, q_offset: tir.Var): + def te_op(x: te.Tensor, offset: tir.Var): + dtype = x.dtype + + def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): + head_dim = tir.const(self.head_dim, "int32") + position_embedding_base = tir.const(self.position_embedding_base, "float32") + freq = tir.power( + position_embedding_base, + (d * 2 % head_dim).astype("float32") / head_dim, + ) + freq = (offset + s) / freq + cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] + sin = tir.sin(freq).astype(dtype) * tir.if_then_else( + d < head_dim // 2, + -x[b, s, h, d + head_dim // 2], + x[b, s, h, d - head_dim // 2], + ) + return cos + sin + + return te.compute(x.shape, compute, name="rotary") + + q_embed = op.tensor_expr_op( + te_op, + "rotary_embedding", + args=[q, q_offset], + attrs={"mlc.rotary_embedding_to_all_dims": True}, + ) + k_embed = op.tensor_expr_op( + te_op, "rotary_embedding", args=[k, 0], attrs={"mlc.rotary_embedding_to_all_dims": True} + ) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + """Same as in Llama architecture (LlamaFFN).""" + + def __init__(self, config: MistralConfig): + super().__init__() + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class MistralAttention(nn.Module): # pylint: disable=too-many-instance-attributes + """Same as LlamaAttention, but with sliding window attention using a rolling buffer cache.""" + + def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): + self.rotary_embedding = rotary_embedding + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.sliding_window_size = config.sliding_window_size + self.attention_sink_size = config.attention_sink_size + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + self.k_cache = RollingKVCacheWithSinks( + self.sliding_window_size, [self.num_kv_heads, self.head_dim] + ) + self.v_cache = RollingKVCacheWithSinks( + self.sliding_window_size, [self.num_kv_heads, self.head_dim] + ) + + def interleave_kv( # pylint: disable=too-many-arguments,too-many-locals + self, + k_cur: Tensor, + v_cur: Tensor, + kv_seq_len: tir.Var, + rolling_cache_len: tir.Var, + cache_offset: tir.Var, + ): + """Unrotate and concatenate currunt and cached k and v""" + h_kv, d = self.num_kv_heads, self.head_dim + kv_s, c, o = kv_seq_len, rolling_cache_len, cache_offset + b = k_cur.shape[0] + + k_cached = op.reshape(self.k_cache.view(c), (b, c, h_kv, d)) + v_cached = op.reshape(self.v_cache.view(c), (b, c, h_kv, d)) + + def _cache_unrotate(x_cached, rolling_cache_len, cache_offset): + return te.compute( + (b, kv_s, h_kv, d), + lambda xb, xs, xh, xd: te.if_then_else( + xs < self.attention_sink_size, + x_cached[xb, xs, xh, xd], + te.if_then_else( + xs < rolling_cache_len - cache_offset + self.attention_sink_size, + x_cached[xb, xs + cache_offset - self.attention_sink_size, xh, xd], + x_cached[xb, xs + cache_offset - rolling_cache_len, xh, xd], + ), + ), + name="cache_unrotate_te", + ) + + def _cache_cur_concat(x_cached, x_cur, rolling_cache_len): + return te.compute( + (b, kv_s, h_kv, d), + lambda xb, xs, xh, xd: te.if_then_else( + xs < rolling_cache_len, + x_cached[xb, xs, xh, xd], + x_cur[xb, xs - rolling_cache_len, xh, xd], + ), + name="cache_cur_concat_te", + ) + + k_cached = op.tensor_expr_op( + _cache_unrotate, + name_hint="te_cache_unrotate_key", + args=[k_cached, c, o], + ) + k = op.tensor_expr_op( + _cache_cur_concat, + name_hint="te_cache_cur_concat_key", + args=[k_cached, k_cur, c], + ) + + v_cached = op.tensor_expr_op( + _cache_unrotate, + name_hint="te_cache_unrotate_value", + args=[v_cached, c, o], + ) + v = op.tensor_expr_op( + _cache_cur_concat, + name_hint="te_cache_cur_concat_value", + args=[v_cached, v_cur, c], + ) + + self.k_cache.override( + op.squeeze(k_cur, axis=0), self.sliding_window_size, self.attention_sink_size + ) + self.v_cache.override( + op.squeeze(v_cur, axis=0), self.sliding_window_size, self.attention_sink_size + ) + + return k, v + + def forward( # pylint: disable=too-many-arguments, too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + rolling_cache_len: tir.Var, # Number of elements currently in the cache. + kv_seq_len: tir.Var, # Equals to ``seq_len + rolling_cache_len``. + cache_offset: tir.Var, + ): + """Forward pass of MistralAttention, performing QKV.""" + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + qkv_cur = self.qkv_proj(hidden_states) + qkv_cur = op.reshape(qkv_cur, (b, s, h_q + 2 * h_kv, d)) + q, k_cur, v_cur = op.split(qkv_cur, [h_q, h_q + h_kv], axis=2) + k, v = self.interleave_kv(k_cur, v_cur, kv_seq_len, rolling_cache_len, cache_offset) + q, k = self.rotary_embedding(q, k, rolling_cache_len) + output = op_ext.attention(q, k, v, attention_mask) + return self.o_proj(output) + + +class RollingKVCacheWithSinks(nn.KVCache): + """ + Rolling buffer cache implementation. + """ + + cache: Optional[rx.Var] + + def override(self, new_element: Tensor, max_cache_size: int, attention_sink_size: int) -> None: + """ + Override cache elements in RollingKVCacheWithSinks. + + Parameters + ---------- + new_element : Tensor + The new tensor to append. + + max_cache_size : int + Max size of the cache. + + attention_sink_size : int + Number of stored attention sinks. + """ + if new_element.dtype != self.dtype: + raise TypeError( + f'RollingKVCacheWithSinks has been set to use dtype "{self.dtype}", ' + f'but got "{new_element.dtype}"' + ) + self.cache = rx.BlockBuilder.current().emit( + rx.Call( + rx.extern("vm.builtin.attention_kv_cache_window_override_with_sinks"), + args=[ + self.cache, + new_element._expr, # pylint: disable=protected-access + rx.PrimValue(max_cache_size), + rx.PrimValue(attention_sink_size), + ], + sinfo_args=[rx.ObjectStructInfo()], + ) + ) + + +class MistralDecoderLayer(nn.Module): + """Exact same as LlamaDecoderLayer.""" + + def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): + rms_norm_eps = config.rms_norm_eps + self.self_attn = MistralAttention(config, rotary_embedding) + self.mlp = MistralMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward( # pylint: disable=too-many-arguments + self, + hidden_states: Tensor, + attention_mask: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + ): + """Forward pass of a decoder layer; calculate attention, and add an residual connection.""" + + def _apply_residual(out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + out = self.self_attn( + self.input_layernorm(hidden_states), + attention_mask, + rolling_cache_len, + kv_seq_len, + cache_offset, + ) + hidden_states = _apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = _apply_residual(out, residual=hidden_states) + return hidden_states + + +class MistralModel(nn.Module): + """Exact same as LlamaModel.""" + + def __init__(self, config: MistralConfig): + assert config.hidden_size % config.num_attention_heads == 0 + rotary_embedding = RotaryEmbedding(config) + self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward( # pylint: disable=too-many-arguments + self, + inputs: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + attention_mask: Tensor, + ): + """Forward pass of the model, passing through all decoder layers.""" + if self.tensor_parallel_shards > 1: + inputs = op.ccl_broadcast_from_worker0(inputs) + hidden_states = self.embed_tokens(inputs) + for layer in self.layers: + hidden_states = layer( + hidden_states, attention_mask, rolling_cache_len, kv_seq_len, cache_offset + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MistralForCasualLM(nn.Module): + """Same as LlamaForCausalLM, except for the use of sliding window attention.""" + + def __init__(self, config: MistralConfig): + self.model = MistralModel(config) + self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.sliding_window_size = config.sliding_window_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward( # pylint: disable=too-many-arguments + self, + inputs: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + attention_mask: Tensor, + ): + """Forward pass.""" + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model( + inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask + ) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill( + self, + inputs: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + ): + """ + Prefilling the prompt. + + Parameters + ---------- + inputs: Tensor + Input tokens, having ``seq_len`` number of tokens. + + rolling_cache_len: tir.Var + Number of elements currently in the cache. + + kv_seq_len: tir.Var + Equals to ``seq_len + rolling_cache_len``. + + cache_offset: tir.Var + Next position to be overrided on the rolling kv cache. + """ + + def _sliding_window_attention_mask( + batch_size, seq_len, rolling_cache_len, kv_seq_len, sliding_window_size + ): + # See `tests/legacy-python/test_sliding_window_mask.py` for its behavior + return te.compute( + (batch_size, 1, seq_len, kv_seq_len), + lambda b, _, i, j: tir.Select( + tir.all( + i + rolling_cache_len >= j, i + rolling_cache_len - j < sliding_window_size + ), + tir.max_value(self.dtype), + tir.min_value(self.dtype), + ), + name="sliding_window_attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _sliding_window_attention_mask, + name_hint="sliding_window_attention_mask_prefill", + args=[ + batch_size, + seq_len, + rolling_cache_len, + kv_seq_len, + self.sliding_window_size, + ], + ) + return self.forward(inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask) + + def decode( + self, + inputs: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + ): + """Decoding step.""" + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, kv_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + """Softmax.""" + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + """Needed for ``export_tvm()``.""" + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "rolling_cache_len": int, + "kv_seq_len": int, + "cache_offset": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "rolling_cache_len": int, + "kv_seq_len": int, + "cache_offset": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/mistral/mistral_quantization.py b/python/mlc_chat/model/mistral/mistral_quantization.py new file mode 100644 index 0000000..e3622fd --- /dev/null +++ b/python/mlc_chat/model/mistral/mistral_quantization.py @@ -0,0 +1,69 @@ +"""This file specifies how MLC's Mistral parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .mistral_model import MistralConfig, MistralForCasualLM + + +def group_quant( + model_config: MistralConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mistral-architecture model using group quantization.""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: MistralConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mistral-architecture model using FasterTransformer quantization.""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: MistralConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mistral-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: MistralConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/mixtral/__init__.py b/python/mlc_chat/model/mixtral/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/mixtral/mixtral_loader.py b/python/mlc_chat/model/mixtral/mixtral_loader.py new file mode 100644 index 0000000..12e96eb --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_loader.py @@ -0,0 +1,129 @@ +""" +This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .mixtral_model import MixtralConfig, MixtralForCasualLM + + +def huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MixtralConfig + The configuration of the Mixtral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MixtralForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Add gates in MLP (when MoE is enabled) + mlp = f"model.layers.{i}.block_sparse_moe" + mlc_mlp = f"model.layers.{i}.moe" + mlc_name = f"{mlc_mlp}.e1_e3.weight" + mlc_param = named_parameters[mlc_name] + + def combine_expert_gate_up(*hf_params, dtype): + stack = [] + for i in range(0, len(hf_params), 2): + stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0)) + return np.stack(stack, axis=0).astype(dtype) + + mapping.add_mapping( + mlc_name, + functools.reduce( + lambda a, b: a + b, + [ + [ + f"{mlp}.experts.{expert_id}.w1.weight", + f"{mlp}.experts.{expert_id}.w3.weight", + ] + for expert_id in range(model_config.num_local_experts) + ], + ), + functools.partial( + combine_expert_gate_up, + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{mlc_mlp}.e2.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.experts.{expert_id}.w2.weight" + for expert_id in range(model_config.num_local_experts) + ], + functools.partial( + lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{mlc_mlp}.gate.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [f"{mlp}.gate.weight"], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/mixtral/mixtral_model.py b/python/mlc_chat/model/mixtral/mixtral_model.py new file mode 100644 index 0000000..a2740f1 --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_model.py @@ -0,0 +1,176 @@ +"""Implementation for Mistral architecture.""" +import dataclasses + +from tvm import tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.model.llama.llama_model import ( + LlamaAttention, + LlamaConfig, + LlamaForCasualLM, + LlamaModel, +) +from mlc_chat.nn import PagedKVCache +from mlc_chat.nn.expert import MixtralExperts +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MixtralConfig(LlamaConfig): # pylint: disable=too-many-instance-attributes + """Configuration of the Mixtral model.""" + + num_local_experts: int = 0 + num_experts_per_tok: int = 0 + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals,fixme + + +class MixtralMoE(nn.Module): + """Mixture of experts""" + + def __init__(self, config: MixtralConfig): + super().__init__() + self.num_experts_per_tok = config.num_experts_per_tok + self.num_local_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate = nn.Linear( + in_features=config.hidden_size, + out_features=config.num_local_experts, + bias=False, + ) + self.e1_e3 = MixtralExperts( + self.num_local_experts, + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + ) + self.e2 = MixtralExperts( + self.num_local_experts, + in_features=self.intermediate_size, + out_features=config.hidden_size, + ) + self.dtype = "float32" + + def forward(self, x: Tensor): + def _expert_forward(x: Tensor, indptr: Tensor): + x1_x3 = self.e1_e3(x, indptr) + x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1) + x = self.e2(op.silu(x1) * x3, indptr) + return x + + experts_per_tok = self.num_experts_per_tok # activated experts per token + local_experts = self.num_local_experts # total number of experts + batch_size, seq_len, hidden_size = x.shape + num_tokens = batch_size * seq_len + x = x.reshape(num_tokens, hidden_size) + # gate: [num_tokens, local_experts] + gate: Tensor = self.gate(x) + # expert_weights: [num_tokens, experts_per_tok] + # expert_indices: [num_tokens, experts_per_tok] + expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok) + use_ft = op_ext.get_store().faster_transformer and self.dtype == "float16" + if num_tokens == 1: + # x: [num_tokens * experts_per_tok, hidden_size] + x = _expert_forward(x, expert_indices) + else: + # cumsum: [num_tokens * local_experts] + cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, local_experts) + # indices: [num_tokens * experts_per_tok] + reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices) + if use_ft: + # indptr: [num_local_experts] + indptr = op_ext.moe_misc.get_indptr( + cumsum, local_experts, num_tokens, inclusive=True, out_dtype="int64" + ) + else: + # indptr: [num_local_experts + 1] + indptr = op_ext.moe_misc.get_indptr( + cumsum, local_experts, num_tokens, inclusive=False, out_dtype="int32" + ) + # x: [num_tokens * experts_per_tok, hidden_size] + x = op.take(x, token_indices, axis=0) + x = _expert_forward(x, indptr) + x = op_ext.moe_misc.scatter_output(x, reverse_indices) + # x: [num_tokens, experts_per_tok, hidden_size] + x = x.reshape( # pylint: disable=too-many-function-args + num_tokens, experts_per_tok, hidden_size + ) * expert_weights.reshape( # pylint: disable=too-many-function-args + num_tokens, experts_per_tok, 1 + ) + # x: [num_tokens, hidden_size] + x = op_ext.moe_misc.moe_sum(x, dim=1) + x = x.reshape(batch_size, seq_len, hidden_size) # pylint: disable=too-many-function-args + return x + + +class MixtralDecoderLayer(nn.Module): + """Mixtral decoder layer""" + + def __init__(self, config: MixtralConfig): + eps = config.rms_norm_eps + self.self_attn = LlamaAttention(config) + self.moe = MixtralMoE(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.moe.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.moe.e1_e3, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=1)) + _set(self.moe.e2, tp.ShardSingleDim("_shard_mlp_down", dim=2)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + """Forward pass of a decoder layer; calculate attention, and add an residual connection.""" + out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.moe(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn.batch_forward( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.moe(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class MixtralModel(LlamaModel): + """Exact same as LlamaModel.""" + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + +class MixtralForCasualLM(LlamaForCasualLM): + """Same as LlamaForCausalLM.""" + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.model = MixtralModel(config) diff --git a/python/mlc_chat/model/mixtral/mixtral_quantization.py b/python/mlc_chat/model/mixtral/mixtral_quantization.py new file mode 100644 index 0000000..37f7ad5 --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_quantization.py @@ -0,0 +1,61 @@ +"""This file specifies how MLC's Mistral parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .mixtral_model import MixtralConfig, MixtralForCasualLM + + +def group_quant( + model_config: MixtralConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral-architecture model using group quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: MixtralConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral-architecture model using FasterTransformer quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: MixtralConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral-architecture model using Activation-aware Weight Quantization(AWQ).""" + raise NotImplementedError("AWQ is not implemented for Mixtral models.") + + +def no_quant( + model_config: MixtralConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral model without quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py new file mode 100644 index 0000000..68d052c --- /dev/null +++ b/python/mlc_chat/model/model.py @@ -0,0 +1,251 @@ +"""A centralized registry of all existing model architures and their configurations.""" + +import dataclasses +from typing import Any, Callable, Dict, Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import ExternMapping, QuantizeMapping +from mlc_chat.quantization.quantization import Quantization + +from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization +from .gemma import gemma_loader, gemma_model, gemma_quantization +from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization +from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization +from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization +from .llama import llama_loader, llama_model, llama_quantization +from .mistral import mistral_loader, mistral_model, mistral_quantization +from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization +from .phi import phi_loader, phi_model, phi_quantization +from .qwen import qwen_loader, qwen_model, qwen_quantization +from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization +from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization + +ModelConfig = Any +"""A ModelConfig is an object that represents a model architecture. It is required to have +a class method `from_file` with the following signature: + + def from_file(cls, path: Path) -> ModelConfig: + ... +""" + +FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping] +FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]] + + +@dataclasses.dataclass +class Model: + """All about a model architecture: its configuration, its parameter loader and quantization. + + Parameters + ---------- + name : str + The name of the model. + + model : Callable[[ModelConfig], nn.Module] + A method that creates the `nn.Module` that represents the model from `ModelConfig`. + + config : ModelConfig + A class that has a `from_file` class method, whose signature is "Path -> ModelConfig". + + source : Dict[str, FuncGetExternMap] + A dictionary that maps the name of a source format to parameter mapping. + + quantize: Dict[str, FuncQuantization] + A dictionary that maps the name of a quantization method to quantized model and the + quantization parameter mapping. + """ + + name: str + config: ModelConfig + model: Callable[[ModelConfig], nn.Module] + source: Dict[str, FuncGetExternMap] + quantize: Dict[str, FuncQuantization] + + +MODELS: Dict[str, Model] = { + "llama": Model( + name="llama", + model=llama_model.LlamaForCasualLM, + config=llama_model.LlamaConfig, + source={ + "huggingface-torch": llama_loader.huggingface, + "huggingface-safetensor": llama_loader.huggingface, + "awq": llama_loader.awq, + }, + quantize={ + "no-quant": llama_quantization.no_quant, + "group-quant": llama_quantization.group_quant, + "ft-quant": llama_quantization.ft_quant, + "awq": llama_quantization.awq_quant, + }, + ), + "mistral": Model( + name="mistral", + model=mistral_model.MistralForCasualLM, + config=mistral_model.MistralConfig, + source={ + "huggingface-torch": mistral_loader.huggingface, + "huggingface-safetensor": mistral_loader.huggingface, + "awq": mistral_loader.awq, + }, + quantize={ + "group-quant": mistral_quantization.group_quant, + "no-quant": mistral_quantization.no_quant, + "ft-quant": mistral_quantization.ft_quant, + }, + ), + "gemma": Model( + name="gemma", + model=gemma_model.GemmaForCausalLM, + config=gemma_model.GemmaConfig, + source={ + "huggingface-torch": gemma_loader.huggingface, + "huggingface-safetensor": gemma_loader.huggingface, + }, + quantize={ + "no-quant": gemma_quantization.no_quant, + "group-quant": gemma_quantization.group_quant, + }, + ), + "gpt2": Model( + name="gpt2", + model=gpt2_model.GPT2LMHeadModel, + config=gpt2_model.GPT2Config, + source={ + "huggingface-torch": gpt2_loader.huggingface, + "huggingface-safetensor": gpt2_loader.huggingface, + }, + quantize={ + "no-quant": gpt2_quantization.no_quant, + "group-quant": gpt2_quantization.group_quant, + "ft-quant": gpt2_quantization.ft_quant, + }, + ), + "mixtral": Model( + name="mixtral", + model=mixtral_model.MixtralForCasualLM, + config=mixtral_model.MixtralConfig, + source={ + "huggingface-torch": mixtral_loader.huggingface, + "huggingface-safetensor": mixtral_loader.huggingface, + }, + quantize={ + "no-quant": mixtral_quantization.no_quant, + "group-quant": mixtral_quantization.group_quant, + "ft-quant": mixtral_quantization.ft_quant, + }, + ), + "gpt_neox": Model( + name="gpt_neox", + model=gpt_neox_model.GPTNeoXForCausalLM, + config=gpt_neox_model.GPTNeoXConfig, + source={ + "huggingface-torch": gpt_neox_loader.huggingface, + "huggingface-safetensor": gpt_neox_loader.huggingface, + }, + quantize={ + "no-quant": gpt_neox_quantization.no_quant, + "group-quant": gpt_neox_quantization.group_quant, + "ft-quant": gpt_neox_quantization.ft_quant, + }, + ), + "gpt_bigcode": Model( + name="gpt_bigcode", + model=gpt_bigcode_model.GPTBigCodeForCausalLM, + config=gpt_bigcode_model.GPTBigCodeConfig, + source={ + "huggingface-torch": gpt_bigcode_loader.huggingface, + "huggingface-safetensor": gpt_bigcode_loader.huggingface, + }, + quantize={ + "no-quant": gpt_bigcode_quantization.no_quant, + "group-quant": gpt_bigcode_quantization.group_quant, + "ft-quant": gpt_bigcode_quantization.ft_quant, + }, + ), + "phi-msft": Model( + name="phi-msft", + model=phi_model.PhiForCausalLM, + config=phi_model.PhiConfig, + source={ + "huggingface-torch": phi_loader.huggingface, + "huggingface-safetensor": phi_loader.huggingface, + }, + quantize={ + "no-quant": phi_quantization.no_quant, + "group-quant": phi_quantization.group_quant, + "ft-quant": phi_quantization.ft_quant, + }, + ), + "phi": Model( + name="phi", + model=phi_model.PhiForCausalLM, + config=phi_model.Phi1Config, + source={ + "huggingface-torch": phi_loader.phi1_huggingface, + "huggingface-safetensor": phi_loader.phi1_huggingface, + }, + quantize={ + "no-quant": phi_quantization.no_quant, + "group-quant": phi_quantization.group_quant, + "ft-quant": phi_quantization.ft_quant, + }, + ), + "qwen": Model( + name="qwen", + model=qwen_model.QWenLMHeadModel, + config=qwen_model.QWenConfig, + source={ + "huggingface-torch": qwen_loader.huggingface, + "huggingface-safetensor": qwen_loader.huggingface, + }, + quantize={ + "no-quant": qwen_quantization.no_quant, + "group-quant": qwen_quantization.group_quant, + "ft-quant": qwen_quantization.ft_quant, + }, + ), + "qwen2": Model( + name="qwen2", + model=qwen2_model.QWen2LMHeadModel, + config=qwen2_model.QWen2Config, + source={ + "huggingface-torch": qwen2_loader.huggingface, + "huggingface-safetensor": qwen2_loader.huggingface, + }, + quantize={ + "no-quant": qwen2_quantization.no_quant, + "group-quant": qwen2_quantization.group_quant, + "ft-quant": qwen2_quantization.ft_quant, + }, + ), + "stablelm_epoch": Model( + name="stablelm_epoch", + model=stablelm_model.StableLMEpochForCausalLM, + config=stablelm_model.StableLMEpochConfig, + source={ + "huggingface-torch": stablelm_loader.huggingface, + "huggingface-safetensor": stablelm_loader.huggingface, + }, + quantize={ + "no-quant": stablelm_quantization.no_quant, + "group-quant": stablelm_quantization.group_quant, + "ft-quant": stablelm_quantization.ft_quant, + }, + ), + "baichuan": Model( + name="baichuan", + model=baichuan_model.BaichuanForCausalLM, + config=baichuan_model.BaichuanConfig, + source={ + "huggingface-torch": baichuan_loader.huggingface, + "huggingface-safetensor": baichuan_loader.huggingface, + }, + quantize={ + "no-quant": baichuan_quantization.no_quant, + "group-quant": baichuan_quantization.group_quant, + "ft-quant": baichuan_quantization.ft_quant, + }, + ), +} diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py new file mode 100644 index 0000000..bacfd43 --- /dev/null +++ b/python/mlc_chat/model/model_preset.py @@ -0,0 +1,495 @@ +"""A builtin set of models available in MLC LLM.""" + +from typing import Any, Dict + +MODEL_PRESETS: Dict[str, Any] = { + "llama2_7b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "llama2_13b": { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "llama2_70b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "codellama_7b": { + "_name_or_path": "codellama/CodeLlama-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.33.0.dev0", + "use_cache": True, + "vocab_size": 32016, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "codellama_13b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "use_cache": True, + "vocab_size": 32016, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "codellama_34b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 22016, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "use_cache": True, + "vocab_size": 32016, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "mistral_7b": { + "architectures": ["MistralForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0.dev0", + "use_cache": True, + "vocab_size": 32000, + "sliding_window_size": 4096, + "prefill_chunk_size": 128, + "attention_sink_size": 4, + }, + "gpt2": { + "architectures": ["GPT2LMHeadModel"], + "bos_token_id": 50256, + "eos_token_id": 50256, + "hidden_act": "gelu_new", + "n_embd": 768, + "initializer_range": 0.02, + "n_positions": 1024, + "model_type": "gpt2", + "n_head": 12, + "n_layer": 12, + "layer_norm_epsilon": 1e-05, + "transformers_version": "4.26.0.dev0", + "use_cache": True, + "vocab_size": 50257, + "context_window_size": 2048, + "prefill_chunk_size": 2048, + }, + "gpt_bigcode": { + "activation_function": "gelu_pytorch_tanh", + "architectures": ["GPTBigCodeForCausalLM"], + "attention_softmax_in_fp32": True, + "multi_query": True, + "attn_pdrop": 0.1, + "bos_token_id": 49152, + "embd_pdrop": 0.1, + "eos_token_id": 49152, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt_bigcode", + "n_embd": 2048, + "n_head": 16, + "n_inner": 8192, + "n_layer": 24, + "n_positions": 2048, + "resid_pdrop": 0.1, + "runner_max_sequence_length": None, + "scale_attention_softmax_in_fp32": True, + "scale_attn_weights": True, + "summary_activation": None, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": True, + "summary_type": "cls_index", + "summary_use_proj": True, + "transformers_version": "4.28.0.dev0", + "use_cache": True, + "vocab_size": 49280, + }, + "Mixtral-8x7B-v0.1": { + "architectures": ["MixtralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_experts_per_tok": 2, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": False, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "sliding_window": None, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "redpajama_3b_v1": { + "_name_or_path": "/root/fm/models/rp_3b_800b_real_fp16", + "architectures": ["GPTNeoXForCausalLM"], + "bos_token_id": 0, + "eos_token_id": 0, + "hidden_act": "gelu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 10240, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 2048, + "model_type": "gpt_neox", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": True, + "use_parallel_residual": False, + "vocab_size": 50432, + }, + "phi-1_5": { + "_name_or_path": "microsoft/phi-1_5", + "activation_function": "gelu_new", + "architectures": ["PhiForCausalLM"], + "attn_pdrop": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi.PhiConfig", + "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM", + }, + "embd_pdrop": 0.0, + "flash_attn": False, + "flash_rotary": False, + "fused_dense": False, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "phi-msft", + "n_embd": 2048, + "n_head": 32, + "n_head_kv": None, + "n_inner": None, + "n_layer": 24, + "n_positions": 2048, + "resid_pdrop": 0.0, + "rotary_dim": 32, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.34.1", + "vocab_size": 51200, + }, + "phi-2": { + "_name_or_path": "microsoft/phi-2", + "activation_function": "gelu_new", + "architectures": ["PhiForCausalLM"], + "attn_pdrop": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi.PhiConfig", + "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM", + }, + "embd_pdrop": 0.0, + "flash_attn": False, + "flash_rotary": False, + "fused_dense": False, + "img_processor": None, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "phi-msft", + "n_embd": 2560, + "n_head": 32, + "n_head_kv": None, + "n_inner": None, + "n_layer": 32, + "n_positions": 2048, + "resid_pdrop": 0.1, + "rotary_dim": 32, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.35.2", + "vocab_size": 51200, + }, + "qwen": { + "architectures": ["QWenLMHeadModel"], + "auto_map": { + "AutoConfig": "configuration_qwen.QWenConfig", + "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel", + }, + "attn_dropout_prob": 0.0, + "bf16": False, + "emb_dropout_prob": 0.0, + "hidden_size": 2048, + "intermediate_size": 11008, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 8192, + "model_type": "qwen", + "no_bias": True, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": True, + "seq_length": 8192, + "tie_word_embeddings": False, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": True, + "use_dynamic_ntk": True, + "use_flash_attn": "auto", + "use_logn_attn": True, + "vocab_size": 151936, + }, + "qwen2": { + "_name_or_path": "Qwen/Qwen1.5-1.8B-Chat", + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5504, + "max_position_embeddings": 4096, + "max_window_layers": 21, + "model_type": "qwen2", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.37.2", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, + }, + "stablelm_epoch": { + "architectures": ["StableLMEpochForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_stablelm_epoch.StableLMEpochConfig", + "AutoModelForCausalLM": "modeling_stablelm_epoch.StableLMEpochForCausalLM", + }, + "bos_token_id": 100257, + "eos_token_id": 100257, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 4096, + "model_type": "stablelm_epoch", + "norm_eps": 1e-05, + "num_attention_heads": 32, + "num_heads": 32, + "num_hidden_layers": 24, + "num_key_value_heads": 32, + "rope_pct": 0.25, + "rope_theta": 10000, + "rotary_scaling_factor": 1.0, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.2", + "use_cache": True, + "use_qkv_bias": True, + "vocab_size": 100352, + }, + "baichuan": { + "architectures": ["BaichuanForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_baichuan.BaichuanConfig", + "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM", + }, + "tokenizer_class": "BaichuanTokenizer", + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_max_length": 4096, + "model_type": "baichuan", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "_from_model_config": True, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.29.2", + "use_cache": True, + "vocab_size": 125696, + }, + # TODO(mlc-team): enable the model presets when stablized. + # "gemma_2b": { + # "architectures": ["GemmaForCausalLM"], + # "attention_bias": False, + # "bos_token_id": 2, + # "eos_token_id": 1, + # "head_dim": 256, + # "hidden_act": "gelu", + # "hidden_size": 2048, + # "initializer_range": 0.02, + # "intermediate_size": 16384, + # "max_position_embeddings": 8192, + # "model_type": "gemma", + # "num_attention_heads": 8, + # "num_hidden_layers": 18, + # "num_key_value_heads": 1, + # "pad_token_id": 0, + # "rms_norm_eps": 1e-06, + # "rope_theta": 10000.0, + # "torch_dtype": "bfloat16", + # "transformers_version": "4.38.0.dev0", + # "vocab_size": 256000, + # }, + # "gemma_7b": { + # "architectures": ["GemmaForCausalLM"], + # "attention_bias": False, + # "bos_token_id": 2, + # "eos_token_id": 1, + # "head_dim": 256, + # "hidden_act": "gelu", + # "hidden_size": 3072, + # "initializer_range": 0.02, + # "intermediate_size": 24576, + # "max_position_embeddings": 8192, + # "model_type": "gemma", + # "num_attention_heads": 16, + # "num_hidden_layers": 28, + # "num_key_value_heads": 16, + # "pad_token_id": 0, + # "rms_norm_eps": 1e-06, + # "rope_theta": 10000.0, + # "torch_dtype": "bfloat16", + # "transformers_version": "4.38.0.dev0", + # "vocab_size": 256000, + # }, +} diff --git a/python/mlc_chat/model/phi/__init__.py b/python/mlc_chat/model/phi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/phi/phi_loader.py b/python/mlc_chat/model/phi/phi_loader.py new file mode 100644 index 0000000..d393c61 --- /dev/null +++ b/python/mlc_chat/model/phi/phi_loader.py @@ -0,0 +1,162 @@ +""" +This file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .phi_model import Phi1Config, PhiConfig, PhiForCausalLM + + +def huggingface(model_config: PhiConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : PhiConfig + The configuration of the Phi model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = PhiForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + if model_config.model_type == "mixformer-sequential": + _add("transformer.embd.weight", "layers.0.wte.weight") + prefix = "transformer.h" + for i in range(model_config.n_layer): + _add(f"{prefix}.{i}.ln.weight", f"layers.{i + 1}.ln.weight") + _add(f"{prefix}.{i}.ln.bias", f"layers.{i + 1}.ln.bias") + _add(f"{prefix}.{i}.mixer.Wqkv.weight", f"layers.{i + 1}.mixer.Wqkv.weight") + _add(f"{prefix}.{i}.mixer.Wqkv.bias", f"layers.{i + 1}.mixer.Wqkv.bias") + _add(f"{prefix}.{i}.mixer.out_proj.weight", f"layers.{i + 1}.mixer.out_proj.weight") + _add(f"{prefix}.{i}.mixer.out_proj.bias", f"layers.{i + 1}.mixer.out_proj.bias") + _add(f"{prefix}.{i}.mlp.fc1.weight", f"layers.{i + 1}.mlp.fc1.weight") + _add(f"{prefix}.{i}.mlp.fc1.bias", f"layers.{i + 1}.mlp.fc1.bias") + _add(f"{prefix}.{i}.mlp.fc2.weight", f"layers.{i + 1}.mlp.fc2.weight") + _add(f"{prefix}.{i}.mlp.fc2.bias", f"layers.{i + 1}.mlp.fc2.bias") + mapping.add_unused(f"layers.{i + 1}.mixer.rotary_emb.inv_freq") + prefix = f"layers.{model_config.n_layer + 1}" + _add("lm_head.ln.weight", f"{prefix}.ln.weight") + _add("lm_head.ln.bias", f"{prefix}.ln.bias") + _add("lm_head.linear.weight", f"{prefix}.linear.weight") + _add("lm_head.linear.bias", f"{prefix}.linear.bias") + + elif model_config.model_type == "phi-msft": + _add("transformer.embd.weight", "transformer.embd.wte.weight") + for mlc_name, _ in named_parameters.items(): + if mlc_name not in mapping.param_map: + _add(mlc_name, mlc_name) + return mapping + + +def phi1_huggingface(model_config: Phi1Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : PhiConfig + The configuration of the Phi model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = PhiForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + def _concat_add(mlc_name, hf_names): + mapping.add_mapping( + mlc_name, + hf_names, + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + _add("lm_head.linear.weight", "lm_head.weight") + _add("lm_head.linear.bias", "lm_head.bias") + _add("lm_head.ln.weight", "model.final_layernorm.weight") + _add("lm_head.ln.bias", "model.final_layernorm.bias") + _add("transformer.embd.weight", "model.embed_tokens.weight") + + prefix = "transformer.h" + hf_prefix = "model.layers" + for i in range(model_config.num_hidden_layers): + _add(f"{prefix}.{i}.ln.weight", f"{hf_prefix}.{i}.input_layernorm.weight") + _add(f"{prefix}.{i}.ln.bias", f"{hf_prefix}.{i}.input_layernorm.bias") + _concat_add( + f"{prefix}.{i}.mixer.Wqkv.weight", + [ + f"{hf_prefix}.{i}.self_attn.q_proj.weight", + f"{hf_prefix}.{i}.self_attn.k_proj.weight", + f"{hf_prefix}.{i}.self_attn.v_proj.weight", + ], + ) + _concat_add( + f"{prefix}.{i}.mixer.Wqkv.bias", + [ + f"{hf_prefix}.{i}.self_attn.q_proj.bias", + f"{hf_prefix}.{i}.self_attn.k_proj.bias", + f"{hf_prefix}.{i}.self_attn.v_proj.bias", + ], + ) + _add(f"{prefix}.{i}.mixer.out_proj.weight", f"{hf_prefix}.{i}.self_attn.dense.weight") + _add(f"{prefix}.{i}.mixer.out_proj.bias", f"{hf_prefix}.{i}.self_attn.dense.bias") + _add(f"{prefix}.{i}.mlp.fc1.weight", f"{hf_prefix}.{i}.mlp.fc1.weight") + _add(f"{prefix}.{i}.mlp.fc1.bias", f"{hf_prefix}.{i}.mlp.fc1.bias") + _add(f"{prefix}.{i}.mlp.fc2.weight", f"{hf_prefix}.{i}.mlp.fc2.weight") + _add(f"{prefix}.{i}.mlp.fc2.bias", f"{hf_prefix}.{i}.mlp.fc2.bias") + mapping.add_unused(f"{hf_prefix}.{i}.mixer.rotary_emb.inv_freq") + + return mapping diff --git a/python/mlc_chat/model/phi/phi_model.py b/python/mlc_chat/model/phi/phi_model.py new file mode 100644 index 0000000..421876d --- /dev/null +++ b/python/mlc_chat/model/phi/phi_model.py @@ -0,0 +1,404 @@ +""" +Implementation for Phi architecture. +TODO: add docstring +""" +import dataclasses +from typing import Any, Dict, Optional, Union + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Phi1Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Phi-1/Phi-1.5 model.""" + + vocab_size: int = 51200 + hidden_size: int = 2048 + intermediate_size: int = 8192 + num_hidden_layers: int = 24 + num_attention_heads: int = 32 + layer_norm_eps: float = 1e-5 + position_embedding_base: int = 0 + partial_rotary_factor: float = 0.5 + num_key_value_heads: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + self.prefill_chunk_size = self.context_window_size + if self.prefill_chunk_size > self.context_window_size: + self.prefill_chunk_size = self.context_window_size + if self.num_key_value_heads == 0 or self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.intermediate_size == 0 or self.intermediate_size is None: + self.intermediate_size = 4 * self.hidden_size + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.num_attention_heads % self.num_key_value_heads == 0 + + +@dataclasses.dataclass +class PhiConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Phi-2 model.""" + + model_type: str # "phi", "phi-msft", "mixformer-sequential" + vocab_size: int = 51200 + n_positions: int = 2048 + n_embd: int = 2560 + n_layer: int = 32 + n_inner: int = 0 + n_head: int = 32 + rotary_dim: int = 32 + position_embedding_base: int = 0 + layer_norm_epsilon: float = 1e-5 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + n_head_kv: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + self.context_window_size = self.n_positions + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + "n_positions", + self.context_window_size, + ) + if self.prefill_chunk_size == 0: + self.prefill_chunk_size = self.context_window_size + if self.prefill_chunk_size > self.context_window_size: + self.prefill_chunk_size = self.context_window_size + if self.n_head_kv == 0 or self.n_head_kv is None: + self.n_head_kv = self.n_head + if self.n_inner == 0 or self.n_inner is None: + self.n_inner = 4 * self.n_embd + if self.head_dim == 0: + self.head_dim = self.n_embd // self.n_head + assert self.head_dim * self.n_head == self.n_embd + assert self.n_head % self.n_head_kv == 0 + + @staticmethod + def from_phi1(config: Phi1Config) -> "PhiConfig": + "Build PhiConig from a Phi1Config." + return PhiConfig( + model_type="phi", + vocab_size=config.vocab_size, + n_positions=config.context_window_size, + n_embd=config.hidden_size, + n_layer=config.num_hidden_layers, + n_inner=config.intermediate_size, + n_head=config.num_attention_heads, + rotary_dim=int(config.partial_rotary_factor * config.head_dim), + position_embedding_base=config.position_embedding_base, + layer_norm_epsilon=config.layer_norm_eps, + context_window_size=config.context_window_size, + prefill_chunk_size=config.prefill_chunk_size, + n_head_kv=config.num_key_value_heads, + head_dim=config.head_dim, + tensor_parallel_shards=config.tensor_parallel_shards, + kwargs=config.kwargs, + ) + + +# pylint: disable=invalid-name,missing-docstring + + +class PhiMLP(nn.Module): + def __init__(self, config: PhiConfig): + super().__init__() + self.intermediate_size = config.n_inner // config.tensor_parallel_shards + self.fc1 = nn.Linear(config.n_embd, self.intermediate_size) + self.fc2 = nn.Linear(self.intermediate_size, config.n_embd) + + def forward(self, hidden_states: Tensor): + hidden_states = self.fc1(hidden_states) + hidden_states = op.gelu(hidden_states, approximate="tanh") + hidden_states = self.fc2(hidden_states) + + return hidden_states + + +class PhiCrossAttention(nn.Module): + def __init__(self, config: PhiConfig): # pylint: disable=unused-argument + super().__init__() + + def forward(self, q: Tensor, k: Tensor, v: Tensor, attention_mask: Tensor): + output = op_ext.attention(q, k, v, casual_mask=attention_mask, qk_dtype="float32") + return output + + +class PhiMHA(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: PhiConfig): + self.rope_theta = config.position_embedding_base + self.rotary_dim = config.rotary_dim + self.n_head = config.n_head // config.tensor_parallel_shards + assert ( + config.n_head % config.tensor_parallel_shards == 0 + ), f"n_head({config.n_head}) must be divisible by tensor_parallel_shards" + self.n_head_kv = config.n_head_kv // config.tensor_parallel_shards + assert ( + config.n_head_kv % config.tensor_parallel_shards == 0 + ), f"n_head({config.n_head_kv}) must be divisible by tensor_parallel_shards" + self.head_dim = config.head_dim + op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) + hidden_size = config.n_embd + + self.Wqkv = nn.Linear(hidden_size, op_size, bias=True) + self.out_proj = nn.Linear(self.n_head * self.head_dim, hidden_size, bias=True) + self.inner_cross_attn = PhiCrossAttention(config) + self.k_cache = nn.KVCache(config.context_window_size, [self.n_head_kv, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.n_head_kv, self.head_dim]) + + def forward(self, x: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + d, h_q, h_kv, t = self.head_dim, self.n_head, self.n_head_kv, total_seq_len + b, s, _ = x.shape + assert b == 1, "Only support batch size 1 at this moment." + # Step 1. QKV Projection + qkv = self.Wqkv(x) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Step 2. Apply QK rotary embedding + q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h_q, h_kv, rotary_dim=self.rotary_dim) + # Step 3. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = self.inner_cross_attn(q, k, v, attention_mask) + # Step 5. Apply output projection + return self.out_proj(output) + + +class PhiParallelBlock(nn.Module): + def __init__(self, config: PhiConfig): + super().__init__() + + self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.mixer = PhiMHA(config) + self.mlp = PhiMLP(config) + + def _set_tp(): + def _set(param, hint): + param.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.mixer.n_head * hd + k = self.mixer.n_head_kv * hd + v = self.mixer.n_head_kv * hd + _set( + self.mixer.Wqkv.weight, + tp.ShardSingleDim("_shard_qkv_weight", segs=[q, k, v], dim=0), + ) + _set(self.mixer.Wqkv.bias, tp.ShardSingleDim("_shard_qkv_bias", segs=[q, k, v], dim=0)) + _set(self.mixer.out_proj.weight, tp.ShardSingleDim("_shard_o_weight", dim=1)) + _set(self.mlp.fc1.weight, tp.ShardSingleDim("_shard_mlp_fc1_weight", dim=0)) + _set(self.mlp.fc1.bias, tp.ShardSingleDim("_shard_mlp_fc1_bias", dim=0)) + _set(self.mlp.fc2.weight, tp.ShardSingleDim("_shard_mlp_fc2_weight", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + residual = hidden_states + hidden_states = self.ln(hidden_states) + + with tp.shard_bias(self.mixer.out_proj, self.tensor_parallel_shards), tp.shard_bias( + self.mlp.fc2, self.tensor_parallel_shards + ): + attn_outputs = self.mixer( + hidden_states, + attention_mask, + total_seq_len, + ) + + feed_forward_hidden_states = self.mlp(hidden_states) + + def _apply_parallel_residual(attn_out, mlp_out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce( + attn_out + mlp_out + residual / self.tensor_parallel_shards, "sum" + ) + return attn_out + mlp_out + residual + + hidden_states = _apply_parallel_residual(attn_outputs, feed_forward_hidden_states, residual) + + return hidden_states + + +class PhiCausalLMHead(nn.Module): + def __init__(self, config: PhiConfig) -> None: + super().__init__() + + self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.linear = nn.Linear(config.n_embd, "vocab_size") + + def forward(self, hidden_states: Tensor): + hidden_states = self.ln(hidden_states) + logits = self.linear(hidden_states) + + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + +class PhiModel(nn.Module): + def __init__(self, config: PhiConfig) -> None: + super().__init__() + self.embd = nn.Embedding("vocab_size", config.n_embd) + self.h = nn.ModuleList([PhiParallelBlock(config) for i in range(config.n_layer)]) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + hidden_states = self.embd(input_ids) + for layer in self.h: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + + return hidden_states + + +class PhiForCausalLM(nn.Module): + def __init__(self, config: Union[PhiConfig, Phi1Config]) -> None: + super().__init__() + + if isinstance(config, Phi1Config): + config = PhiConfig.from_phi1(config) + + self.transformer = PhiModel(config) + self.lm_head = PhiCausalLMHead(config) + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_ids, total_seq_len, attention_mask) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + lm_logits = self.lm_head(hidden_states) + + return lm_logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/phi/phi_quantization.py b/python/mlc_chat/model/phi/phi_quantization.py new file mode 100644 index 0000000..52089c2 --- /dev/null +++ b/python/mlc_chat/model/phi/phi_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .phi_model import PhiConfig, PhiForCausalLM + + +def group_quant( + model_config: PhiConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using group quantization.""" + model: nn.Module = PhiForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: PhiConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using FasterTransformer quantization.""" + model: nn.Module = PhiForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: PhiConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi model without quantization.""" + model: nn.Module = PhiForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/qwen/__init__.py b/python/mlc_chat/model/qwen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/qwen/qwen_loader.py b/python/mlc_chat/model/qwen/qwen_loader.py new file mode 100644 index 0000000..810efed --- /dev/null +++ b/python/mlc_chat/model/qwen/qwen_loader.py @@ -0,0 +1,70 @@ +""" +This file specifies how MLC's QWen parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .qwen_model import QWenConfig, QWenLMHeadModel + + +def huggingface(model_config: QWenConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPT2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = QWenLMHeadModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add gates in MLP + mlp = f"transformer.h.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.w1.weight", + f"{mlp}.w2.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/qwen/qwen_model.py b/python/mlc_chat/model/qwen/qwen_model.py new file mode 100644 index 0000000..ef4caca --- /dev/null +++ b/python/mlc_chat/model/qwen/qwen_model.py @@ -0,0 +1,254 @@ +""" +Implementation for QWEN architecture. +TODO: add docstring +""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class QWenConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the QWen model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + layer_norm_epsilon: float + scale_attn_weights: bool + kv_channels: int + rotary_emb_base: int + intermediate_size: int + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + assert self.tensor_parallel_shards == 1, "QWEN currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring + + +class QWenAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: QWenConfig): + self.hidden_size = config.hidden_size + self.rope_theta = config.rotary_emb_base + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.projection_size = config.kv_channels * config.num_attention_heads + + self.c_attn = nn.Linear( + in_features=config.hidden_size, + out_features=3 * self.projection_size, + bias=True, + ) + self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=False) + + # KV cache for single sequence + self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + + def forward( # pylint: disable=too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, + ): + d, h, t = self.head_dim, self.num_heads, total_seq_len + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + # Step 1. QKV Projection + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + # Step 2. Apply QK rotary embedding + q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h, h) + # Step 3. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + # Step 5. Apply output projection + return self.c_proj(output) + + +class QWenMLP(nn.Module): + def __init__(self, config: QWenConfig): + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=self.intermediate_size, + bias=False, + ) + self.c_proj = nn.Linear(self.intermediate_size // 2, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.c_proj(x1 * op.silu(x2)) + + +class QWenBlock(nn.Module): + def __init__(self, config: QWenConfig): + rms_norm_eps = config.layer_norm_epsilon + self.attn = QWenAttention(config) + self.mlp = QWenMLP(config) + self.ln_1 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.ln_2 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + out = self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len) + hidden_states = out + hidden_states + out = self.mlp(self.ln_2(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + +class QWenModel(nn.Module): + def __init__(self, config: QWenConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + self.h = nn.ModuleList([QWenBlock(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = nn.RMSNorm(config.hidden_size, -1, config.layer_norm_epsilon, bias=False) + + def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + hidden_states = self.wte(input_ids) + for layer in self.h: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class QWenLMHeadModel(nn.Module): + def __init__(self, config: QWenConfig): + self.transformer = QWenModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(inputs, total_seq_len, attention_mask) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/qwen/qwen_quantization.py b/python/mlc_chat/model/qwen/qwen_quantization.py new file mode 100644 index 0000000..c69f583 --- /dev/null +++ b/python/mlc_chat/model/qwen/qwen_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's QWen parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .qwen_model import QWenConfig, QWenLMHeadModel + + +def group_quant( + model_config: QWenConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a QWen-architecture model using group quantization.""" + model: nn.Module = QWenLMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: QWenConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen model using FasterTransformer quantization.""" + model: nn.Module = QWenLMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: QWenConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a QWen model without quantization.""" + model: nn.Module = QWenLMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/qwen2/__init__.py b/python/mlc_chat/model/qwen2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/qwen2/qwen2_loader.py b/python/mlc_chat/model/qwen2/qwen2_loader.py new file mode 100644 index 0000000..559a911 --- /dev/null +++ b/python/mlc_chat/model/qwen2/qwen2_loader.py @@ -0,0 +1,88 @@ +""" +This file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .qwen2_model import QWen2Config, QWen2LMHeadModel + + +def huggingface(model_config: QWen2Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPT2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = QWen2LMHeadModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # map attention weight + attn = f"model.layers.{i}.self_attn" + for weight_type in ["weight", "bias"]: + mlc_name = f"{attn}.c_attn.{weight_type}" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{weight_type}", + f"{attn}.k_proj.{weight_type}", + f"{attn}.v_proj.{weight_type}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # map mlp weight + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/qwen2/qwen2_model.py b/python/mlc_chat/model/qwen2/qwen2_model.py new file mode 100644 index 0000000..f09ccee --- /dev/null +++ b/python/mlc_chat/model/qwen2/qwen2_model.py @@ -0,0 +1,271 @@ +""" +Implementation for QWEN2 architecture. +""" + +import dataclasses +from functools import partial +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the QWen model.""" + + hidden_act: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + rms_norm_eps: float + rope_theta: int + vocab_size: int + + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + dtype: str = "float32" + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + assert self.tensor_parallel_shards == 1, "QWEN currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals + + +class QWen2Attention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: QWen2Config): + head_dim = config.hidden_size // config.num_attention_heads + + self.c_attn = nn.Linear( + in_features=config.hidden_size, + out_features=(2 * config.num_key_value_heads + config.num_attention_heads) * head_dim, + bias=True, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * head_dim, config.hidden_size, bias=False + ) + # KV cache for single sequence + self.k_cache = nn.KVCache( + config.context_window_size, [config.num_key_value_heads, head_dim] + ) + self.v_cache = nn.KVCache( + config.context_window_size, [config.num_attention_heads, head_dim] + ) + + self.hidden_size = config.hidden_size + self.head_dim = head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.rope_theta = config.rope_theta + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + bsz, sl, _ = hidden_states.shape + assert bsz == 1, "Only support batch size 1 at this moment." + # Step 1. QKV Projection + qkv = self.c_attn(hidden_states) + num_heads = 2 * self.num_key_value_heads + self.num_attention_heads + qkv = op.reshape(qkv, (bsz, sl, num_heads, self.head_dim)) + # Step 2. Apply QK rotary embedding + q, k, v = op_ext.llama_rope( + qkv, total_seq_len, self.rope_theta, self.num_attention_heads, self.num_key_value_heads + ) + # Step 3. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(total_seq_len) + v = self.v_cache.view(total_seq_len) + # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + # Step 5. Apply output projection + return self.o_proj(output) + + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.silu, + "swish": nn.silu, + "gelu_new": partial(nn.gelu, approximate=True), +} + + +class QWen2MLP(nn.Module): + def __init__(self, config: QWen2Config): + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(self.act_fn(x1) * x2) + + +class QWen2DecoderLayer(nn.Module): + def __init__(self, config: QWen2Config): + self.self_attn = QWen2Attention(config) + self.mlp = QWen2MLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.rms_norm_eps, bias=False + ) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + out = self.input_layernorm(hidden_states) + out = self.self_attn(out, attention_mask, total_seq_len) + hidden_states = out + hidden_states + + out = self.post_attention_layernorm(hidden_states) + out = self.mlp(out) + hidden_states = out + hidden_states + return hidden_states + + +class QWen2Model(nn.Module): + def __init__(self, config: QWen2Config): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [QWen2DecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, input_ids: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + hidden_states = self.embed_tokens(input_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class QWen2LMHeadModel(nn.Module): + def __init__(self, config: QWen2Config): + self.model = QWen2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.dtype = config.dtype + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(inputs, attention_mask, total_seq_len) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, attention_mask, total_seq_len) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, attention_mask, total_seq_len) + + @staticmethod + def softmax_with_temperature(logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/qwen2/qwen2_quantization.py b/python/mlc_chat/model/qwen2/qwen2_quantization.py new file mode 100644 index 0000000..a59802d --- /dev/null +++ b/python/mlc_chat/model/qwen2/qwen2_quantization.py @@ -0,0 +1,54 @@ +"""This file specifies how MLC's QWen2 parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .qwen2_model import QWen2Config, QWen2LMHeadModel + + +def group_quant( + model_config: QWen2Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a QWen-architecture model using group quantization.""" + model: nn.Module = QWen2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: QWen2Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen model using FasterTransformer quantization.""" + model: nn.Module = QWen2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: QWen2Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a QWen model without quantization.""" + model: nn.Module = QWen2LMHeadModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/stable_lm/__init__.py b/python/mlc_chat/model/stable_lm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/mlc_chat/model/stable_lm/stablelm_loader.py b/python/mlc_chat/model/stable_lm/stablelm_loader.py new file mode 100644 index 0000000..f635c0e --- /dev/null +++ b/python/mlc_chat/model/stable_lm/stablelm_loader.py @@ -0,0 +1,104 @@ +""" +This file specifies how MLC's StableLM parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .stablelm_model import StableLMEpochConfig, StableLMEpochForCausalLM + + +def huggingface(model_config: StableLMEpochConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GPT2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = StableLMEpochForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + mlc_name = f"{attn}.qkv_proj.bias" + + # The old StableLM 3B model does not have bias term in q, k, v projection + if mlc_name in named_parameters: + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.bias", + f"{attn}.k_proj.bias", + f"{attn}.v_proj.bias", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/stable_lm/stablelm_model.py b/python/mlc_chat/model/stable_lm/stablelm_model.py new file mode 100644 index 0000000..3a5ce65 --- /dev/null +++ b/python/mlc_chat/model/stable_lm/stablelm_model.py @@ -0,0 +1,266 @@ +""" +Implementation for StableLM architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class StableLMEpochConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the StableLM model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + num_key_value_heads: int + norm_eps: float + rope_pct: float + rope_theta: int + intermediate_size: int + use_qkv_bias: bool = False # Default to False for Stable-LM 3B model + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + assert self.tensor_parallel_shards == 1, "StableLM currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring + + +class StableLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: StableLMEpochConfig): + self.hidden_size = config.hidden_size + self.rope_theta = config.rope_theta + self.rope_pct = config.rope_pct + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rotary_ndims = int(self.head_dim * config.rope_pct) + + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.use_qkv_bias, + ) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + # KV cache for single sequence + self.k_cache = nn.KVCache( + config.context_window_size, [self.num_key_value_heads, self.head_dim] + ) + self.v_cache = nn.KVCache( + config.context_window_size, [self.num_key_value_heads, self.head_dim] + ) + + def forward( # pylint: disable=too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, + ): + d, h_q, h_kv, t = self.head_dim, self.num_heads, self.num_key_value_heads, total_seq_len + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + # Step 1. QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Step 2. Apply QK rotary embedding + q, k, v = op_ext.llama_rope( + qkv, t, self.rope_theta, h_q, h_kv, rotary_dim=self.rotary_ndims + ) + # Step 3. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + # Step 5. Apply output projection + return self.o_proj(output) + + +class StalbeLMMLP(nn.Module): + def __init__(self, config: StableLMEpochConfig): + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class StableLMDecoderLayer(nn.Module): + def __init__(self, config: StableLMEpochConfig): + norm_eps = config.norm_eps + self.self_attn = StableLMAttention(config) + self.mlp = StalbeLMMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + +class StableLMEpochModel(nn.Module): + def __init__(self, config: StableLMEpochConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [StableLMDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + + def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + hidden_states = self.embed_tokens(input_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class StableLMEpochForCausalLM(nn.Module): + def __init__(self, config: StableLMEpochConfig): + self.model = StableLMEpochModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(inputs, total_seq_len, attention_mask) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def _attention_mask(batch_size, seq_len, total_seq_len): + return te.compute( + (batch_size, 1, seq_len, total_seq_len), + lambda b, _, i, j: tir.if_then_else( + i < j - (total_seq_len - seq_len), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[batch_size, seq_len, total_seq_len], + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def decode(self, inputs: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, total_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/stable_lm/stablelm_quantization.py b/python/mlc_chat/model/stable_lm/stablelm_quantization.py new file mode 100644 index 0000000..0bb6047 --- /dev/null +++ b/python/mlc_chat/model/stable_lm/stablelm_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's StableLM parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .stablelm_model import StableLMEpochConfig, StableLMEpochForCausalLM + + +def group_quant( + model_config: StableLMEpochConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a StableLM-architecture model using group quantization.""" + model: nn.Module = StableLMEpochForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: StableLMEpochConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a StableLM model using FasterTransformer quantization.""" + model: nn.Module = StableLMEpochForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: StableLMEpochConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a StableLM model without quantization.""" + model: nn.Module = StableLMEpochForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/nn/__init__.py b/python/mlc_chat/nn/__init__.py new file mode 100644 index 0000000..fb1743f --- /dev/null +++ b/python/mlc_chat/nn/__init__.py @@ -0,0 +1,3 @@ +"""Common `nn.Modules` used to define LLMs in this project.""" +from .expert import MixtralExperts +from .kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode, TIRPagedKVCache diff --git a/python/mlc_chat/nn/expert.py b/python/mlc_chat/nn/expert.py new file mode 100644 index 0000000..a4ff0cf --- /dev/null +++ b/python/mlc_chat/nn/expert.py @@ -0,0 +1,26 @@ +"""An nn.Module that represents MoE experts""" +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor + +from mlc_chat.op import extern, ft_gemm, moe_matmul + + +class MixtralExperts(nn.Module): + """Mixtral experts""" + + def __init__(self, num_local_experts, in_features, out_features): + self.num_local_experts = num_local_experts + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter((num_local_experts, out_features, in_features)) + self.dtype = "float32" + + def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,missing-docstring + assert x.ndim == 2 + if indptr.ndim == 2: + assert indptr.shape[0] == 1 + return moe_matmul.gemv(x, self.weight, indptr) + assert indptr.ndim == 1 + if extern.get_store().faster_transformer and self.dtype == "float16": + return ft_gemm.faster_transformer_moe_gemm(x, self.weight, indptr) + return moe_matmul.group_gemm(x, self.weight, indptr) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py new file mode 100644 index 0000000..e956037 --- /dev/null +++ b/python/mlc_chat/nn/kv_cache.py @@ -0,0 +1,1435 @@ +"""Attention KV cache modeling.""" + +# pylint: disable=too-many-statements,too-many-lines +import enum +import math +from typing import Optional, Tuple + +from tvm import relax as rx +from tvm import tir +from tvm.relax.frontend.nn import Object, Tensor +from tvm.runtime import DataType +from tvm.script import tir as T +from tvm.target import Target + +from mlc_chat.op.position_embedding import ( + llama_inplace_rope, + llama_rope_with_position_map, + rope_freq, +) + + +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +class PagedKVCache(Object): # pylint: disable=too-few-public-methods + """The Paged KV Cache used in LLM batching for efficient attention computation.""" + + @staticmethod + def create_generic( # pylint: disable=too-many-arguments + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rope_mode: RopeMode, + rope_scale: int, + rope_theta: int, + dtype: str, + rotary_dim: Optional[int] = None, + name: str = "paged_kv_cache", + ) -> "PagedKVCache": + """The generic function of creating a PagedKVCache, + which will be rewritten by functions in compilation pipeline. + """ + if rotary_dim is None: + rotary_dim = head_dim + return PagedKVCache( + _expr=rx.Call( + rx.extern("mlc.create_paged_kv_cache_generic"), + args=[ + rx.ShapeExpr( + [max_batch_size, max_total_seq_len, prefill_chunk_size, page_size] + ), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.PrimValue(rotary_dim), + rx.DataTypeImm(dtype), + ], + sinfo_args=[rx.ObjectStructInfo()], + ), + _name=name, + ) + + def attention( # pylint: disable=invalid-name, too-many-arguments + self, + layer_id: int, + q: Tensor, + k: Tensor, + v: Tensor, + attn_score_scaling_factor: float = 1.0, + ) -> Tensor: + """Compute attention with the given q/k/v data and in-cache k/v data + on the specified layer. Rotary position embeddings are applied to k/v + within this function. + + - For prefill, the input q and output tensor have shape + (1, total_seq_len, num_attention_heads, head_dim), and the + k/v tensors have shape (1, total_seq_len, num_key_value_heads, head_dim). + - For decode, the input q and output tensor have shape + (batch_size, 1, num_attention_heads, head_dim), and the + k/v tensors have shape (batch_size, 1, num_key_value_heads, head_dim). + """ + # pylint: disable=protected-access + q_shape = q.shape + q = q.reshape(q.shape[0] * q.shape[1], q.shape[2], q.shape[3]) + k = k.reshape(k.shape[0] * k.shape[1], k.shape[2], k.shape[3]) + v = v.reshape(v.shape[0] * v.shape[1], v.shape[2], v.shape[3]) + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.paged_attention_kv_cache_attention", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(attn_score_scaling_factor), + q._expr, + k._expr, + v._expr, + ], + out_sinfo=q._expr.struct_info, + ) + ) + ).reshape(*q_shape) + # pylint: enable=protected-access + + def attention_with_fused_qkv( # pylint: disable=invalid-name + self, + layer_id: int, + qkv: Tensor, + num_qo_heads: int, + attn_score_scaling_factor: float = 1.0, + ) -> Tensor: + """Compute attention with the given fused q/k/v data and in-cache k/v data + on the specified layer. Rotary position embeddings are applied to k/v + within this function. + + - For prefill, the input qkv and output tensor have shape + (1, total_seq_len) for the first two dimensions. + - For decode, the input qkv and output tensor have shape + (batch_size, 1) for the first two dimensions. + - The input qkv have `2 * num_qo_heads + num_kv_heads` at the third dim. + - The output tensor have `num_qo_heads` at the third dim. + - The input qkv and output tensor have `head_dim` at the last dim. + """ + # pylint: disable=protected-access + b, s, _, d = qkv._expr.struct_info.shape + qkv = qkv.reshape(b * s, qkv.shape[2], d) + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(attn_score_scaling_factor), + qkv._expr, + ], + out_sinfo=rx.TensorStructInfo((b * s, num_qo_heads, d), qkv.dtype), + ) + ) + ).reshape(b, s, num_qo_heads, d) + + def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: + """Get the in-sequence positions of each slot in the query, + which are needed for applying positional embeddings in some models. + + Parameters + ---------- + total_length : tir.PrimExpr + The summed-up total sequence length of queries in + the batch being forwarded. + + Returns + ------- + q_positions : Tensor + The in-sequence query positions, in shape `(total_length,)` + """ + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_get_query_positions", + self._expr, + sinfo_args=rx.TensorStructInfo((total_length,), "int32"), + ) + ) + ) + + # pylint: enable=protected-access + + +class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods + """Paged KV cache using FlashInfer (CUDA) kernels.""" + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rope_mode: RopeMode, + rope_scale: int, + rope_theta: int, + rotary_dim: int, + dtype: str, + target: Target, + name: str = "paged_kv_cache", + ) -> None: + """Create a paged KV cache object with FlashInfer kernels. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + prefill_chunk_size : tir.Var + The maximum total sequence length in a prefill. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + rope_mode : RopeMode + The RoPE mode of the Paged KV cache. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + rope_scale : int + The scale of rotary position embedding. + rope_theta : int + The base of rotary position embedding. + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. + """ + if rope_mode == RopeMode.INLINE: + assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim." + + bb = rx.BlockBuilder.current() # pylint: disable=invalid-name + args = [ + rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), + # pylint: disable=line-too-long + # fmt: off + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + rx.extern("paged_kv_cache.attention_kernel_prefill"), + rx.extern("paged_kv_cache.attention_kernel_decode"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), + rx.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), + rx.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), + rx.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), + rx.extern("paged_kv_cache.attention_kernel_decode_end_forward"), + rx.extern("flashinfer.merge_state_in_place"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + # fmt: on + # pylint: enable=line-too-long + ] + super().__init__( + _expr=rx.Call( + rx.extern("vm.builtin.paged_attention_kv_cache_create"), + args=args, + sinfo_args=[rx.ObjectStructInfo()], + ), + _name=name, + ) + + +class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods + """Paged KV cache using TIR kernels.""" + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + rope_mode: RopeMode, + head_dim: int, + rope_scale: int, + rope_theta: int, + rotary_dim: int, + dtype: str, + target: Target, + name: str = "paged_kv_cache", + ) -> None: + """Create a paged KV cache object with TIR kernels. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + prefill_chunk_size : tir.Var + The maximum total sequence length in a prefill. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + rope_mode : RopeMode + The RoPE mode of the Paged KV cache. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + rope_scale : int + The scale of rotary position embedding. + rope_theta : int + The base of rotary position embedding. + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. + target : Target + The target to build the model to. + """ + + bb = rx.BlockBuilder.current() + args = [ + rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), + # pylint: disable=line-too-long + # fmt: off + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_decode"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), + bb.add_func(_merge_state_inplace(num_key_value_heads, head_dim, dtype, target), "tir_attention_merge_state"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + # fmt: on + # pylint: enable=line-too-long + ] + super().__init__( + _expr=rx.Call( + rx.extern("vm.builtin.paged_attention_kv_cache_create_reduced"), + args=args, + sinfo_args=[rx.ObjectStructInfo()], + ), + _name=name, + ) + + +# mypy: disable-error-code="attr-defined,valid-type,no-redef" +# pylint: disable=too-many-locals + + +def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): + """Return the TIR function that appends new k/v data to PagedKVCache.""" + + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + # fmt: on + # pylint: enable=line-too-long,invalid-name + + return tir_kv_cache_transpose_append + + +def _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): + """Return the TIR function that fetches the k/v data on given positions and layer.""" + + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_debug_get_kv( + var_pages: T.handle, + var_position_map: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + layer_id: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + seqlen = T.SizeVar("num_tokens_including_cache", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd]) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] # type: ignore[name-defined] + k_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd] + v_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd] + # fmt: on + # pylint: enable=line-too-long,invalid-name + + return tir_kv_cache_debug_get_kv + + +def _rope( # pylint: disable=too-many-arguments + buffer: T.Buffer, + offset: tir.Var, + rotary_dim: int, + theta: tir.Var, + scale: tir.Var, + indices: Tuple[tir.Var, ...], + qkv_dtype="float16", +): + d = indices[-1] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) + cos = cos_freq * buffer[indices] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -buffer[indices[:-1] + (d + rotary_dim // 2,)], + buffer[indices[:-1] + (d - rotary_dim // 2,)], + ) + return cos + sin + + +def _var(dtype): + return T.alloc_buffer((1,), dtype, scope="local") + + +def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument + # pylint: disable=invalid-name + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + L_per_cta = tile_x // group_size + + def mask(causal, row, col, kv_len, qo_len): + return T.if_then_else( + causal > 0, + col < kv_len - qo_len + row + 1, + col < kv_len, + ) + + # pylint: disable=line-too-long,too-many-arguments,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_last_page_len: T.handle, # [b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32") + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32") + last_page_len = T.match_buffer(var_last_page_len, (batch_size,), "int32") + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32") + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta + H_qo_start: T.int32 = by * group_size + + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + cur_last_page_len: T.int32 = last_page_len[b_idx] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, + 0 + ) + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = L_start + i // group_size + cur_H_qo = H_qo_start + i % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), + pages[page_no, 0, by, page_offset, j] + ) + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore + V_smem[i, j] = pages[page_no, 1, by, page_offset, j] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + for j in T.serial(tile_z): + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * V_smem[k, j] + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + sch = tir.Schedule(batch_prefill_paged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) + tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _attention_decode( + num_kv_heads, + num_qo_heads, + head_dim, + qkv_dtype, + target: Target, # pylint: disable=unused-argument +): + # pylint: disable=invalid-name + qkv_dtype_bytes = 2 + H_qo = num_qo_heads + H_kv = num_kv_heads + D = head_dim + + thread_limit = 512 if str(target.kind) != "webgpu" else 256 + + GROUP_SIZE = H_qo // H_kv + VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) + bdx = D // VEC_SIZE + bdy = GROUP_SIZE + while bdx * bdy > thread_limit and bdy > 1: + bdy //= 2 + gdz = GROUP_SIZE // bdy + threads_per_CTA = max(thread_limit, bdx * bdy) + bdz = threads_per_CTA // (bdx * bdy) + tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1 + log2e = math.log2(math.exp(1)) + + # pylint: disable=line-too-long,too-many-arguments,too-many-branches + # fmt: off + @T.prim_func + def batch_decode_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + Q_handle: T.handle, + pages_handle: T.handle, + page_table_indptr_handle: T.handle, + page_table_values_handle: T.handle, + last_page_len_handle: T.handle, + k_rope_pos_offset_handle: T.handle, + q_rope_position_handle: T.handle, + output_handle: T.handle, + lse_handle: T.handle, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1}) + B = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + + Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) + pages = T.match_buffer( + pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype + ) + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32") + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32") + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32") + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32") + last_page_len = T.match_buffer(last_page_len_handle, (B,), "int32") + output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) + lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable + + sm_scale = 1.0 / math.sqrt(float(D)) * log2e + + for bx in T.thread_binding(B, thread="blockIdx.x"): + for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + for tz in T.thread_binding(bdz, thread="threadIdx.z"): + with T.block("attn"): + Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") + V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") + O_allreduce = T.alloc_buffer((bdz, bdy, D), "float32", scope="shared") + md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") + S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") + K_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + m_prev = T.alloc_buffer((1,), "float32", scope="local") + d_prev = T.alloc_buffer((1,), "float32", scope="local") + other_m = T.alloc_buffer((1,), "float32", scope="local") + other_d = T.alloc_buffer((1,), "float32", scope="local") + other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + st_m = T.alloc_buffer((1,), "float32", scope="local") + st_d = T.alloc_buffer((1,), "float32", scope="local") + O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + + by: T.int32 = fused_by_bz % H_kv + bz: T.int32 = fused_by_bz // H_kv + batch_idx: T.int32 = bx + cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] + cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] + cur_last_page_len: T.int32 = last_page_len[batch_idx] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, + 0 + ) + + # init states + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + # load q + for vec in T.vectorized(VEC_SIZE): + Q_local[vec] = T.if_then_else( + rotary_mode == 1, + _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype), + Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] + ) + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): + tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore + tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore + # load K from global memory to shared memory + for j in T.serial(tile_size_per_bdx): + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), + pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] + ) + else: + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + T.tvm_storage_sync("shared") + # load V from global memory to shared memory + for j in T.serial(tile_size_per_bdx): + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] + else: + for vec in T.vectorized(VEC_SIZE): + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + T.tvm_storage_sync("shared") + # compute QK + m_prev[0] = st_m[0] + for j in T.serial(bdy * tile_size_per_bdx): + # load K from shared memory to local memory + for vec in T.vectorized(VEC_SIZE): + K_local[vec] = K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] + # compute S = Q * K * sm_scale + S_reduce_local[0] = 0 + for vec in T.serial(VEC_SIZE): + S_reduce_local[0] += Q_local[vec] * K_local[vec] * attn_score_scaling_factor * sm_scale + + with T.block("block_cross_thread"): + T.reads(S_reduce_local[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") + + if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: + S_local[j] = t0[0] + else: + S_local[j] = -5e4 + # update st_m + st_m[0] = T.max(st_m[0], S_local[j]) + + # update st_d, st_O + o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) + st_d[0] *= o_scale + for j in T.serial(bdy * tile_size_per_bdx): + S_local[j] = T.exp2(S_local[j] - st_m[0]) + st_d[0] += S_local[j] + for j in T.vectorized(VEC_SIZE): + O_local[j] *= o_scale + + # load V from shared memory to local memory + # compute O + for j in T.serial(bdy * tile_size_per_bdx): + for vec in T.vectorized(VEC_SIZE): + V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] + for vec in T.vectorized(VEC_SIZE): + O_local[vec] += V_local[vec] * S_local[j] + + if bdz > 1: + # allreduce over bdz + for vec in T.vectorized(VEC_SIZE): + O_allreduce[tz, ty, tx * VEC_SIZE + vec] = O_local[vec] + md_allreduce[tz, ty, 0] = st_m[0] + md_allreduce[tz, ty, 1] = st_d[0] + T.tvm_storage_sync("shared") + + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + for j in T.serial(bdz): + m_prev[0] = st_m[0] + d_prev[0] = st_d[0] + other_m[0] = md_allreduce[j, ty, 0] + other_d[0] = md_allreduce[j, ty, 1] + for vec in T.vectorized(VEC_SIZE): + other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] + st_m[0] = T.max(st_m[0], other_m[0]) + st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) + for vec in T.serial(VEC_SIZE): + O_local[vec] = O_local[vec] * T.exp2(m_prev[0] - st_m[0]) + other_o[vec] * T.exp2(other_m[0] - st_m[0]) + + # normalize O + for vec in T.serial(VEC_SIZE): + O_local[vec] /= st_d[0] + + # store O to global memory + for vec in T.vectorized(VEC_SIZE): + output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] + + # store lse to global memory + lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0]) + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + return batch_decode_paged_kv + + +def _merge_state_inplace( + num_heads, head_dim, v_dtype, target: Target +): # pylint: disable=unused-argument + # pylint: disable=invalid-name + v_dtype_bytes = 2 + VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) + bdx = head_dim // VEC_SIZE + bdy = num_heads + + @T.prim_func + def merge_state_inplace( + v: T.handle, + s: T.handle, + v_other: T.handle, + s_other: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + N = T.int32(is_size_var=True) + H = T.int32(is_size_var=True) + D = T.int32(is_size_var=True) + + V = T.match_buffer(v, (N, H, D), v_dtype) + S = T.match_buffer(s, (N, H), "float32") + V_other = T.match_buffer(v_other, (N, H, D), v_dtype) + S_other = T.match_buffer(s_other, (N, H), "float32") + + for bx in T.thread_binding(N, thread="blockIdx.x"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("merge"): + s_val = _var("float32") + s_other_val = _var("float32") + s_max = _var("float32") + scale = _var("float32") + other_scale = _var("float32") + + v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + + s_val[0] = S[bx, ty] + s_other_val[0] = S_other[bx, ty] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + + # load v + for vec in T.vectorized(VEC_SIZE): + v_vec[vec] = V[bx, ty, tx * VEC_SIZE + vec] + # load v_other + for vec in T.vectorized(VEC_SIZE): + v_other_vec[vec] = V_other[bx, ty, tx * VEC_SIZE + vec] + + # merge + for vec in T.serial(VEC_SIZE): + v_vec[vec] = v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] + + # store v + for vec in T.vectorized(VEC_SIZE): + V[bx, ty, tx * VEC_SIZE + vec] = v_vec[vec] + + # store s + S[bx, ty] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + # pylint: enable=invalid-name + return merge_state_inplace + + +def _attention_prefill_ragged( + h_kv, h_q, d, dtype, target: Target +): # pylint: disable=unused-argument + # pylint: disable=invalid-name,line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + L_per_cta = tile_x // group_size + + def mask(causal, row, col, kv_len, qo_len): + return T.if_then_else( + causal > 0, + col < kv_len - qo_len + row + 1, + col < kv_len, + ) + + # fmt: off + @T.prim_func + def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32 + ): + batch_size = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32") + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32") + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta + H_qo_start: T.int32 = by * group_size + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = L_start + i // group_size + cur_H_qo = H_qo_start + i % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype), + k[L_kv_base + cur_L, by, j] + ) + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + V_smem[i, j] = v[L_kv_base + cur_L, by, j] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + for j in T.serial(tile_z): + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + if mask(causal, + row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): + row: T.int32 = i * 32 * num_warps + ty * 32 + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * V_smem[k, j] + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + sch = tir.Schedule(batch_prefill_ragged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) + tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) diff --git a/python/mlc_chat/op/__init__.py b/python/mlc_chat/op/__init__.py new file mode 100644 index 0000000..3425686 --- /dev/null +++ b/python/mlc_chat/op/__init__.py @@ -0,0 +1,6 @@ +"""Extern module for compiler.""" +from . import moe_matmul, moe_misc +from .attention import attention +from .extern import configure, enable, get_store +from .ft_gemm import faster_transformer_dequantize_gemm +from .position_embedding import llama_rope diff --git a/python/mlc_chat/op/attention.py b/python/mlc_chat/op/attention.py new file mode 100644 index 0000000..02f21a6 --- /dev/null +++ b/python/mlc_chat/op/attention.py @@ -0,0 +1,182 @@ +"""Operators enabled by external modules.""" +import math + +from tvm import tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import op + +from mlc_chat.support import logging + +from . import extern as _extern + +logger = logging.getLogger(__name__) + + +WARN_FLASHINFER_GROUP_SIZE = False +WARN_FLASHINFER_HEAD_DIM = False + + +def attention( # pylint: disable=invalid-name,too-many-locals,too-many-statements,too-many-arguments + q: nn.Tensor, + k: nn.Tensor, + v: nn.Tensor, + casual_mask: nn.Tensor, + attn_score_scaling_factor: float = 1.0, + qk_dtype: str = None, +) -> nn.Tensor: + """Attention with casual mask. + + --- Variables --- + s: sequence length of the current query + t: total sequence length + d: head dimension + h, h_q: number of heads in query + h_kv: number of heads in key and value + b: batch size = 1 + + --- Shapes --- + q: [b, s, h_q, d] + k: [t, h_kv, d] + v: [t, h_kv, d] + o: [1, s, hidden = h_q * d] + + --- Computation --- + + .. code-block:: python + + if h_kv != h_q: + k = k.repeat(h_q // h_kv, axis=1) + v = v.repeat(h_q // h_kv, axis=1) + q -> [b, h, s, d] + k, v -> [b, h, t, d] + attn = q @ k^T / sqrt(d) * attn_score_scaling_factor # [b, h, s, t] + attn = softmax_with_mask(attn, casual_mask, axis=-1) + o = attn @ v # [b, h, s, d] + o -> [b, s, h * d] + + --- Other params --- + qk_dtype: if set, `matmul(Q, K, out_dtype=qk_dtype)`, (otherwise use `q.dtype` as `out_dtype`). + For FlashInfer, if "float32", sets `allow_fp16_qk_reduction` to False; otherwise no effect. + """ + assert q.ndim == 4 and k.ndim in [3, 4] and v.ndim in [3, 4] + b, s, h_q, d = q.shape + t, h_kv, _ = k.shape[-3:] + group_size = h_q // h_kv + assert b == 1, "batch size must be 1" + + def _fallback(): + nonlocal q, k, v, qk_dtype + if k.ndim == 3: + k = op.reshape(k, [b, t, h_kv, d]) + if v.ndim == 3: + v = op.reshape(v, [b, t, h_kv, d]) + if h_kv != h_q: + k = k.repeat(h_q // h_kv, axis=2) + v = v.repeat(h_q // h_kv, axis=2) + q = op.permute_dims(q, [0, 2, 1, 3]) + k = op.permute_dims(k, [0, 2, 1, 3]) + v = op.permute_dims(v, [0, 2, 1, 3]) + model_dtype = q.dtype + if qk_dtype is None: + qk_dtype = model_dtype + attn_weights = op.matmul( # [b, h, s, t] + q, # [b, h, s, d] + op.permute_dims(k, [0, 1, 3, 2]), # [b, h, d, t] + out_dtype=qk_dtype, + ) / math.sqrt(d) + if attn_score_scaling_factor != 1.0: + attn_weights = attn_weights * attn_score_scaling_factor + attn_weights = attn_weights.maximum(tir.min_value(model_dtype)).minimum( + casual_mask.astype(qk_dtype) + ) + attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(model_dtype) + output = op.matmul(attn_weights, v) # [b, h, s, d] <= [b, h, s, t] x [b, h, t, d] + output = op.permute_dims(output, [0, 2, 1, 3]) # [b, s, h, d] + output = op.reshape(output, [b, s, h_q * d]) # [b, s, h * d] + return output + + # FlashInfer Implementation + if ( + _extern.get_store().flashinfer + and attn_score_scaling_factor == 1.0 + and q.dtype == "float16" + and k.dtype == "float16" + and v.dtype == "float16" + ): + if group_size not in [1, 4, 8]: + global WARN_FLASHINFER_GROUP_SIZE # pylint: disable=global-statement + if not WARN_FLASHINFER_GROUP_SIZE: + WARN_FLASHINFER_GROUP_SIZE = True + logger.warning( + "FlashInfer only supports group size in [1, 4, 8], but got %d. Skip and " + "fallback to default implementation.", + group_size, + ) + return _fallback() + if d not in [128]: + global WARN_FLASHINFER_HEAD_DIM # pylint: disable=global-statement + if not WARN_FLASHINFER_HEAD_DIM: + WARN_FLASHINFER_HEAD_DIM = True + logger.warning( + "FlashInfer only supports head_dim in [128], but got %d. Skip and fallback to " + "default implementation.", + d, + ) + return _fallback() + rope_theta = 0.0 + rope_scale = 1.0 + qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim + rotary_mode = 0 # "kNone" + casual = 1 # True + fp16_qk = 1 # True + if qk_dtype == "float32": + fp16_qk = 0 # False + + # 32MB scratchpad + scratch = op.empty([8192 * 1024], dtype="float32") # pylint: disable=no-member + + def _decode(): + return op.extern( + name="flashinfer.single_decode", + args=[ + q, + k, + v, + scratch, + qkv_layout, + rotary_mode, + rope_scale, + rope_theta, + ], + out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"), + ) + + def _prefill(): + return op.extern( + name="flashinfer.single_prefill", + args=[ + q, + k, + v, + scratch, + casual, + qkv_layout, + rotary_mode, + fp16_qk, + rope_scale, + rope_theta, + ], + out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"), + ) + + if isinstance(s, int) and s == 1: + func = "decode" + else: + func = "prefill" + return { + "decode": _decode, + "prefill": _prefill, + }[func]() + + # Fallback Implementation + return _fallback() diff --git a/python/mlc_chat/op/extern.py b/python/mlc_chat/op/extern.py new file mode 100644 index 0000000..5fa7e82 --- /dev/null +++ b/python/mlc_chat/op/extern.py @@ -0,0 +1,65 @@ +"""Potential externel modules managed by MLC compilation stack. + +An externl module could contain one or multiple handcrafted kernels, as long as it is provided as +an object file (`.o`), a C++ source file (`.cc`), or a CUDA source file (`.cu`). It can be +integrated into the system pretty smoothly. + +As examples, `flashinfer.py` contains such an example that instructs MLC to compile +"$tvm_home/3rdparty/flashinfer/src/tvm_wrapper.cu" with a specific set of compilation flags and then +link into the generated artifact of MLC LLM. TVM PR #16247 +(https://github.com/apache/tvm/pull/16247/) provides more details of using TVM's +`nn.SourceModule` to integrate C++ and CUDA files, and `nn.ObjectModule` to integrate object files. + +To conveniently use those externel modules, MLC LLM compilation pipeline manages an extra global +singleton `Store: ExternalModuleStore` to store the configured modules. It is supposed to be enabled +before any compilation happens, and configured during a model's `forward` method is invoked. +""" +import dataclasses +from typing import Optional + +from tvm.target import Target + + +@dataclasses.dataclass +class ExternModuleStore: + """Global store of external modules enabled during compilation.""" + + configured: bool = False + target: Optional[Target] = None + flashinfer: bool = False + faster_transformer: bool = False + + +STORE: ExternModuleStore = ExternModuleStore() +"""Singleton of `ExternModuleStore`.""" + + +def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None: + """Enable external modules. It should be called before any compilation happens.""" + global STORE # pylint: disable=global-statement + STORE = ExternModuleStore( + configured=False, + target=target, + flashinfer=flashinfer, + faster_transformer=faster_transformer, + ) + + +def get_store() -> ExternModuleStore: + """Get the global store of external modules.""" + return STORE + + +def configure() -> None: + """Configure external modules with extra parameters. It should be called during a model's + `forward` method is invoked. + + Parameters + ---------- + """ + store = get_store() + if store.configured: + return + store.configured = True + if store.flashinfer or store.faster_transformer: + assert store.target.kind.name == "cuda" diff --git a/python/mlc_chat/op/ft_gemm.py b/python/mlc_chat/op/ft_gemm.py new file mode 100644 index 0000000..0a4edc6 --- /dev/null +++ b/python/mlc_chat/op/ft_gemm.py @@ -0,0 +1,135 @@ +"""Operators enabled by external modules.""" +import operator +from functools import reduce +from typing import Optional + +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import op + + +def faster_transformer_dequantize_gemm( # pylint: disable=too-many-arguments + x: nn.Tensor, + weight: nn.Tensor, + scale: nn.Tensor, + bias: Optional[nn.Tensor] = None, + activation: Optional[str] = None, + group_size: Optional[int] = None, +): + """ + Faster Transformer dequantize gemm inference with CutlassFpAIntB + + Parameters + ---------- + x : nn.Tensor + The input tensor, with shape of [*m, k]. + + weight : nn.Tensor + The quantized weight data tensor, with shape of [k, n // num_elem_per_storage]. + + scale : nn.Tensor + The quantized weight scale tensor, with shape of [k // group_size, n]. + + bias : Optional[nn.Tensor] + The optional bias for matmul, with shape broadcastable to [*m, n]. + + group_size : Optional[int] + The optional group size. If not set, then using k as group size. + + Returns + ------ + ret: nn.Tensor + The output tensor of deocde matmul, with shape of [*m, n]. + """ + assert x.dtype == "float16" and x.ndim >= 1 + assert weight.ndim == 2 + assert scale.dtype == "float16" and scale.ndim == 2 + assert x.shape[-1] == weight.shape[0], ( + "Reduction dimension mismatched between x and weight, " + f"{x.shape[-1]} vs {weight.shape[0]}." + ) + assert activation in [ + None, + "relu", + "gelu", + "silu", + "identity", + ], "Supported activations are [None, 'identity', 'gelu', 'silu', 'relu']." + activation = activation if activation else "identity" + m = reduce(operator.mul, x.shape[:-1], 1) + k = x.shape[-1] + n = scale.shape[1] + + if not group_size: + group_size = k + + if bias: + assert bias.dtype == "float16" and bias.ndim >= 1 + bias_stride = ( + bias.shape[-1] + if bias and not reduce(operator.mul, bias.shape, 1) == bias.shape[-1] + else 0 + ) + return op.extern( + name="fastertransformer.gemm_fp16_int_bias", + args=[ + x, + weight, + scale, + bias, + activation, + m, + n, + k, + group_size, + bias_stride, + ], + out=nn.Tensor.placeholder((*x.shape[:-1], scale.shape[1]), dtype="float16"), + ) + return op.extern( + name="fastertransformer.gemm_fp16_int", + args=[x, weight, scale, activation, m, n, k, group_size], + out=nn.Tensor.placeholder((*x.shape[:-1], scale.shape[1]), dtype="float16"), + ) + + +def faster_transformer_moe_gemm( # pylint: disable=too-many-arguments + x: nn.Tensor, + weight: nn.Tensor, + total_rows_before: nn.Tensor, +): + """ + Faster Transformer moe gemm inference with CutlassFpAIntB + + Parameters + ---------- + x : nn.Tensor + The input tensor, with shape of [*m, k]. + + weight : nn.Tensor + The weight data tensor, with shape of [num_experts, n, k]. + + total_rows_before : nn.Tensor + The total rows before tensor the current expert, with shape of [num_experts]. This is the + same as the indptr excluding the first zero element. + + Returns + ------ + ret: nn.Tensor + The output tensor of deocde matmul, with shape of [*m, n]. + """ + assert x.dtype == "float16" and x.ndim >= 1 + assert weight.dtype == "float16" and weight.ndim == 3 + assert x.shape[-1] == weight.shape[-1], ( + "Reduction dimension mismatched between x and weight, " + f"{x.shape[-1]} vs {weight.shape[-1]}." + ) + m = reduce(operator.mul, x.shape[:-1], 1) + num_experts = weight.shape[0] + n = weight.shape[1] + k = x.shape[-1] + + return op.extern( + name="fastertransformer.moe_gemm_fp16_fp16", + args=[x, weight, total_rows_before, m, n, k, num_experts], + out=nn.Tensor.placeholder((*x.shape[:-1], n), dtype="float16"), + ) diff --git a/python/mlc_chat/op/moe_matmul.py b/python/mlc_chat/op/moe_matmul.py new file mode 100644 index 0000000..169140a --- /dev/null +++ b/python/mlc_chat/op/moe_matmul.py @@ -0,0 +1,555 @@ +"""Mixture of Experts operators""" + +from tvm import DataType, tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,too-many-statements + + +def gemv(x: Tensor, w: Tensor, indptr: Tensor) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The weight tensor of shape (local_experts, out_features, in_features), where `local_experts` + is the total number of experts. + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + Returns + ------- + out : Tensor + The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the + number of activated experts per token. + """ + (local_experts, out_features, in_features), dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + x_leading_dim, _ = x.shape + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily. + assert w.shape == [local_experts, out_features, in_features] and w.dtype == dtype + assert x.shape == [x_leading_dim, in_features] and x.dtype == dtype + assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" + assert x_leading_dim in [1, experts_per_tok] + + @T.prim_func(private=True) + def _func( + x: T.Buffer((x_leading_dim, in_features), dtype), + w: T.Buffer((local_experts, out_features, in_features), dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for e in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, e) + T.reads(x[:, :], w[indptr[0, e], :, :], indptr[0, e]) + T.writes(o[e, :]) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), dtype) + o[e, i] += access_x(x, e, j) * w[indptr[0, e], i, j] + + return op.tensor_ir_op( + _func, + "moe_gemv", + args=[x, w, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], dtype), + ) + + +def dequantize_gemv( # pylint: disable=too-many-arguments + x: Tensor, + w: Tensor, + scale: Tensor, + indptr: Tensor, + quantize_dtype: str, + group_size: int, +) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized. + It needs to be dequantized before the GEMV computation. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The quantized weight tensor of shape (local_experts, out_features, in_features // n), + where n is the number of elements per storage dtype, e.g. if the storage dtype is uint32, + and the quantize dtype is int4, then n is 8. + `local_experts` is the total number of experts including activated and non-active ones. + + scale : Tensor + The scale tensor of shape (local_experts, out_features, in_features // group_size), where + `local_experts` is the total number of experts including activated and non-active ones. + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + quantize_dtype : str + The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc. + + group_size : int + The number of elements in each quantization group, e.g. 32 or 128. + + Returns + ------- + out : Tensor + The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the + number of activated experts per token. + """ + (x_leading_dim, in_features), model_dtype = x.shape, x.dtype + (local_experts, out_features, _), storage_dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_group = (in_features + group_size - 1) // group_size + num_storage = group_size // num_elem_per_storage * num_group + + def _dequantize(w, s, e, i, j): + tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) + tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype) + w = w[e, i, j // num_elem_per_storage] + s = s[e, i, j // group_size] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype) + return (w - tir_max_int) * s + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + assert x.shape == [x_leading_dim, in_features] and x.dtype == model_dtype + assert w.shape == [local_experts, out_features, num_storage] and w.dtype == storage_dtype + assert scale.shape == [local_experts, out_features, num_group] and scale.dtype == model_dtype + assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" + assert x_leading_dim in [1, experts_per_tok] + + @T.prim_func(private=True) + def _func( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + scale: T.Buffer((local_experts, out_features, num_group), model_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, scale, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + return op.tensor_ir_op( + _func, + "moe_dequantize_gemv", + args=[x, w, scale, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + + +def group_gemm(x: Tensor, w: Tensor, indptr: Tensor): # pylint: disable=too-many-statements + """Group GEMM in MoE models. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape. + + w : Tensor + Weight tensor of shape (num_local_experts, out_features, in_features). + `w[i, :, :]` is the weight matrix for the `i`-th local expert. + + indptr : Tensor + Index pointer tensor of shape (num_local_experts + 1, ). + `x[indptr[a] : indptr[a + 1]]` is the input for the `i`-th local expert. + + Returns + ------- + out : Tensor + Output tensor of shape (batch_size, out_features). + """ + # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily. + (num_local_experts, out_features, in_features), dtype = w.shape, w.dtype + + assert x.shape[1:] == [in_features] and x.dtype == dtype + assert indptr.shape == [num_local_experts + 1] and indptr.dtype == "int32" + + Ne, N, K = num_local_experts, out_features, in_features + BLK_M, BLK_N, BLK_K = 8, 128, 32 + TX, TY, CTA_COUNT = 8, 32, 1024 + VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1 + UNROLL = 64 + STORAGE_ALIGN = False + assert BLK_K % 8 == 0 + tiles_per_row = (N + BLK_N - 1) // BLK_N + zero = tir.const(0, dtype) + + @T.prim_func(private=True) + def _func( # pylint: disable=too-many-statements + var_x: T.handle, + var_w: T.handle, + var_indptr: T.handle, + var_o: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32(is_size_var=True) + X = T.match_buffer(var_x, (B, K), dtype) + W = T.match_buffer(var_w, (Ne, N, K), dtype) + indptr = T.match_buffer(var_indptr, (Ne + 1,), "int32") + O = T.match_buffer(var_o, (B, N), dtype) + + for _bx in T.thread_binding(CTA_COUNT, thread="blockIdx.x"): + with T.block("CTA"): + bx = T.axis.spatial(CTA_COUNT, _bx) + T.reads(indptr[:], X[:, :], W[:, :, :]) + T.writes(O[:, :]) + # pylint: disable=redefined-builtin + sum = T.alloc_buffer((2,), "int32", scope="local") + row = T.alloc_buffer((2,), "int32", scope="local") + cur_e = T.alloc_buffer((1,), "int32", scope="local") + tile_id = T.alloc_buffer((1,), "int32", scope="local") + # pylint: enable=redefined-builtin + sum[0] = 0 + sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row + row[0] = 0 + row[1] = indptr[1] - indptr[0] + cur_e[0] = 0 + tile_id[0] = bx + while T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # move to the current group + while sum[1] <= tile_id[0] and cur_e[0] < Ne: + cur_e[0] += 1 + if cur_e[0] < Ne: + e: T.int32 = cur_e[0] + delta: T.int32 = indptr[e + 1] - indptr[e] + sum[0] = sum[1] + sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row + row[0] = row[1] + row[1] += delta + # sync threads to make sure all threads have the same tile position + T.tvm_storage_sync("shared") + if T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # fetch current tile position + e: T.int32 = cur_e[0] # type: ignore[no-redef] + num_tiles: T.int32 = tile_id[0] - sum[0] + m_offset: T.int32 = BLK_M * T.floordiv(num_tiles, tiles_per_row) + row[0] + n_offset: T.int32 = BLK_N * T.floormod(num_tiles, tiles_per_row) + with T.block("gemm"): + T.reads( + row[1], + X[m_offset : m_offset + BLK_M, :], + W[e, n_offset : n_offset + BLK_N, :], + ) + T.writes(O[m_offset : m_offset + BLK_M, n_offset : n_offset + BLK_N]) + X_tile = T.alloc_buffer((BLK_M, K), dtype, scope="shared") + W_tile = T.alloc_buffer((BLK_N, K), dtype, scope="shared") + O_tile = T.alloc_buffer((BLK_M, BLK_N), dtype, scope="local") + for a0, a1 in T.grid(BLK_M, K): + with T.block("X_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + X_tile[i, j] = T.if_then_else( + m_offset + i < row[1], + X[m_offset + i, j], + zero, + ) + for a0, a1 in T.grid(BLK_N, K): + with T.block("W_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + W_tile[i, j] = T.if_then_else( + n_offset + i < N, + W[e, n_offset + i, j], + zero, + ) + for a0, a1, a2 in T.grid(BLK_M, BLK_N, K): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [a0, a1, a2]) + with T.init(): + O_tile[i, j] = zero + O_tile[i, j] += X_tile[i, k] * W_tile[j, k] + for a0, a1 in T.grid(BLK_M, BLK_N): + with T.block("store"): + i, j = T.axis.remap("SS", [a0, a1]) + if m_offset + i < row[1] and n_offset + j < N: + O[m_offset + i, n_offset + j] = O_tile[i, j] + # move to next tile + tile_id[0] += CTA_COUNT + + def _schedule(): + sch = tir.Schedule(_func) + + def _cooperative_fetch(block, vec_len): + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[TY, TX, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if STORAGE_ALIGN: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + main_block = sch.get_block("compute") + x, y, k = sch.get_loops(main_block) + ty, yi = sch.split(y, [TY, None]) + tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT]) + ko, ki = sch.split(k, factors=[None, BLK_K]) + sch.reorder(ty, tx, ko, ki, yi, xi, vec_c) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec_c) + if UNROLL > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=UNROLL) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + l2g = sch.get_block("store") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O]) + sch.vectorize(v) + _cooperative_fetch(sch.get_block("X_shared"), vec_len=VEC_X) + _cooperative_fetch(sch.get_block("W_shared"), vec_len=VEC_W) + sch.decompose_reduction(main_block, ko) + return sch.mod["main"] + + return op.tensor_ir_op( + _schedule(), + "group_gemm", + args=[x, w, indptr], + out=Tensor.placeholder([x.shape[0], out_features], dtype), + ) + + +def dequantize_group_gemm( + x: Tensor, + w: Tensor, + scale: Tensor, + indptr: Tensor, + quantize_dtype: str, + indptr_dtype: str, + group_size: int, +): + """Group GEMM in MoE models but the weight is quantized. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape. + + w : Tensor + Weight tensor of shape (num_local_experts, out_features, in_features // n), where n is the + number of elements per storage dtype, e.g. if the storage dtype is uint32, and the quantize + dtype is int4, then n is 8. + + scale : Tensor + The scale tensor of shape (num_local_experts, out_features, in_features // group_size). + + indptr : Tensor + Index pointer tensor of shape (num_local_experts + 1, ). `x[indptr[a] : indptr[a + 1]]` is + the input for the `i`-th local expert. + + group_size : int + The number of elements in each quantization group, e.g. 32 or 128. + + quantize_dtype : str + The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc. + + indptr_dtype : str + The dtype of the index pointer tensor, which can be int32 or int64. + + Returns + ------- + out : Tensor + Output tensor of shape (batch_size, out_features). + """ + (_, in_features), model_dtype = x.shape, x.dtype + (num_local_experts, out_features, _), storage_dtype = w.shape, w.dtype + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_group = (in_features + group_size - 1) // group_size + num_storage = group_size // num_elem_per_storage * num_group + + def _dequantize(w, s, e, i, j): + tir_bin_mask = tir.const((1 << quantize_dtype_bits) - 1, storage_dtype) + tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype) + w = w[e, i, j // num_elem_per_storage] + s = s[e, i, j // group_size] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype) + return (w - tir_max_int) * s + + Ne, N, K = num_local_experts, out_features, in_features + BLK_M, BLK_N, BLK_K = 8, 128, 32 + TX, TY, CTA_COUNT = 8, 32, 1024 + VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1 + UNROLL = 64 + STORAGE_ALIGN = False + assert BLK_K % 8 == 0 + tiles_per_row = (N + BLK_N - 1) // BLK_N + zero = tir.const(0, model_dtype) + if indptr_dtype == "int64": + indptr = op.pad(indptr, [1, 0], "constant", 0) + + @T.prim_func(private=True) + def _func( + var_x: T.handle, + w: T.Buffer((Ne, N, num_storage), storage_dtype), + scale: T.Buffer((Ne, N, num_group), model_dtype), + indptr: T.Buffer((Ne + 1,), indptr_dtype), + var_o: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32(is_size_var=True) + X = T.match_buffer(var_x, (B, K), model_dtype) + O = T.match_buffer(var_o, (B, N), model_dtype) + for _bx in T.thread_binding(CTA_COUNT, thread="blockIdx.x"): + with T.block("CTA"): + bx = T.axis.spatial(CTA_COUNT, _bx) + T.reads(X[:, :], w[:, :, :], scale[:, :, :], indptr[:]) + T.writes(O[:, :]) + # pylint: disable=redefined-builtin + sum = T.alloc_buffer((2,), indptr_dtype, scope="local") + row = T.alloc_buffer((2,), indptr_dtype, scope="local") + cur_e = T.alloc_buffer((1,), indptr_dtype, scope="local") + tile_id = T.alloc_buffer((1,), indptr_dtype, scope="local") + # pylint: enable=redefined-builtin + sum[0] = 0 + sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row + row[0] = 0 + row[1] = indptr[1] - indptr[0] + cur_e[0] = 0 + tile_id[0] = bx + while T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # move to the current group + while sum[1] <= tile_id[0] and cur_e[0] < Ne: + cur_e[0] += 1 + if cur_e[0] < Ne: + e = cur_e[0] + delta = indptr[e + 1] - indptr[e] + sum[0] = sum[1] + sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row + row[0] = row[1] + row[1] += delta + # sync threads to make sure all threads have the same tile position + T.tvm_storage_sync("shared") + if T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # fetch current tile position + e = cur_e[0] # type: ignore[no-redef] + num_tiles = tile_id[0] - sum[0] + m_offset = T.floordiv(num_tiles, tiles_per_row) * BLK_M + row[0] + n_offset = T.floormod(num_tiles, tiles_per_row) * BLK_N + with T.block("gemm"): + T.reads( + row[1], + X[m_offset : m_offset + BLK_M, :], + w[e, n_offset : n_offset + BLK_N, :], + scale[e, n_offset : n_offset + BLK_N, :], + ) + T.writes(O[m_offset : m_offset + BLK_M, n_offset : n_offset + BLK_N]) + X_tile = T.alloc_buffer((BLK_M, K), model_dtype, scope="shared") + W_tile = T.alloc_buffer((BLK_N, K), model_dtype, scope="shared") + O_tile = T.alloc_buffer((BLK_M, BLK_N), "float32", scope="local") + for a0, a1 in T.grid(BLK_M, K): + with T.block("X_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + X_tile[i, j] = T.if_then_else( + m_offset + i < row[1], + X[m_offset + i, j], + zero, + ) + for a0, a1 in T.grid(BLK_N, K): + with T.block("W_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + W_tile[i, j] = T.if_then_else( + n_offset + i < N, + _dequantize(w, scale, e, n_offset + i, j), + zero, + ) + for a0, a1, a2 in T.grid(BLK_M, BLK_N, K): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [a0, a1, a2]) + with T.init(): + O_tile[i, j] = zero + O_tile[i, j] += X_tile[i, k] * W_tile[j, k] + for a0, a1 in T.grid(BLK_M, BLK_N): + with T.block("store"): + i, j = T.axis.remap("SS", [a0, a1]) + if m_offset + i < row[1] and n_offset + j < N: + O[m_offset + i, n_offset + j] = O_tile[i, j] + # move to next tile + tile_id[0] += CTA_COUNT + + def _schedule(): + sch = tir.Schedule(_func) + + def _cooperative_fetch(block, vec_len): + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[TY, TX, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if STORAGE_ALIGN: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + main_block = sch.get_block("compute") + x, y, k = sch.get_loops(main_block) + ty, yi = sch.split(y, [TY, None]) + tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT]) + ko, ki = sch.split(k, factors=[None, BLK_K]) + sch.reorder(ty, tx, ko, ki, yi, xi, vec_c) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec_c) + if UNROLL > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=UNROLL) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + l2g = sch.get_block("store") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O]) + sch.vectorize(v) + _cooperative_fetch(sch.get_block("X_shared"), vec_len=VEC_X) + _cooperative_fetch(sch.get_block("W_shared"), vec_len=VEC_W) + sch.decompose_reduction(main_block, ko) + return sch.mod["main"] + + return op.tensor_ir_op( + _schedule(), + "dequantize_group_gemm", + args=[x, w, scale, indptr], + out=Tensor.placeholder([x.shape[0], out_features], model_dtype), + ) diff --git a/python/mlc_chat/op/moe_misc.py b/python/mlc_chat/op/moe_misc.py new file mode 100644 index 0000000..e97ef94 --- /dev/null +++ b/python/mlc_chat/op/moe_misc.py @@ -0,0 +1,408 @@ +"""Mixture of Experts operators""" +from functools import reduce +from typing import Tuple, Union + +from tvm import te, tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T +from tvm.target import Target +from tvm.topi.cuda.scan import inclusive_scan +from tvm.topi.cuda.sort import topk as topi_topk + +# mypy: disable-error-code="attr-defined,name-defined" +# pylint: disable=line-too-long,too-many-locals,invalid-name + + +def moe_sum(x: Tensor, dim: int) -> Tensor: + """Compute the sum of the input tensor along the given axis. It is specialized for the MoE + case where `x.ndim == 3` and `x.shape[1] == num_experts_per_tok (which is 2)`. + """ + if x.ndim == 3 and x.shape[1] == 2: + return op.tensor_expr_op( + lambda x: te.compute( + (x.shape[0], x.shape[2]), + lambda i, j: x[i, 0, j] + x[i, 1, j], + name="sum_2", + ), + "sum", + args=[x], + ) + return op.sum(x, axis=dim) + + +def gating_softmax_topk(x: Tensor, k: int) -> Tuple[Tensor, Tensor]: + """Compute the softmax score, choose the top-k experts, and renormalize the selected scores. + + Parameters + ---------- + x : Tensor + The input tensor with shape [batch_size, num_local_experts]. + + k : int + The number of top elements to be selected, which is `num_experts_per_tok` in MoE. + + Returns + ------- + expert_weights: Tensor + The renormalized top-k expert scores with shape [batch_size, k]. + + expert_indices: Tensor + The top-k expert indices with shape [batch_size, k]. + """ + (batch_size, num_local_experts), dtype = x.shape, x.dtype + index_dtype = "int32" + + TX = 1024 + SCAN_LEN = 2 + + # specialized kernel for top 2 case + @T.prim_func(private=True) + def topk_softmax_func( + var_x: T.handle, + var_out: T.handle, + var_out_index: T.handle, + ) -> None: + T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) + batch_size = T.int64() + x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) + out = T.match_buffer(var_out, (batch_size, SCAN_LEN), dtype) + out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN), index_dtype) + local_top_k = T.alloc_buffer((SCAN_LEN,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((SCAN_LEN,), dtype=index_dtype, scope="local") + local_top_k_f32 = T.alloc_buffer((SCAN_LEN,), dtype="float32", scope="local") + local_top_k_max = T.alloc_buffer((1,), dtype="float32", scope="local") + for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): + for ii in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("top_k"): + vi = T.axis.spatial(batch_size, io * TX + ii) + T.where(io * TX + ii < batch_size) + with T.block("init"): + local_top_k[0] = T.min_value(dtype) + local_top_k_index[0] = 0 + for k in range(num_local_experts): + with T.block("update"): + vk = T.axis.remap("S", [k]) + # N.B. This snippet is specialized for k = 2 + if x[vi, vk] > local_top_k[0]: + local_top_k[1] = local_top_k[0] + local_top_k_index[1] = local_top_k_index[0] + local_top_k[0] = x[vi, vk] + local_top_k_index[0] = vk + elif x[vi, vk] > local_top_k[1]: + local_top_k[1] = x[vi, vk] + local_top_k_index[1] = vk + for j in T.unroll(SCAN_LEN): + with T.block("cast"): + vj = T.axis.remap("S", [j]) + local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32") + with T.block("max"): + local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1]) + for j in T.unroll(SCAN_LEN): + with T.block("output"): + vj = T.axis.remap("S", [j]) + out[vi, vj] = T.cast( + T.exp(local_top_k_f32[j] - local_top_k_max[0]) + / ( + T.exp(local_top_k_f32[0] - local_top_k_max[0]) + + T.exp(local_top_k_f32[1] - local_top_k_max[0]) + ), + dtype, + ) + out_index[vi, vj] = local_top_k_index[vj] + + if k == 2: + return op.tensor_ir_op( + topk_softmax_func, + "top2_softmax", + args=[x], + out=( + Tensor.placeholder([batch_size, 2], dtype), + Tensor.placeholder([batch_size, 2], index_dtype), + ), + ) + expert_score, expert_indices = op.tensor_expr_op(topi_topk, "topk", args=[x, k, -1, "both", False, index_dtype]) # type: ignore[list-item] + expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) + return expert_score, expert_indices + + +def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor: + """An operator that returns the cumsum array in MoE. + + The input `expert_indices` of shape [batch_size, experts_per_tok] indicates the indices of + the activated experts for each instance in a batch. This operator first converts it to + `expert_mask`, a boolean mask with shape [batch_size, num_local_experts], and then computes + cumsum over the transpose-then-flattened array of `expert_mask`. + + A position `(e, b)` in the result `cumsum`, where `e` is the expert id and `b` is the batch id, + indicates a shuffling plan that moves the `b`-th instance that ensures the inputs to the `e`-th + expert is contiguous. + + Parameters + ---------- + expert_indices : Tensor + The topk indices with shape [batch_size, experts_per_tok], int32, where + `experts_per_tok` is the number of activated experts. + + num_local_experts : int + The number of totally experts. + + Returns + ------- + cumsum: Tensor + The cumsum result with shape [num_local_experts * batch_size], int32. + + Example + ------- + Suppose `batch_size` is 4, `experts_per_tok` is 2, the total number of experts is 6, and + `expert_indices` is the 2D tensor below: + + [ + [0, 1], + [1, 2], + [3, 4], + [2, 5], + ] + + , then the `expert_mask` is a tensor of shape [batch_size, num_local_experts] below: + + [ + [1, 1, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 1, 0, 0, 1], + ] + + . The result cumsum of the transposed `expert_mask` is a flattened version of 2D tensor below: + + [ + [1, 1, 1, 1], + [2, 3, 3, 3], + [3, 4, 4, 5], + [5, 5, 6, 6], + [6, 6, 7, 7], + [7, 7, 7, 8], + ] + """ + batch_size, experts_per_tok = expert_indices.shape + expert_mask = ( + op.tensor_expr_op( # pylint: disable=too-many-function-args + lambda expert_indices: te.compute( + (batch_size, num_local_experts), + lambda i, j: tir.expr.Select( + reduce( + tir.Or, + [expert_indices[i, k] == j for k in range(experts_per_tok)], + ), + true_value=tir.const(1, "int32"), + false_value=tir.const(0, "int32"), + ), + ), + "expert_mask", + args=[expert_indices], + ) + .permute_dims(1, 0) + .reshape(batch_size * num_local_experts) + ) + with Target.current(allow_none=True) or Target( + { + "kind": "cuda", + "max_num_threads": 1024, + "arch": "sm_50", + } + ): + return op.tensor_expr_op(inclusive_scan, "cumsum", args=[expert_mask, 0, "int32"]) # type: ignore[list-item] + + +def get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tuple[Tensor, Tensor]: + """Returns a 1D tensor of indices that represents the shuffling plan for each instance in a + batch, so that the inputs to each experts are contiguous and the indices for reverse permutation + (scatter) to the original order. + + If `reverse_indices[i] = (b, j)`, it means the `b`-th instance in the batch should be moved to the + `i`-th position in shuffling, and `j` doesn not matter only meaning `expert_indices[b, j]` + corresponds to the expert at position `i` in the shuffling plan. We also compute + `token_indices[i] = b` so that we can use `relax.op.take` for shuffling. + + Effectively it is equivalent to the following Python code: + + .. code-block:: python + + for b in range(batch_size): + for j in range(experts_per_tok): + e = expert_indices[b, j] + reverse_indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j + token_indices[cumsum[e * batch_size + b] - 1 + + Parameters + ---------- + cumsum : Tensor + A flattened 1D tensor whose original shape is [experts_per_tok, batch_size]. + + expert_indices : Tensor + The indices of the experts with shape [batch_size, experts_per_tok]. + + Returns + ------- + reverse_indices : Tensor + The indices for scattering with shape [batch_size * experts_per_tok]. + + token_indices : Tensor + The indices for shuffling with shape [batch_size * experts_per_tok]. + """ + TX = 1024 + batch_size, experts_per_tok = expert_indices.shape + + @T.prim_func(private=True) + def _func( + var_cumsum: T.handle, + var_expert_indices: T.handle, + var_reverse_indices: T.handle, + var_token_indices: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + batch_size = T.SizeVar("batch_size", "int32") + cumsum_len = T.SizeVar("cumsum_len", "int32") # [experts_per_tok * batch_size] + cumsum = T.match_buffer(var_cumsum, [cumsum_len], "int32") + expert_indices = T.match_buffer(var_expert_indices, [batch_size, experts_per_tok], "int32") + reverse_indices = T.match_buffer( + var_reverse_indices, [batch_size * experts_per_tok], "int32" + ) + token_indices = T.match_buffer(var_token_indices, [batch_size * experts_per_tok], "int32") + for bj_o in T.thread_binding(0, T.ceildiv(batch_size * experts_per_tok, TX), "blockIdx.x"): + for bj_i in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("indices"): + T.reads(expert_indices[:, :], cumsum[:]) + T.writes(reverse_indices[:], token_indices[:]) + if bj_o * TX + bj_i < batch_size * experts_per_tok: + b: T.int32 = T.floordiv(bj_o * TX + bj_i, experts_per_tok) + j: T.int32 = T.floormod(bj_o * TX + bj_i, experts_per_tok) + e: T.int32 = expert_indices[b, j] + reverse_indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j + token_indices[cumsum[e * batch_size + b] - 1] = b + + return op.tensor_ir_op( + _func, + "get_indices", + args=[cumsum, expert_indices], + out=[Tensor.placeholder([batch_size * experts_per_tok], "int32") for _ in range(2)], + ) + + +def get_indptr( + cumsum: Tensor, + num_local_experts: int, + batch_size: Union[int, tir.Var], + inclusive: bool, + out_dtype: str, +) -> Tensor: + """Extract the `indptr` array from MoE cumsum array. The MoE cumsum array is a flattened tensor + whose original shape is [num_local_experts, batch_size], and the `indptr` array is a 1D tensor + of length `num_local_experts + 1`. The range `[indptr[i], indptr[i + 1])` indicates instances in + the batch that corresponds to the `i`-th expert. + + Effectively, this operator is equivalent to the following numpy code: + + .. code-block:: python + + indptr = np.zeros(num_local_experts + 1, dtype=np.int32) + indptr[0] = 0 + for i in range(1, num_local_experts + 1): + indptr[i] = cumsum[i * batch_size - 1] + return indptr + + Parameters + ---------- + cumsum : Tensor + The prefix sum of the sparse array with shape [batch_size * num_local_experts], int32. + + num_local_experts : int + The number of experts. + + batch_size : int | tir.Var + The batch size. Note that the batch size here refers to `batch_size * seq_len` in MoE, + and we name is `batch_size` for simplicity here only because the two dimensions are fused + in Mixtral. + + inclusive : bool + Whether to compute inclusive or exclusive prefix sum as the indptr. If `inclusive` is False, + the 0-th element of the `indptr` array, which always equals to 0, will be omitted. + + out_dtype : str + The output dtype. + + Returns + ------- + indptr : Tensor + The `indptr` array with shape [num_local_experts + 1] if `inclusive` is True, otherwise + [num_local_experts]. The `indptr` array is of type `out_dtype`. + """ + + out_shape = [num_local_experts if inclusive else num_local_experts + 1] + + @T.prim_func(private=True) + def _func_exclusive(var_cumsum: T.handle, var_indptr: T.handle, batch_size: T.int32): + T.func_attr({"tir.noalias": True}) + cumsum = T.match_buffer(var_cumsum, shape=[batch_size * num_local_experts], dtype="int32") + indptr = T.match_buffer(var_indptr, shape=out_shape, dtype=out_dtype) + for vi in T.serial(0, out_shape[0]): + with T.block("indptr"): + i = T.axis.spatial(out_shape[0], vi) + indptr[i] = T.Select(i > 0, cumsum[i * batch_size - 1], T.int32(0)) + + @T.prim_func(private=True) + def _func_inclusive(var_cumsum: T.handle, var_indptr: T.handle, batch_size: T.int32): + T.func_attr({"tir.noalias": True}) + cumsum = T.match_buffer(var_cumsum, shape=[batch_size * num_local_experts], dtype="int32") + indptr = T.match_buffer(var_indptr, shape=out_shape, dtype=out_dtype) + for vi in T.serial(0, out_shape[0]): + with T.block("indptr"): + i = T.axis.spatial(out_shape[0], vi) + indptr[i] = cumsum[(i + 1) * batch_size - 1] + + assert cumsum.ndim == 1 + return op.tensor_ir_op( + _func_inclusive if inclusive else _func_exclusive, + "get_expert_instance_indptr", + args=[cumsum, batch_size], # type: ignore[list-item] + out=Tensor.placeholder(out_shape, out_dtype), + ) + + +def scatter_output(x: Tensor, indices: Tensor) -> Tensor: + """Scatter the output of MoE experts back to the original positions. + + Parameters + ---------- + x : Tensor + The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. + + indices : Tensor + The indices of the experts with shape [batch_size * num_experts_per_tok]. + + Returns + ------- + out : Tensor + The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. + """ + dtype = x.dtype + + @T.prim_func(private=True) + def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle): + T.func_attr({"tir.noalias": True}) + hidden_size = T.int64() + indices_len = T.int64() + x = T.match_buffer(var_x, [indices_len, hidden_size], dtype) + indices = T.match_buffer(var_indices, [indices_len], "int32") + out = T.match_buffer(var_out, [indices_len, hidden_size], dtype) + for i in T.serial(0, indices_len): + for j in T.serial(0, hidden_size): + with T.block("scatter"): + vi, vj = T.axis.remap("SS", [i, j]) + out[indices[vi], vj] = x[vi, vj] + + return op.tensor_ir_op( + _func, + "scatter_output", + args=[x, indices], + out=Tensor.placeholder(x.shape, dtype), + ) diff --git a/python/mlc_chat/op/position_embedding.py b/python/mlc_chat/op/position_embedding.py new file mode 100644 index 0000000..12bdaaa --- /dev/null +++ b/python/mlc_chat/op/position_embedding.py @@ -0,0 +1,376 @@ +"""Operators for positional embeddings, e.g. RoPE.""" + +from typing import Optional, Tuple + +from tvm import tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T +from tvm.target import Target + +# pylint: disable=invalid-name + + +def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): + """Compute the inverse frequency of RoPE and then return the cosine and sine of it. + + Parameters + ---------- + s : tir.Var + The position index. + + d : tir.Var + The dimension index. + + d_range : int + The maximum dimension index. + + theta : float + The theta value in RoPE, which controls the frequency. + + dtype : str + The data type of the output. + + Returns + ------- + cos_freq : Tensor + The cosine of the inverse frequency. + + sin_freq : Tensor + The sine of the inverse frequency. + """ + freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) + cos_freq = tir.cos(freq).astype(dtype) + sin_freq = tir.sin(freq).astype(dtype) + return cos_freq, sin_freq + + +# mypy: disable-error-code="attr-defined" + + +def llama_rope( # pylint: disable=too-many-arguments + qkv: Tensor, + total_seq_len: tir.Var, + theta: float, + num_q_heads: int, + num_kv_heads: int, + scale: float = 1.0, + rotary_dim: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """Llama-style RoPE. Given a fused QKV tensor, it returns three tensors, Q, K, and V, where Q + and K are rotated by RoPE while V remains unchanged. + + Parameters + ---------- + qkv : Tensor + The fused QKV tensor of shape: [batch_size, seq_len, #q_heads + #kv_heads * 2, head_dim] + + total_seq_len : tir.Var + The total sequence length after being concatenated with KVCache. It is used to compute the + offset of RoPE. + + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + rotary_dim : Optional[int] + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + + Returns + ------- + q : Tensor + The query tensor of shape [batch_size, seq_len, #q_heads, head_dim] w/ RoPE applied + + k : Tensor + The key tensor of shape [batch_size, seq_len, #kv_heads, head_dim] w/ RoPE applied + + v : Tensor + The value tensor of shape [batch_size, seq_len, #kv_heads, head_dim] w/o RoPE applied + """ + _, _, fused_heads, head_dim = qkv.shape + assert fused_heads == num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + dtype = qkv.dtype + scale = tir.const(scale, dtype) + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + b: tir.Var, + s: tir.Var, + h: tir.Var, + d: tir.Var, + offset: tir.Var, + ): + cos_freq, sin_freq = rope_freq((s + offset) * scale, d, rotary_dim, theta, dtype) + cos = cos_freq * x[b, s, h, d] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[b, s, h, d + rotary_dim // 2], + x[b, s, h, d - rotary_dim // 2], + ) + return cos + sin + + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + total_seq_len: T.int64, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": T.bool(True), + } + ) + batch_size = T.int64() + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + for iters in T.grid(batch_size, seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + b, s, h, d = T.axis.remap("SSSS", iters) + if h < num_q_heads: + q[b, s, h, d] = T.if_then_else( + d < rotary_dim, + _rope(qkv, b, s, h, d, total_seq_len - seq_len), + qkv[b, s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[b, s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope(qkv, b, s, h, d, total_seq_len - seq_len), + qkv[b, s, h, d], + ) + else: + v[b, s, h - (num_q_heads + num_kv_heads), d] = qkv[b, s, h, d] + + b, s, _, _ = qkv.shape + return op.tensor_ir_op( # pylint: disable=no-member + fused_rope, + "llama_rope", + args=[qkv, total_seq_len], + out=( + Tensor.placeholder((b, s, num_q_heads, head_dim), dtype), + Tensor.placeholder((b, s, num_kv_heads, head_dim), dtype), + Tensor.placeholder((b, s, num_kv_heads, head_dim), dtype), + ), + ) + + +def llama_rope_with_position_map( # pylint: disable=too-many-arguments + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + rotary_dim: int = None, +): + """Return the TIR function that computes Llama-style RoPE with q position map. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + fused_heads = num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + scale = tir.const(scale, dtype) + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + pos: tir.Var, + ): + cos_freq, sin_freq = rope_freq(pos * scale, d, rotary_dim, theta, dtype) + cos = cos_freq * x[s, h, d] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[s, h, d + rotary_dim // 2], + x[s, h, d - rotary_dim // 2], + ) + return cos + sin + + @T.prim_func + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + apply_rope: T.int32, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": T.bool(True), + } + ) + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (seq_len,), "int32") + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + return fused_rope + + +# pylint: disable=line-too-long,too-many-arguments,too-many-nested-blocks,invalid-name + + +def llama_inplace_rope( + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + target: Target, # pylint: disable=unused-argument + rotary_dim: Optional[int] = None, +): + """Return the TIR function that inplace computes Llama-style RoPE with q position offset. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + target : Target + The target to build the model to. + + rotary_dim : Optional[int] + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + if rotary_dim is None: + rotary_dim = head_dim + + def _rope( + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + rope_offset: tir.Var, + instance_offset: tir.Var, + ): + cos_freq, sin_freq = rope_freq((s + rope_offset) * scale, d, rotary_dim, theta, dtype) + cos = cos_freq * x[s + instance_offset, h, d] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[s + instance_offset, h, d + rotary_dim // 2], + x[s + instance_offset, h, d - rotary_dim // 2], + ) + return cos + sin + + # fmt: off + @T.prim_func + def tir_rotary( # pylint: disable=too-many-locals + var_q: T.handle, + var_k: T.handle, + var_append_len_indptr: T.handle, + var_rope_offsets: T.handle, + _0: T.int32, + _1: T.int32, + _2: T.int32, + _3: T.int32, + _4: T.int32, + _5: T.float32, + _6: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1}) + total_len = T.int32() + batch_size = T.int32() + q = T.match_buffer(var_q, (total_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (total_len, num_kv_heads, head_dim), dtype) + rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32") + append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size + 1,), "int32") + with T.block(): + for b_h in T.thread_binding(batch_size * (num_q_heads + num_kv_heads), thread="blockIdx.x"): + b: T.int32 = b_h // (num_q_heads + num_kv_heads) + h: T.int32 = b_h % (num_q_heads + num_kv_heads) + instance_offset: T.int32 = append_len_indptr[b] + rope_offset: T.int32 = rope_offsets[b] + append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] + for s0 in range(T.ceildiv(append_len, 32)): + for s1 in T.thread_binding(32, thread="threadIdx.y"): + for d0 in T.thread_binding(T.ceildiv(head_dim, 4), thread="threadIdx.x"): + for d1 in T.vectorized(4): + s: T.int32 = s0 * 32 + s1 + d: T.int32 = d0 * 4 + d1 + if s < append_len and d < rotary_dim: + if h < num_q_heads: + q[s + instance_offset, h, d] = _rope(q, s, h, d, rope_offset, instance_offset) + else: + k[s + instance_offset, h - num_q_heads, d] = _rope(k, s, h - num_q_heads, d, rope_offset, instance_offset) + return tir_rotary + + +# pylint: enable=line-too-long,too-many-arguments,too-many-nested-blocks,invalid-name diff --git a/python/mlc_chat/protocol/__init__.py b/python/mlc_chat/protocol/__init__.py new file mode 100644 index 0000000..2776756 --- /dev/null +++ b/python/mlc_chat/protocol/__init__.py @@ -0,0 +1,4 @@ +"""The protocols for MLC LLM server""" +from . import openai_api_protocol + +RequestProtocol = openai_api_protocol.CompletionRequest diff --git a/python/mlc_chat/protocol/conversation_protocol.py b/python/mlc_chat/protocol/conversation_protocol.py new file mode 100644 index 0000000..01c145d --- /dev/null +++ b/python/mlc_chat/protocol/conversation_protocol.py @@ -0,0 +1,136 @@ +"""The standard conversation protocol in MLC LLM""" + +from enum import Enum +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field, field_validator + + +# The message placeholders in the message prompts according to roles. +class MessagePlaceholders(Enum): + """The message placeholders in the message prompts according to roles.""" + + SYSTEM = "{system_message}" + USER = "{user_message}" + ASSISTANT = "{assistant_message}" + TOOL = "{tool_message}" + FUNCTION = "{function_string}" + + +class Conversation(BaseModel): + """Class that specifies the convention template of conversation + and contains the conversation history. + + Given a conversation template, the corresponding prompt generated out + from it is usually in the following format: + + <><><><><> + <><><><> + ... + <><><><> + <><> + """ + + # Optional name of the template. + name: Optional[str] = None + # The system prompt template, it optionally contains the system + # message placeholder, and the placeholder will be replaced with + # the system message below. + system_template: str = MessagePlaceholders.SYSTEM.value + # The content of the system prompt (without the template format). + system_message: str = "" + # The system token ids to be prepended at the beginning of tokenized + # generated prompt. + system_prefix_token_ids: Optional[List[int]] = None + + # The conversation roles + roles: Dict[str, str] + + # The roles prompt template, it optionally contains the defaults + # message placeholders and will be replaced by actual content + role_templates: Dict[str, str] + + # The conversation history messages. + # Each message is a pair of strings, denoting "(role, content)". + # The content can be None. + messages: List[Tuple[str, Optional[str]]] = Field(default_factory=lambda: []) + + # The separators between messages when concatenating into a single prompt. + # List size should be either 1 or 2. + # - When size is 1, the separator will be used between adjacent messages. + # - When size is 2, seps[0] is used after user message, and + # seps[1] is used after assistant message. + seps: List[str] + + # The separator between the role and the content in a message. + role_content_sep: str = "" + # The separator between the role and empty contents. + role_empty_sep: str = "" + + # The stop criteria + stop_str: List[str] = Field(default_factory=lambda: []) + stop_token_ids: List[int] = Field(default_factory=lambda: []) + + # Function call fields + function_string: str = "" + # whether using function calling or not, helps check for output message format in API call + use_function_calling: bool = False + + def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs): + # Defaults templates which would be overridden by model specific templates + _role_templates: Dict[str, str] = { + "user": MessagePlaceholders.USER.value, + "assistant": MessagePlaceholders.ASSISTANT.value, + "tool": MessagePlaceholders.TOOL.value, + } + if role_templates is not None: + _role_templates.update(role_templates) + super().__init__(role_templates=_role_templates, **kwargs) + + @field_validator("seps") + @classmethod + def check_message_seps(cls, seps: List[str]) -> List[str]: + """Check if the input message separators has size 1 or 2.""" + if len(seps) == 0 or len(seps) > 2: + raise ValueError("seps should have size 1 or 2.") + return seps + + def as_prompt(self) -> str: + """Convert the conversation template and history messages to + a single prompt. + """ + # - Get the system message. + system_msg = self.system_template.replace( + MessagePlaceholders.SYSTEM.value, self.system_message + ) + + # - Get the message strings. + message_list: List[str] = [] + separators = list(self.seps) + if len(separators) == 1: + separators.append(separators[0]) + for role, content in self.messages: # pylint: disable=not-an-iterable + if role not in self.roles.keys(): + raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') + separator = separators[role == "assistant"] # check assistant role + if content is not None: + message_string = ( + self.roles[role] + + self.role_content_sep + + self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, content + ) + + separator + ) + else: + message_string = self.roles[role] + self.role_empty_sep + message_list.append(message_string) + + prompt = system_msg + separators[0] + "".join(message_list) + + # Replace the last function string placeholder with actual function string + prompt = self.function_string.join(prompt.rsplit(MessagePlaceholders.FUNCTION.value, 1)) + # Replace with remaining function string placeholders with empty string + prompt = prompt.replace(MessagePlaceholders.FUNCTION.value, "") + + return prompt diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py new file mode 100644 index 0000000..2ae26bf --- /dev/null +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -0,0 +1,329 @@ +"""Protocols in MLC LLM for OpenAI API. +Adapted from FastChat's OpenAI protocol: +https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +""" + +# pylint: disable=missing-class-docstring +import time +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import shortuuid +from pydantic import BaseModel, Field, field_validator, model_validator + +################ Commons ################ + + +class ListResponse(BaseModel): + object: str = "list" + data: List[Any] + + +class TopLogProbs(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] + + +class LogProbsContent(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] + top_logprobs: List[TopLogProbs] = [] + + +class LogProbs(BaseModel): + content: List[LogProbsContent] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0): + super().__init__( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + +################ v1/models ################ + + +class ModelResponse(BaseModel): + """OpenAI "v1/models" response protocol. + API reference: https://platform.openai.com/docs/api-reference/models/object + """ + + id: str + created: int = Field(default_factory=lambda: int(time.time())) + object: str = "model" + owned_by: str = "MLC-LLM" + + +################ v1/completions ################ + + +class CompletionRequest(BaseModel): + """OpenAI completion request protocol. + API reference: https://platform.openai.com/docs/api-reference/completions/create + """ + + model: str + prompt: Union[str, List[int], List[Union[str, List[int]]]] + best_of: int = 1 + echo: bool = False + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = None + max_tokens: int = 16 + n: int = 1 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + suffix: Optional[str] = None + temperature: float = 1.0 + top_p: float = 1.0 + user: Optional[str] = None + ignore_eos: bool = False + + @field_validator("frequency_penalty", "presence_penalty") + @classmethod + def check_penalty_range(cls, penalty_value: float) -> float: + """Check if the penalty value is in range [-2, 2].""" + if penalty_value < -2 or penalty_value > 2: + raise ValueError("Penalty value should be in range [-2, 2].") + return penalty_value + + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + + @model_validator(mode="after") + def check_logprobs(self) -> "CompletionRequest": + """Check if the logprobs requirements are valid.""" + if self.top_logprobs < 0 or self.top_logprobs > 5: + raise ValueError('"top_logprobs" must be in range [0, 5]') + if not self.logprobs and self.top_logprobs > 0: + raise ValueError('"logprobs" must be True to support "top_logprobs"') + return self + + +class CompletionResponseChoice(BaseModel): + finish_reason: Optional[Literal["stop", "length"]] = None + index: int = 0 + logprobs: Optional[LogProbs] = None + text: str + + +class CompletionResponse(BaseModel): + """OpenAI completion response protocol. + API reference: https://platform.openai.com/docs/api-reference/completions/object + """ + + id: str + choices: List[CompletionResponseChoice] + created: int = Field(default_factory=lambda: int(time.time())) + model: str + object: str = "text_completion" + usage: UsageInfo = Field( + default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda + ) + + +################ v1/chat/completions ################ + + +class ChatFunction(BaseModel): + description: Optional[str] = None + name: str + parameters: Dict + + +class ChatTool(BaseModel): + type: Literal["function"] + function: ChatFunction + + +class ChatFunctionCall(BaseModel): + name: str + arguments: Union[None, Dict[str, Any]] = None + + +class ChatToolCall(BaseModel): + id: str = Field(default_factory=lambda: f"call_{shortuuid.random()}") + type: Literal["function"] + function: ChatFunctionCall + + +class ChatCompletionMessage(BaseModel): + content: Optional[Union[str, List[Dict[str, str]]]] = None + role: Literal["system", "user", "assistant", "tool"] + name: Optional[str] = None + tool_calls: Optional[List[ChatToolCall]] = None + tool_call_id: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + """OpenAI chat completion request protocol. + API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + + messages: List[ChatCompletionMessage] + model: str + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = None + max_tokens: Optional[int] = None + n: int = 1 + response_format: Literal["text", "json_object"] = "text" + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + temperature: float = 1.0 + top_p: float = 1.0 + tools: Optional[List[ChatTool]] = None + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None + user: Optional[str] = None + ignore_eos: bool = False + + @field_validator("frequency_penalty", "presence_penalty") + @classmethod + def check_penalty_range(cls, penalty_value: float) -> float: + """Check if the penalty value is in range [-2, 2].""" + if penalty_value < -2 or penalty_value > 2: + raise ValueError("Penalty value should be in range [-2, 2].") + return penalty_value + + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + + @model_validator(mode="after") + def check_logprobs(self) -> "ChatCompletionRequest": + """Check if the logprobs requirements are valid.""" + if self.top_logprobs < 0 or self.top_logprobs > 5: + raise ValueError('"top_logprobs" must be in range [0, 5]') + if not self.logprobs and self.top_logprobs > 0: + raise ValueError('"logprobs" must be True to support "top_logprobs"') + return self + + +class ChatCompletionResponseChoice(BaseModel): + finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None + index: int = 0 + message: ChatCompletionMessage + logprobs: Optional[LogProbs] = None + + +class ChatCompletionStreamResponseChoice(BaseModel): + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + index: int = 0 + delta: ChatCompletionMessage + logprobs: Optional[LogProbs] = None + + +class ChatCompletionResponse(BaseModel): + """OpenAI completion response protocol. + API reference: https://platform.openai.com/docs/api-reference/chat/object + """ + + id: str + choices: List[ChatCompletionResponseChoice] + created: int = Field(default_factory=lambda: int(time.time())) + model: str + system_fingerprint: str + object: Literal["chat.completion"] = "chat.completion" + usage: UsageInfo = Field( + default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda + ) + + +class ChatCompletionStreamResponse(BaseModel): + """OpenAI completion stream response protocol. + API reference: https://platform.openai.com/docs/api-reference/chat/streaming + """ + + id: str + choices: List[ChatCompletionStreamResponseChoice] + created: int = Field(default_factory=lambda: int(time.time())) + model: str + system_fingerprint: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + + +################################################ + + +def openai_api_get_unsupported_fields( + request: Union[CompletionRequest, ChatCompletionRequest] +) -> List[str]: + """Get the unsupported fields in the request.""" + unsupported_field_default_values: List[Tuple[str, Any]] = [ + ("best_of", 1), + ("n", 1), + ("response_format", "text"), + ] + + unsupported_fields: List[str] = [] + for field, value in unsupported_field_default_values: + if hasattr(request, field) and getattr(request, field) != value: + unsupported_fields.append(field) + return unsupported_fields + + +def openai_api_get_generation_config( + request: Union[CompletionRequest, ChatCompletionRequest] +) -> Dict[str, Any]: + """Create the generation config from the given request.""" + kwargs: Dict[str, Any] = {} + arg_names = [ + "temperature", + "top_p", + "max_tokens", + "frequency_penalty", + "presence_penalty", + "logprobs", + "top_logprobs", + "logit_bias", + "seed", + "ignore_eos", + ] + for arg_name in arg_names: + kwargs[arg_name] = getattr(request, arg_name) + if kwargs["max_tokens"] is None: + # Setting to -1 means the generation will not stop until + # exceeding model capability or hit any stop criteria. + kwargs["max_tokens"] = -1 + if request.stop is not None: + kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop + return kwargs diff --git a/python/mlc_chat/protocol/protocol_utils.py b/python/mlc_chat/protocol/protocol_utils.py new file mode 100644 index 0000000..a9a68a1 --- /dev/null +++ b/python/mlc_chat/protocol/protocol_utils.py @@ -0,0 +1,58 @@ +"""Utility functions for request protocols""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from ..serve.config import GenerationConfig +from . import RequestProtocol +from .openai_api_protocol import ChatCompletionRequest as OpenAIChatCompletionRequest +from .openai_api_protocol import CompletionRequest as OpenAICompletionRequest +from .openai_api_protocol import ( + openai_api_get_generation_config, + openai_api_get_unsupported_fields, +) + + +class ErrorResponse(BaseModel): + """The class of error response.""" + + object: str = "error" + message: str + code: int = None + + +def get_unsupported_fields(request: RequestProtocol) -> List[str]: + """Get the unsupported fields of the request. + Return the list of unsupported field names. + """ + if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): + return openai_api_get_unsupported_fields(request) + raise RuntimeError("Cannot reach here") + + +def get_generation_config( + request: RequestProtocol, + extra_stop_token_ids: Optional[List[int]] = None, + extra_stop_str: Optional[List[str]] = None, +) -> GenerationConfig: + """Create the generation config in MLC LLM out from the input request protocol.""" + kwargs: Dict[str, Any] + if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): + kwargs = openai_api_get_generation_config(request) + else: + raise RuntimeError("Cannot reach here") + + if extra_stop_token_ids is not None: + stop_token_ids = kwargs.get("stop_token_ids", []) + assert isinstance(stop_token_ids, list) + stop_token_ids += extra_stop_token_ids + kwargs["stop_token_ids"] = stop_token_ids + + if extra_stop_str is not None: + stop_strs = kwargs.get("stop_strs", []) + assert isinstance(stop_strs, list) + stop_strs += extra_stop_str + kwargs["stop_strs"] = stop_strs + + return GenerationConfig(**kwargs) diff --git a/python/mlc_chat/quantization/__init__.py b/python/mlc_chat/quantization/__init__.py new file mode 100644 index 0000000..31016a9 --- /dev/null +++ b/python/mlc_chat/quantization/__init__.py @@ -0,0 +1,6 @@ +"""A subpackage for quantization and dequantization algorithms""" +from .awq_quantization import AWQQuantize +from .ft_quantization import FTQuantize +from .group_quantization import GroupQuantize +from .no_quantization import NoQuantize +from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_chat/quantization/awq_quantization.py b/python/mlc_chat/quantization/awq_quantization.py new file mode 100644 index 0000000..116582f --- /dev/null +++ b/python/mlc_chat/quantization/awq_quantization.py @@ -0,0 +1,271 @@ +"""AWQ Quantization""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from tvm import DataType, DataTypeCode, te, tir, topi +from tvm.relax.frontend import nn +from tvm.runtime import NDArray + +from mlc_chat.loader import QuantizeMapping + +from .utils import convert_uint_to_float, is_final_fc + + +def _make_divisible(c, divisor): # pylint: disable=invalid-name + return (c + divisor - 1) // divisor + + +def _calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + raise NotImplementedError + + base_width = _make_divisible(in_features // group_size, pack_num) + base_width = _make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + +@dataclass +class AWQQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for AWQ quantization""" + + name: str + kind: str + group_size: int + quantize_dtype: str # "int3", "int4", "int8" + storage_dtype: str # "uint32" + model_dtype: str # "float16", "float32" + + num_elem_per_storage: int = 0 + num_storage_per_group: int = 0 + max_int_value: int = 0 + + prebuilt_quantize_func: Dict[str, Callable[[NDArray], NDArray]] = field( + default_factory=lambda: {} + ) + + def __post_init__(self): + assert self.kind == "awq" + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + model_dtype = DataType(self.model_dtype) + assert quantize_dtype.type_code == DataTypeCode.INT + assert storage_dtype.type_code == DataTypeCode.UINT + assert model_dtype.type_code == DataTypeCode.FLOAT + if storage_dtype.bits < quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + + self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits + if self.group_size % self.num_elem_per_storage != 0: + raise ValueError("Group size should be divisible by numbers of elements per storage") + self.num_storage_per_group = self.group_size // self.num_elem_per_storage + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + + def quantize_model( + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + ) -> nn.Module: + """ + Quantize model with awq quantization. + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ + + class _Mutator(nn.Mutator): + def __init__(self, config: AWQQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for awq quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------- + ret_node : Any + The new node to replace current node. + """ + + if isinstance(node, nn.Linear) and not is_final_fc(name): + return AWQQuantizeLinear.from_linear(node, self.config) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def _dequantize( + self, + weight: te.Tensor, + zeros: te.Tensor, + scale: te.Tensor, + out_shape: Optional[List[tir.PrimExpr]] = None, + ): + float_weight = convert_uint_to_float( + weight, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + out_shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage], + ft_reorder=True, + ) + float_zeros = convert_uint_to_float( + zeros, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + out_shape=[zeros.shape[0], zeros.shape[1] * self.num_elem_per_storage], + ft_reorder=True, + ) + float_weight = topi.transpose(float_weight) + float_zeros = topi.transpose(float_zeros) + scale = topi.transpose(scale) + return te.compute( + shape=( + [weight.shape[0], weight.shape[1] * self.num_elem_per_storage] + if out_shape is None + else out_shape + ), + fcompute=lambda i, j: tir.multiply( + tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]), + scale[i, j // self.group_size], + ), + name="dequantize", + ) + + +class AWQQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with AWQ quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: int, + config: AWQQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + self.qweight = nn.Parameter( + (in_features, out_features // config.num_elem_per_storage), config.storage_dtype + ) + self.qzeros = nn.Parameter( + (in_features // config.group_size, out_features // config.num_elem_per_storage), + config.storage_dtype, + ) + self.scales = nn.Parameter( + (in_features // config.group_size, out_features), config.model_dtype + ) + if bias: + self.bias = nn.Parameter( + (out_features,), config.model_dtype if out_dtype is None else out_dtype + ) + else: + self.bias = None + + @staticmethod + def from_linear(linear: nn.Linear, config: AWQQuantize) -> "AWQQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a group quantized AWQQuantizeLinear + + Parameters + ---------- + linear : nn.Linear + The non-quantized nn.Linear. + + config : AWQQuantize + The awq quantization config. + + Returns + ------- + ret : GroupQuantizeLinear + The awq quantized AWQQuantizeLinear layer. + """ + return AWQQuantizeLinear( + in_features=linear.in_features, + out_features=linear.out_features, + config=config, + bias=getattr(linear, "bias", None) is not None, + out_dtype=linear.out_dtype, + ) + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for awq quantized linear layer + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized linear layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, zeros, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + zeros, + scale, + [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], + ), + name_hint="dequantize", + args=[self.qweight, self.qzeros, self.scales], + ) + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x + + def to(self, dtype: Optional[str] = None) -> None: + """ + Override to() such that we do not convert bias if there is an out_dtype. + Otherwise, we might run into dtype mismatch when computing x + self.bias. + """ + self.qweight.to(dtype=dtype) + self.qzeros.to(dtype=dtype) + self.scales.to(dtype=dtype) + if self.bias is not None and self.out_dtype is None: + self.bias.to(dtype=dtype) + if dtype is not None and isinstance(getattr(self, "dtype", None), str): + self.dtype = dtype # pylint: disable=attribute-defined-outside-init diff --git a/python/mlc_chat/quantization/ft_quantization.py b/python/mlc_chat/quantization/ft_quantization.py new file mode 100644 index 0000000..c30e85b --- /dev/null +++ b/python/mlc_chat/quantization/ft_quantization.py @@ -0,0 +1,396 @@ +"""The FasterTransformer quantization config""" + +from dataclasses import dataclass +from typing import Any, Callable, List, Literal, Optional, Tuple + +import tvm +from tvm import DataType, DataTypeCode, IRModule +from tvm import dlight as dl +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.runtime import NDArray +from tvm.target import Target + +from ..loader import QuantizeMapping +from ..op import faster_transformer_dequantize_gemm +from ..support import logging +from ..support.auto_target import detect_cuda_arch_list +from ..support.style import bold +from .group_quantization import ( + GroupQuantize, + GroupQuantizeEmbedding, + GroupQuantizeLinear, +) +from .utils import is_final_fc + +logger = logging.getLogger(__name__) + + +@dataclass +class FTQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for FasterTransformer quantization""" + + name: str + kind: str + quantize_dtype: Literal["int4", "int8"] + storage_dtype: Literal["int8"] + model_dtype: Literal["float16"] + group_size: Optional[int] = None + + num_elem_per_storage: int = 0 + max_int_value: int = 0 + + def fallback_group_quantize(self) -> GroupQuantize: + """ + The fallback group quantization config for other parameters. + + Returns + ------ + quantize: GroupQuantize + The group quantization config to fallback. + """ + return GroupQuantize( + name=self.name, + kind="group-quant", + group_size=32, # hardcoded to 32 as only supporting int4 quantization + quantize_dtype=self.quantize_dtype, + storage_dtype="uint32", + model_dtype=self.model_dtype, + linear_weight_layout="NK", + ) + + def __post_init__(self): + assert self.kind == "ft-quant" + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + assert self.quantize_dtype in ["int4", "int8"] + assert storage_dtype.type_code == DataTypeCode.INT + assert self.model_dtype == "float16" + assert self.group_size in [None, 64, 128] + if storage_dtype.bits < quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + + self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + self._quantize_func_cache = {} + + def quantize_model( + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + ) -> nn.Module: + """ + Quantize model with FasterTransformer quantization + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ + + class _Mutator(nn.Mutator): + def __init__(self, config: FTQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for FasterTransformer quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node. + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + if isinstance(node, nn.Linear): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + if ( + # pylint: disable=too-many-boolean-expressions + is_final_fc(name) + or node.out_dtype == "float32" + or (self.config.quantize_dtype == "int4" and node.out_features % 8 != 0) + or (self.config.quantize_dtype == "int8" and node.out_features % 4 != 0) + ): + # Under any of the conditions we fall back to GroupQuantize + # For `is_final_fc()` see https://github.com/mlc-ai/mlc-llm/issues/1723 + # If simply skipping lm_head quantization degrades performance + # Other requirements are from CUTLASS + logger.info( + 'Fallback to GroupQuantize for nn.Linear: "%s", ' + + "weight.shape: %s, out_dtype: %s", + bold(name), + node.weight.shape, + node.out_dtype, + ) + group_quantize = self.config.fallback_group_quantize() + self.quant_map.map_func[weight_name] = group_quantize.quantize_weight + return GroupQuantizeLinear.from_linear(node, group_quantize) + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return FTQuantizeLinear.from_linear(node, self.config) + if isinstance(node, nn.Embedding): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + group_quantize = self.config.fallback_group_quantize() + self.quant_map.map_func[weight_name] = group_quantize.quantize_weight + return GroupQuantizeEmbedding.from_embedding(node, group_quantize) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def quantize_weight(self, weight: NDArray) -> List[NDArray]: + """ + Quantize weight with FasterTransformer quantization + + Parameters + ---------- + weight : NDArray + The original weight. + + Returns + ------ + ret: List[NDArray] + The list of FasterTransformer quantized weights. + """ + assert tvm.get_global_func("relax.ext.cutlass", True), ( + "Cutlass should be enabled in TVM runtime to quantize weight, " + "but not enabled in current TVM runtime environment. " + "To enable Cutlass in TVM runtime, please `set(USE_CUTLASS ON)` " + "in config.cmake when compiling TVM from source" + ) + assert len(weight.shape) == 2 + device = weight.device + device_type = device.MASK2STR[device.device_type] + if device_type == "cuda": + target = Target.current() + if target is None: + target = Target.from_device(device) + with target: + + def _create_quantize_func() -> IRModule: + bb = relax.BlockBuilder() # pylint: disable=invalid-name + weight_var = relax.Var( + "weight", relax.TensorStructInfo(weight.shape, weight.dtype) + ) + with bb.function(name="main", params=[weight_var]): + with bb.dataflow(): + lv0 = bb.emit_te( + self._quantize, weight_var + ) # pylint: disable=invalid-name + lv1 = bb.normalize(lv0[0]) + lv2 = bb.emit( + relax.call_pure_packed( + "cutlass.ft_preprocess_weight", + lv1, + detect_cuda_arch_list(target=target)[0], + DataType(self.quantize_dtype).bits == 4, + sinfo_args=lv1.struct_info, + ) + ) + gv = bb.emit_output( + relax.Tuple([lv2, lv0[1]]) + ) # pylint: disable=invalid-name + bb.emit_func_output(gv) + return bb.finalize() + + def _compile_quantize_func(mod: IRModule) -> Callable: + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device) # pylint: disable=invalid-name + return vm["main"] + + key = str((int(weight.shape[0]), int(weight.shape[1]), weight.dtype, device_type)) + quantize_func = self._quantize_func_cache.get(key, None) + if quantize_func is None: + logger.info("Compiling quantize function for key: %s", key) + quantize_func = _compile_quantize_func(_create_quantize_func()) + self._quantize_func_cache[key] = quantize_func + data = quantize_func(weight) + return data + else: + raise NotImplementedError(f"Device type {device_type} is not supported") + + def _quantize( # pylint: disable=too-many-locals + self, + weight: te.Tensor, + ) -> Tuple[te.Tensor, te.Tensor]: + """FasterTransformer quantization for weight tensor, defined in tensor expression.""" + assert len(weight.shape) == 2 + n, k = weight.shape + + cur_group_size = k if not self.group_size else self.group_size + scale_shape = (tir.ceildiv(k, cur_group_size), n) + r = te.reduce_axis((0, cur_group_size), name="r") + + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda j, i: te.max( + tir.if_then_else( + j * cur_group_size + r < k, + te.abs(weight[i, j * cur_group_size + r]), + te.min_value(self.model_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + max_int = tir.const(self.max_int_value, self.model_dtype) + scale = te.compute( + scale_shape, + lambda i, j: max_abs[i, j].astype(self.model_dtype) / max_int, + name="scale", + ) + # compute scaled weight + quantize_dtype = DataType(self.quantize_dtype) + bin_mask = tir.const((1 << quantize_dtype.bits) - 1, self.storage_dtype) + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda i, j: tir.min( + tir.max( + tir.round(weight[i, j] / scale[j // cur_group_size, i]), + -max_int - 1, + ), + max_int, + ).astype(self.storage_dtype) + & bin_mask, + ) + + quantized_weight_shape = (k, tir.ceildiv(n, self.num_elem_per_storage)) + r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda j, i: tir.sum( + tir.if_then_else( + i * self.num_elem_per_storage + r < n, + scaled_weight[i * self.num_elem_per_storage + r, j] + << ( + r.astype(self.storage_dtype) + * tir.const(quantize_dtype.bits, self.storage_dtype) + ), + tir.const(0, self.storage_dtype), + ), + axis=r, + ), + name="weight", + ) + + return quantized_weight, scale + + +class FTQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with FasterTransformer quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: int, + config: FTQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + cur_group_size = in_features if not config.group_size else config.group_size + self.q_weight = nn.Parameter( + (in_features, tir.ceildiv(out_features, config.num_elem_per_storage)), + config.storage_dtype, + ) + self.q_scale = nn.Parameter( + (tir.ceildiv(in_features, cur_group_size), out_features), config.model_dtype + ) + if bias: + self.bias = nn.Parameter( + (out_features,), config.model_dtype if out_dtype is None else out_dtype + ) + else: + self.bias = None + + @staticmethod + def from_linear(src: nn.Linear, config: FTQuantize) -> "FTQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a FasterTransformer quantized FTQuantizeLinear + + Parameters + ---------- + src : nn.Linear + The non-quantized nn.Linear. + + config : FTQuantize + The FasterTransformer quantization config. + + Returns + ------- + ret : FTQuantizeLinear + The FasterTransformer quantized FTQuantizeLinear layer. + """ + quantized_linear = FTQuantizeLinear( + in_features=src.in_features, + out_features=src.out_features, + config=config, + bias=getattr(src, "bias", None) is not None, + out_dtype=src.out_dtype, + ) + if quantized_linear.bias is not None: + quantized_linear.bias.attrs = src.bias.attrs + return quantized_linear + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for FasterTransformer quantized linear layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the FasterTransformer quantized linear layer. + """ + return faster_transformer_dequantize_gemm( + x, self.q_weight, self.q_scale, self.bias, group_size=self.config.group_size + ) + + def to(self, dtype: Optional[str] = None) -> None: + """ + Override to() such that we do not convert bias if there is an out_dtype. + Otherwise, we might run into dtype mismatch when computing x + self.bias. + """ + self.q_weight.to(dtype=dtype) + self.q_scale.to(dtype=dtype) + if self.bias is not None and self.out_dtype is None: + self.bias.to(dtype=dtype) + if dtype is not None and isinstance(getattr(self, "dtype", None), str): + self.dtype = dtype # pylint: disable=attribute-defined-outside-init diff --git a/python/mlc_chat/quantization/group_quantization.py b/python/mlc_chat/quantization/group_quantization.py new file mode 100644 index 0000000..baf8662 --- /dev/null +++ b/python/mlc_chat/quantization/group_quantization.py @@ -0,0 +1,664 @@ +"""The group quantization config""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +from tvm import DataType, DataTypeCode, IRModule +from tvm import dlight as dl +from tvm import relax, te, tir, topi +from tvm.relax.frontend import nn +from tvm.runtime import NDArray +from tvm.target import Target + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.nn import MixtralExperts +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp + +from .utils import convert_uint_to_float, is_final_fc + +logger = logging.getLogger(__name__) + + +@dataclass +class GroupQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for group quantization""" + + name: str + kind: str + group_size: int + quantize_dtype: Literal["int3", "int4", "int8"] + storage_dtype: Literal["uint32"] + model_dtype: Literal["float16", "float32"] + linear_weight_layout: Literal["KN", "NK"] + quantize_embedding: bool = True + quantize_final_fc: bool = True + + num_elem_per_storage: int = 0 + num_storage_per_group: int = 0 + max_int_value: int = 0 + + def __post_init__(self): + assert self.kind == "group-quant" + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + model_dtype = DataType(self.model_dtype) + assert quantize_dtype.type_code == DataTypeCode.INT + assert storage_dtype.type_code == DataTypeCode.UINT + assert model_dtype.type_code == DataTypeCode.FLOAT + if storage_dtype.bits < quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + + self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits + if self.group_size % self.num_elem_per_storage != 0: + raise ValueError("Group size should be divisible by numbers of elements per storage") + self.num_storage_per_group = self.group_size // self.num_elem_per_storage + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + self.linear_quant_axis = 0 if self.linear_weight_layout == "KN" else 1 + self._quantize_func_cache = {} + + def quantize_model( + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + ) -> nn.Module: + """ + Quantize model with group quantization + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ + + class _Mutator(nn.Mutator): + def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for group quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node. + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + if isinstance(node, nn.Linear) and ( + not is_final_fc(name) or self.config.quantize_final_fc + ): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = partial( + self.config.quantize_weight, + output_transpose=self.config.linear_weight_layout == "KN", + ) + return GroupQuantizeLinear.from_linear(node, self.config) + if isinstance(node, nn.Embedding) and self.config.quantize_embedding: + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return GroupQuantizeEmbedding.from_embedding(node, self.config) + if isinstance(node, MixtralExperts): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return GroupQuantizeMixtralExperts.from_mixtral_experts(node, self.config) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def _dequantize( + self, + weight: te.Tensor, + scale: te.Tensor, + axis: int, + out_shape: Optional[List[tir.PrimExpr]] = None, + ): + tir_max_int = tir.const(self.max_int_value, self.model_dtype) + float_weight = convert_uint_to_float( + weight, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + axis=axis, + out_shape=out_shape, + ) + if out_shape is None: + out_shape = weight.shape + out_shape[axis] *= self.num_elem_per_storage + axis = axis if axis >= 0 else len(out_shape) + axis + return te.compute( + shape=out_shape, + fcompute=lambda *idx: tir.multiply( + tir.subtract( + float_weight(*idx), + tir_max_int, + ), + scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]), + ), + name="dequantize", + ) + + def quantize_weight( + self, weight: NDArray, axis: int = -1, output_transpose: bool = False + ) -> List[NDArray]: + """ + Quantize weight with group quantization + + Parameters + ---------- + weight : NDArray + The original weight. + + axis : int + The group axis. + + output_transpose : bool + Whether to transpose the output quantized weight. Only 2D weight is supported. + + Returns + ------ + ret: List[NDArray] + The list of group quantized weights. + """ + device = weight.device + device_type = device.MASK2STR[device.device_type] + axis = axis if axis >= 0 else len(weight.shape) + axis + + def _create_quantize_func() -> IRModule: + bb = relax.BlockBuilder() # pylint: disable=invalid-name + weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, weight.dtype)) + with bb.function(name="main", params=[weight_var]): + with bb.dataflow(): + lv = bb.emit_te(self._quantize, weight_var, axis, output_transpose) + gv = bb.emit_output(lv) # pylint: disable=invalid-name + bb.emit_func_output(gv) + return bb.finalize() + + def _compile_quantize_func(mod: IRModule) -> Callable: + if device_type in ["cuda", "rocm", "metal", "vulkan"]: + target = Target.current() + if target is None: + target = Target.from_device(device) + with target: + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + elif device_type == "cpu": + target = "llvm" + mod = relax.transform.LegalizeOps()(mod) + else: + raise NotImplementedError(f"Device type {device_type} is not supported") + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device) # pylint: disable=invalid-name + return vm["main"] + + key = ( + f"({weight.shape}, {weight.dtype}, {device_type}, " + f"axis={axis}, output_transpose={output_transpose})" + ) + quantize_func = self._quantize_func_cache.get(key, None) + if quantize_func is None: + logger.info("Compiling quantize function for key: %s", key) + quantize_func = _compile_quantize_func(_create_quantize_func()) + self._quantize_func_cache[key] = quantize_func + return quantize_func(weight) + + def _quantize( # pylint: disable=too-many-locals + self, + weight: te.Tensor, + axis: int = -1, + output_transpose: bool = False, + ) -> Tuple[te.Tensor, te.Tensor]: + """Group quantization for weight tensor, defined in tensor expression.""" + max_int = tir.const(self.max_int_value, self.model_dtype) + shape = weight.shape # pylint: disable=invalid-name + axis = axis if axis >= 0 else len(shape) + axis + k = shape[axis] + quantize_dtype = DataType(self.quantize_dtype) + # compute scale per group + r = te.reduce_axis((0, self.group_size), name="r") # pylint: disable=invalid-name + num_group = tir.ceildiv(k, self.group_size) + scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :]) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda *idx: te.max( + tir.if_then_else( + idx[axis] * self.group_size + r < k, + te.abs(weight(*idx[:axis], idx[axis] * self.group_size + r, *idx[axis + 1 :])), + te.min_value(self.model_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + scale_shape, + lambda *idx: max_abs(*idx).astype(self.model_dtype) / max_int, + name="scale", + ) + # compute scaled weight + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda *idx: tir.min( + tir.max( + tir.round( + weight(*idx) + / scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]) + + max_int + ), + tir.const(0, self.model_dtype), + ), + max_int * 2, + ).astype(self.storage_dtype), + ) + # compute quantized weight per storage + r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name + num_storage = self.num_storage_per_group * num_group + quantized_weight_shape = (*shape[:axis], num_storage, *shape[axis + 1 :]) + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda *idx: tir.sum( + tir.if_then_else( + idx[axis] * self.num_elem_per_storage + r < k, + scaled_weight( + *idx[:axis], idx[axis] * self.num_elem_per_storage + r, *idx[axis + 1 :] + ) + << (r * quantize_dtype.bits), + 0, + ), + axis=r, + ), + name="weight", + ) + if output_transpose: + if len(quantized_weight.shape) != 2 or len(scale.shape) != 2: + raise ValueError( + "Does not support transpose output quantized weight with ndim != 2" + ) + quantized_weight = topi.transpose(quantized_weight) + scale = topi.transpose(scale) + return quantized_weight, scale + + +class GroupQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with group quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: Union[int, tir.Var], + config: GroupQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + num_group = tir.ceildiv(in_features, config.group_size) + if config.linear_weight_layout == "KN": + self.q_weight = nn.Parameter( + (config.num_storage_per_group * num_group, out_features), config.storage_dtype + ) + self.q_scale = nn.Parameter((num_group, out_features), config.model_dtype) + else: + self.q_weight = nn.Parameter( + (out_features, config.num_storage_per_group * num_group), config.storage_dtype + ) + self.q_scale = nn.Parameter((out_features, num_group), config.model_dtype) + if bias: + self.bias = nn.Parameter( + (out_features,), config.model_dtype if out_dtype is None else out_dtype + ) + else: + self.bias = None + + @staticmethod + def from_linear(src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a group quantized GroupQuantizeLinear + + Parameters + ---------- + src : nn.Linear + The non-quantized nn.Linear. + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeLinear + The group quantized GroupQuantizeLinear layer. + """ + # For dynamic shape, src.out_features is `"name"`; src.weight.shape[0] is `tir.Var("name")` + out_features, in_features = src.weight.shape + quantized_linear = GroupQuantizeLinear( + in_features=in_features, + out_features=out_features, + config=config, + bias=getattr(src, "bias", None) is not None, + out_dtype=src.out_dtype, + ) + if quantized_linear.bias is not None: + quantized_linear.bias.attrs = src.bias.attrs + if "shard_strategy" in src.weight.attrs: + shard = src.weight.attrs["shard_strategy"] + _apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) + _apply_sharding(shard, f"{shard.name}_q_scale", quantized_linear.q_scale) + return quantized_linear + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for group quantized linear layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized linear layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=self.config.linear_quant_axis, + out_shape=( + [ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + tir.IntImm("int64", self.in_features), + ] + if self.config.linear_weight_layout == "NK" + else [ + tir.IntImm("int64", self.in_features), + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[1] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + ] + ), + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) + if self.config.linear_weight_layout == "NK": + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x + + def to(self, dtype: Optional[str] = None) -> None: + """ + Override to() such that we do not convert bias if there is an out_dtype. + Otherwise, we might run into dtype mismatch when computing x + self.bias. + """ + self.q_weight.to(dtype=dtype) + self.q_scale.to(dtype=dtype) + if self.bias is not None and self.out_dtype is None: + self.bias.to(dtype=dtype) + if dtype is not None and isinstance(getattr(self, "dtype", None), str): + self.dtype = dtype # pylint: disable=attribute-defined-outside-init + + +class GroupQuantizeEmbedding(nn.Module): + """An nn.Embedding module with group quantization""" + + def __init__(self, num: Union[int, tir.Var], dim: int, config: GroupQuantize): + self.num = num + self.dim = dim + self.config = config + num_group = tir.ceildiv(dim, config.group_size) + self.q_weight = nn.Parameter( + (num, config.num_storage_per_group * num_group), config.storage_dtype + ) + self.q_scale = nn.Parameter((num, num_group), config.model_dtype) + + @staticmethod + def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding": + """ + Converts a non-quantized nn.Embedding to a group quantized GroupQuantizeEmbedding + + Parameters + ---------- + linear : nn.Embedding + The non-quantized nn.Embedding. + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeEmbedding + The group quantized GroupQuantizeEmbedding layer. + """ + num, dim = embedding.weight.shape + return GroupQuantizeEmbedding(num, dim, config) + + def forward(self, x: nn.Tensor): # pylint: disable=invalid-name + """ + Forward method for group quantized embedding layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the embedding layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=-1, + out_shape=[ + ( + tir.IntImm("int64", self.num) + if isinstance(self.num, int) + else weight.shape[0] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + tir.IntImm("int64", self.dim), + ], + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) + if x.ndim == 1: + return nn.op.take(w, x, axis=0) + return nn.op.reshape( + nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0), + shape=[*x.shape, self.dim], + ) + + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which dequantizes the weight + and multiplies it with the input tensor. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the lm_head layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=-1, + out_shape=[ + ( + tir.IntImm("int64", self.num) + if isinstance(self.num, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.dim), + ], + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) + w = nn.op.permute_dims(w) + return nn.op.matmul(x, w, out_dtype="float32") + + +class GroupQuantizeMixtralExperts(nn.Module): # pylint: disable=too-many-instance-attributes + """An MixtralExperts module with group quantization""" + + def __init__( + self, + num_local_experts, + in_features, + out_features, + config: GroupQuantize, + ): # pylint: disable=too-many-arguments + self.num_local_experts = num_local_experts + self.in_features = in_features + self.out_features = out_features + self.config = config + num_group = tir.ceildiv(in_features, config.group_size) + self.q_weight = nn.Parameter( + (num_local_experts, out_features, config.num_storage_per_group * num_group), + config.storage_dtype, + ) + self.q_scale = nn.Parameter( + (num_local_experts, out_features, num_group), config.model_dtype + ) + self.quantize_dtype = config.quantize_dtype + self.group_size = config.group_size + self.dtype = config.model_dtype + if config.linear_weight_layout == "KN": + raise NotImplementedError("GroupQuantizeMixtralExperts does not support KN layout now.") + + @staticmethod + def from_mixtral_experts( + src: "MixtralExperts", config: GroupQuantize + ) -> "GroupQuantizeMixtralExperts": + """ + Converts a non-quantized MixtralExperts to a group quantized GroupQuantizeMixtralExperts + + Parameters + ---------- + src : MixtralExperts + The non-quantized MixtralExperts + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeMixtralExperts + The group quantized GroupQuantizeMixtralExperts layer. + """ + quantized_mistral_experts = GroupQuantizeMixtralExperts( + num_local_experts=src.num_local_experts, + in_features=src.in_features, + out_features=src.out_features, + config=config, + ) + if "shard_strategy" in src.weight.attrs: + shard = src.weight.attrs["shard_strategy"] + _apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) + _apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) + return quantized_mistral_experts + + def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """Forward method for group quantized mistral experts. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + indptr: nn.Tensor + The indptr tensor + + single_batch_decode: bool + Whether to use single-batch decode + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized mistral experts layer. + """ + from mlc_chat.op import moe_matmul # pylint: disable=import-outside-toplevel + + assert x.ndim == 2 + if indptr.ndim == 2: # single-batch + assert indptr.shape[0] == 1 + return moe_matmul.dequantize_gemv( + x, + self.q_weight, + self.q_scale, + indptr, + quantize_dtype=self.quantize_dtype, + group_size=self.group_size, + ) + assert indptr.ndim == 1 + return moe_matmul.dequantize_group_gemm( + x, + self.q_weight, + self.q_scale, + indptr, + quantize_dtype=self.quantize_dtype, + indptr_dtype=indptr.dtype, + group_size=self.group_size, + ) + + +def _apply_sharding(shard, name: str, weight: nn.Parameter): + if isinstance(shard, tp.ShardSingleDim): + weight.attrs["shard_strategy"] = tp.ShardSingleDim( + name=name, + dim=shard.dim, + segs=shard.segs, + ) + else: + raise NotImplementedError(f"Unknowing sharding strategy: {shard}") diff --git a/python/mlc_chat/quantization/no_quantization.py b/python/mlc_chat/quantization/no_quantization.py new file mode 100644 index 0000000..b1944c1 --- /dev/null +++ b/python/mlc_chat/quantization/no_quantization.py @@ -0,0 +1,14 @@ +"""The no quantization config""" +from dataclasses import dataclass + + +@dataclass +class NoQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for no quantization""" + + name: str + kind: str + model_dtype: str # "float16", "float32" + + def __post_init__(self): + assert self.kind == "no-quant" diff --git a/python/mlc_chat/quantization/quantization.py b/python/mlc_chat/quantization/quantization.py new file mode 100644 index 0000000..3fab898 --- /dev/null +++ b/python/mlc_chat/quantization/quantization.py @@ -0,0 +1,120 @@ +"""A centralized registry of all existing quantization methods and their configurations.""" +from typing import Any, Dict + +from .awq_quantization import AWQQuantize +from .ft_quantization import FTQuantize +from .group_quantization import GroupQuantize +from .no_quantization import NoQuantize + +Quantization = Any +"""Quantization is an object that represents an quantization algorithm. It is required to +have the following fields: + + name : str + The name of the quantization algorithm, for example, "q4f16_1". + + kind : str + The kind of quantization algorithm, for example, "group-quant", "faster-transformer". + +It is also required to have the following method: + + def quantize_model(self, module: nn.Module) -> nn.Module: + ... + + def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArray]: + ... +""" + +QUANTIZATION: Dict[str, Quantization] = { + "q0f16": NoQuantize( + name="q0f16", + kind="no-quant", + model_dtype="float16", + ), + "q0f32": NoQuantize( + name="q0f32", + kind="no-quant", + model_dtype="float32", + ), + "q3f16_0": GroupQuantize( + name="q3f16_0", + kind="group-quant", + group_size=40, + quantize_dtype="int3", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="KN", + quantize_embedding=True, + quantize_final_fc=True, + ), + "q3f16_1": GroupQuantize( + name="q3f16_1", + kind="group-quant", + group_size=40, + quantize_dtype="int3", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="NK", + quantize_embedding=True, + quantize_final_fc=True, + ), + "q4f16_0": GroupQuantize( + name="q4f16_0", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="KN", + quantize_embedding=True, + quantize_final_fc=True, + ), + "q4f16_1": GroupQuantize( + name="q4f16_1", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="NK", + quantize_embedding=True, + quantize_final_fc=True, + ), + "q4f32_1": GroupQuantize( + name="q4f32_1", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float32", + linear_weight_layout="NK", + quantize_embedding=True, + quantize_final_fc=True, + ), + "q4f16_2": GroupQuantize( + name="q4f16_2", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="NK", + quantize_embedding=False, + quantize_final_fc=False, + ), + "q4f16_autoawq": AWQQuantize( + name="q4f16_autoawq", + kind="awq", + group_size=128, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + ), + "q4f16_ft": FTQuantize( + name="q4f16_ft", + kind="ft-quant", + quantize_dtype="int4", + storage_dtype="int8", + model_dtype="float16", + ), +} diff --git a/python/mlc_chat/quantization/utils.py b/python/mlc_chat/quantization/utils.py new file mode 100644 index 0000000..4159da8 --- /dev/null +++ b/python/mlc_chat/quantization/utils.py @@ -0,0 +1,47 @@ +"""Common utilities for quantization""" + +from typing import List, Optional + +from tvm import te, tir + + +def convert_uint_to_float( # pylint: disable=too-many-arguments + weight: te.Tensor, + bits: int, + num_elem_per_storage: int, + storage_dtype: str, + model_dtype: str, + axis: int = -1, + out_shape: Optional[List[tir.PrimExpr]] = None, + ft_reorder: Optional[bool] = False, +) -> te.Tensor: + """Convert a quantized uint weight to an unquantized float weight.""" + tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype) + if out_shape is None: + out_shape = weight.shape + out_shape[axis] *= num_elem_per_storage + axis = axis if axis >= 0 else len(out_shape) + axis + return te.compute( + shape=out_shape, + fcompute=lambda *idx: tir.bitwise_and( + tir.shift_right( + weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]), + ( + ( + (idx[axis] % num_elem_per_storage) % 2 * 4 + + (idx[axis] % num_elem_per_storage) // 2 + ) + * bits + if ft_reorder + else (idx[axis] % num_elem_per_storage) * bits + ).astype(storage_dtype), + ), + tir_bin_mask, + ).astype(model_dtype), + ) + + +def is_final_fc(name: str) -> bool: + """Determines whether the parameter is the last layer based on its name.""" + # TODO: use more specious condition to determine final fc # pylint: disable=fixme + return name in ["head", "lm_head"] diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py new file mode 100644 index 0000000..d2911a1 --- /dev/null +++ b/python/mlc_chat/rest.py @@ -0,0 +1,492 @@ +# pylint: disable=missing-docstring,fixme +import argparse +import ast +import asyncio +import dataclasses +import json +from contextlib import asynccontextmanager +from typing import Dict, List + +import numpy as np +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse + +from mlc_chat.chat_module import GenerationConfig +from mlc_chat.support.random import set_global_random_seed + +from .chat_module import ChatModule +from .interface.openai_api import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + ToolCalls, + ToolChoice, + UsageInfo, + VisualStudioCodeCompletionRequest, + VisualStudioCodeCompletionResponse, +) + + +@dataclasses.dataclass +class RestAPIArgs: + """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API + server.""" + + model: str = dataclasses.field( + metadata={ + "help": ( + """ + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. + """ + ) + } + ) + lib_path: str = dataclasses.field( + default=None, + metadata={ + "help": ( + """ + The full path to the model library file to use (e.g. a ``.so`` file). + """ + ) + }, + ) + device: str = dataclasses.field( + default="auto", + metadata={ + "help": ( + """ + The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. If no 'device_id' + is provided, it will be set to 0 by default. + """ + ) + }, + ) + host: str = dataclasses.field( + default="127.0.0.1", + metadata={ + "help": ( + """ + The host at which the server should be started, defaults to ``127.0.0.1``. + """ + ) + }, + ) + port: int = dataclasses.field( + default=8000, + metadata={ + "help": ( + """ + The port on which the server should be started, defaults to ``8000``. + """ + ) + }, + ) + random_seed: int = dataclasses.field( + default=None, + metadata={ + "help": ( + """ + The random seed to initialize all the RNG used in mlc-chat. By default, + no seed is set. + """ + ) + }, + ) + + +def convert_args_to_argparser() -> argparse.ArgumentParser: + """Convert from RestAPIArgs to an equivalent ArgumentParser.""" + args = argparse.ArgumentParser("MLC Chat REST API") + for field in dataclasses.fields(RestAPIArgs): + name = field.name.replace("_", "-") + field_name = f"--{name}" + # `kwargs` contains `help`, `choices`, and `action` + kwargs = field.metadata.copy() + if field.type == bool: + # boolean arguments do not need to specify `type` + args.add_argument(field_name, default=field.default, **kwargs) + else: + args.add_argument(field_name, type=field.type, default=field.default, **kwargs) + return args + + +session: Dict[str, ChatModule] = {} + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + if ARGS.random_seed is not None: + set_global_random_seed(ARGS.random_seed) + chat_mod = ChatModule( + model=ARGS.model, + device=ARGS.device, + model_lib_path=ARGS.lib_path, + ) + session["chat_mod"] = chat_mod + yield + session.clear() + + +origins = ["*"] + +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class AsyncCompletionStream: + def __init__(self, generation_config: GenerationConfig): + self.generation_config = generation_config + + def __aiter__(self): + return self + + async def get_next_msg(self): + # pylint: disable=protected-access + if not session["chat_mod"]._stopped(): + session["chat_mod"]._decode(generation_config=self.generation_config) + msg = session["chat_mod"]._get_message() + return msg + # pylint: enable=protected-access + raise StopAsyncIteration + + async def __anext__(self): + if not session["chat_mod"]._stopped(): + task = asyncio.create_task(self.get_next_msg()) + msg = await task + return msg + raise StopAsyncIteration + + +def add_function_call(prompt: List[ChatMessage], function_string: str): + # update content of the last input message to include function string + user_query = prompt[-1].content + prompt[-1].content = f"<> {user_query} <> {function_string}\n" + + +def function_call_util(request: ChatCompletionRequest): + """Performs the necessary actions to add function calls to the prompt + returns True if function calls are added to the prompt else returns False + TODO: Check function name in tools.function['name'] + TODO: Currently auto mode default to generating function calls instead of smartly + checking weather to generate function calls or not + """ + + # return if no tools are provided + if request.tools is None: + return False + + # skip if tool_choice is set to none + if isinstance(request.tool_choice, str) and request.tool_choice == "none": + return False + + if isinstance(request.tool_choice, ToolChoice): + # force the model to use a specific function provided by tool_choice + if request.tool_choice.type != "function": + raise ValueError("Only 'function' tool choice is supported") + for tool in request.tools: + if tool.function["name"] == request.tool_choice.function["name"]: + add_function_call(request.messages, json.dumps(tool.function)) + return True + raise ValueError("ToolChoice.function.name not found in tools") + + if isinstance(request.tool_choice, str): + # Add all the functions to the input prompt + function_list = [] + for tool in request.tools: + if tool.type == "function": + function_list.append(tool.function) + else: + raise ValueError("Only 'function' tool.type is supported") + add_function_call(request.messages, json.dumps(function_list)) + else: + raise ValueError("Invalid toolChoice instance type") + return True + + +def convert_function_str_to_json(stringified_calls): + def parse_function_call(call_str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None + + calls = ast.literal_eval(stringified_calls) + result = [parse_function_call(call_str) for call_str in calls] + return result + + +@app.post("/v1/chat/completions") +async def request_chat_completion(request: ChatCompletionRequest): + """ + Creates model response for the given chat conversation. + The messages field contains a list of messages (describing the conversation history). eg: + ```"messages": [{"role": "user", "content": "What's my name?"}, + {"role": "assistant", "content": "Your name is Llama."}, + {"role": "user", "content": "No, that's your name. My name is X."}, + {"role": "assistant", "content": "Ah, my apologies! Your name is X! "}, + {"role": "user", "content": "What is the meaning of life?"}, + ] + ``` + ] + """ + generation_config = GenerationConfig( + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + top_p=request.top_p, + mean_gen_len=request.mean_gen_len, + max_gen_len=request.max_gen_len, + n=request.n, + stop=request.stop, + ) + + session["chat_mod"].reset_chat() # Reset previous history, KV cache, etc. + + use_function_call = function_call_util(request) + + if request.stream: + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=request.messages, + generation_config=generation_config, + ) + + async def iter_response(): + prev_txt = "" + async for content in AsyncCompletionStream(generation_config=generation_config): + if content: + # Remove the replacement character (U+FFFD) from the response + # This is to handle emojis. An emoji might be made up of multiple tokens. + # In the Rest streaming setting, if an emoji gets truncated in the middle of + # its encoded byte sequence, a replacement character will appear. + valid_content = content.replace("�", "") + chunk = ChatCompletionStreamResponse( + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage( + role="assistant", content=valid_content[len(prev_txt) :] + ), + finish_reason="stop", + ) + ] + ) + prev_txt = valid_content + yield f"data: {chunk.json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(iter_response(), media_type="text/event-stream") + msg = session["chat_mod"].generate( + prompt=request.messages, generation_config=generation_config, stateless=True + ) + if isinstance(msg, str): + msg = [msg] + + choices = [] + for index, msg_i in enumerate(msg): + if use_function_call: + choices.append( + ChatCompletionResponseChoice( + index=index, + message=ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCalls( + function=fn_json_obj, + ) + for fn_json_obj in convert_function_str_to_json(msg_i) + ], + ), + finish_reason="tool_calls", + ) + ) + else: + choices.append( + ChatCompletionResponseChoice( + index=index, + message=ChatMessage( + role="assistant", + content=msg_i, + ), + finish_reason="stop", + ) + ) + + return ChatCompletionResponse( + choices=choices, + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + +@app.post("/v1/completions") +async def request_completion(request: CompletionRequest): + """ + Creates a completion for a given prompt. + """ + + generation_config = GenerationConfig( + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + top_p=request.top_p, + mean_gen_len=request.mean_gen_len, + max_gen_len=request.max_gen_len, + n=request.n, + stop=request.stop, + ) + + session["chat_mod"].reset_chat() + # Langchain's load_qa_chain.run expects the input to be a list with the query + if isinstance(request.prompt, list): + if len(request.prompt) > 1: + raise ValueError( + """ + The /v1/completions endpoint currently only supports single message prompts. + Please ensure your request contains only one message + """ + ) + prompt = request.prompt[0] + else: + prompt = request.prompt + + if request.stream: + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=prompt, + generation_config=generation_config, + ) + + async def iter_response(): + prev_txt = "" + async for content in AsyncCompletionStream(generation_config=generation_config): + if content: + chunk = CompletionStreamResponse( + choices=[ + CompletionResponseStreamChoice( + index=0, + text=content[len(prev_txt) :], + finish_reason="stop", + ) + ] + ) + prev_txt = content + yield f"data: {chunk.json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(iter_response(), media_type="text/event-stream") + msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) + if isinstance(msg, str): + msg = [msg] + return CompletionResponse( + choices=[ + CompletionResponseChoice(index=index, text=msg[index]) for index in range(len(msg)) + ], + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + +@app.post("/v1/embeddings") +async def request_embeddings(request: EmbeddingsRequest): + """ + Gets embedding for some text. + """ + inps = [] + if isinstance(request.input, str): + inps.append(request.input) + elif isinstance(request.input, list): + inps = request.input + else: + assert f"Invalid input type {type(request.input)}" + + data = [] + for i, inp in enumerate(inps): + session["chat_mod"].reset_chat() + emb = session["chat_mod"].embed_text(input=inp).numpy() + mean_emb = np.squeeze(np.mean(emb, axis=1), axis=0) + norm_emb = mean_emb / np.linalg.norm(mean_emb) + data.append({"object": "embedding", "embedding": norm_emb.tolist(), "index": i}) + # TODO: Fill in correct usage info + return EmbeddingsResponse( + data=data, usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) + ) + + +@app.post("/chat/reset") +async def reset(): + """ + Reset the chat for the currently initialized model. + """ + session["chat_mod"].reset_chat() + + +@app.get("/stats") +async def read_stats(): + """ + Get the runtime stats. + """ + return session["chat_mod"].stats() + + +@app.get("/verbose_stats") +async def read_stats_verbose(): + """ + Get the verbose runtime stats. + """ + return session["chat_mod"].stats(verbose=True) + + +@app.post("/v1/llm-vscode/completions") +async def request_llm_vscode(request: VisualStudioCodeCompletionRequest): + """ + Creates a vscode code completion for a given prompt. + Follows huggingface LSP (https://github.com/huggingface/llm-ls) + """ + generation_config = GenerationConfig( + temperature=request.parameters.temperature, + top_p=request.parameters.top_p, + mean_gen_len=request.parameters.max_new_tokens, + max_gen_len=request.parameters.max_new_tokens, + ) + msg = session["chat_mod"].generate(prompt=request.inputs, generation_config=generation_config) + + return VisualStudioCodeCompletionResponse(generated_text=msg) + + +ARGS = convert_args_to_argparser().parse_args() +if __name__ == "__main__": + uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) diff --git a/python/mlc_chat/serve/__init__.py b/python/mlc_chat/serve/__init__.py new file mode 100644 index 0000000..59185ec --- /dev/null +++ b/python/mlc_chat/serve/__init__.py @@ -0,0 +1,11 @@ +"""Subdirectory of serving.""" + +# Load MLC LLM library by importing base +from .. import base +from .async_engine import AsyncThreadedEngine +from .config import EngineMode, GenerationConfig, KVCacheConfig +from .data import Data, RequestStreamOutput, TextData, TokenData +from .engine import Engine +from .grammar import BNFGrammar, GrammarStateMatcher +from .request import Request +from .server import PopenServer diff --git a/python/mlc_chat/serve/_ffi_api.py b/python/mlc_chat/serve/_ffi_api.py new file mode 100644 index 0000000..282c80c --- /dev/null +++ b/python/mlc_chat/serve/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for mlc_chat.serve""" +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.serve" prefix. +# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData") +tvm._ffi._init_api("mlc.serve", __name__) # pylint: disable=protected-access diff --git a/python/mlc_chat/serve/async_engine.py b/python/mlc_chat/serve/async_engine.py new file mode 100644 index 0000000..74058ea --- /dev/null +++ b/python/mlc_chat/serve/async_engine.py @@ -0,0 +1,330 @@ +"""The MLC LLM Asynchronous Serving Engine. +Acknowledgment: Part of the code was adapted from the vLLM project. +""" + +import asyncio +import sys +import threading +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import tvm + +from ..streamer import TextStreamer +from ..tokenizer import Tokenizer +from . import data +from .config import EngineMode, GenerationConfig, KVCacheConfig +from .engine import ModelInfo, _estimate_max_total_sequence_length, _process_model_args +from .event_trace_recorder import EventTraceRecorder +from .request import Request + + +class AsyncRequestStream: + """The asynchronous stream for requests. + + Each request has its own unique stream. + The stream exposes the method `push` for engine to push new generated + delta text to the stream, and the method `finish` for engine to mark + the finish of generation. + + The stream implements `__aiter__` and `__anext__`, which the engine + can use to iterates all the generated tokens in order asynchronously. + """ + + # The asynchronous queue to hold elements of + # - either a tuple of (str, int, List[str], Optional[str]), denoting the + # delta output text, the number of delta tokens, the logprob JSON strings + # of delta tokens, and the optional finish reason respectively, + # - or an exception. + if sys.version_info >= (3, 9): + _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object + Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception] + ] + else: + _queue: asyncio.Queue + # The finish flag. + _finished: bool + + def __init__(self) -> None: + self._queue = asyncio.Queue() + self._finished = False + + def push( + self, + item_or_exception: Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception], + ) -> None: + """Push a new token to the stream.""" + if self._finished: + # No new item is expected after finish. + self._queue.put_nowait( + RuntimeError( + "The request has already finished. " + "The stream is not supposed to accept new items." + ) + ) + return + self._queue.put_nowait(item_or_exception) + + def finish(self) -> None: + """Mark the finish of the generation in the stream.""" + self._queue.put_nowait(StopIteration()) + self._finished = True + + def __aiter__(self): + return self + + async def __anext__(self) -> Tuple[str, int, Optional[List[str]], Optional[str]]: + result = await self._queue.get() + if isinstance(result, StopIteration): + raise StopAsyncIteration + if isinstance(result, Exception): + raise result + return result + + +class AsyncThreadedEngine: # pylint: disable=too-many-instance-attributes + """The asynchronous engine for generate text asynchronously, + backed by ThreadedEngine. + + This class wraps a synchronous threaded engine that runs on + a standalone thread inside, and exports the asynchronous `generate` + method as the main text generation interface, which yields the + generated tokens. The internal threaded engine keeps running an + event loop that drives the engine. + + Parameters + ---------- + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( + self, + models: Union[ModelInfo, List[ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + enable_tracing: bool = False, + ) -> None: + if isinstance(models, ModelInfo): + models = [models] + ( + model_args, + config_file_paths, + tokenizer_path, + self.max_single_sequence_length, + prefill_chunk_size, + self.conv_template_name, + ) = _process_model_args(models) + self.trace_recorder = EventTraceRecorder() if enable_tracing else None + + if kv_cache_config.max_total_sequence_length is None: + kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( + models, config_file_paths + ) + if kv_cache_config.prefill_chunk_size is None: + kv_cache_config.prefill_chunk_size = prefill_chunk_size + elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: + raise ValueError( + f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " + f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " + "models. Please specify a smaller prefill chunk size." + ) + + module = tvm.get_global_func("mlc.serve.create_threaded_engine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "add_request", + "abort_request", + "run_background_loop", + "init_background_engine", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(tokenizer_path) + if engine_mode is None: + # The default engine mode: non-speculative + engine_mode = EngineMode() + + # The mapping from request ids to request asynchronous stream. + self._request_tools: Dict[str, Tuple[AsyncRequestStream, TextStreamer]] = {} + + def _background_loop(): + self._ffi["init_background_engine"]( + self.max_single_sequence_length, + tokenizer_path, + kv_cache_config.asjson(), + engine_mode.asjson(), + self._request_stream_callback, + self.trace_recorder, + *model_args, + ) + self._ffi["run_background_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_loop_thread.start() + # The main thread request handling asyncio event loop, which will + # be lazily initialized. + self._async_event_loop: Optional[asyncio.AbstractEventLoop] = None + self._terminated = False + + def terminate(self): + """Terminate the engine.""" + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + + async def generate( + self, prompt: Union[str, List[int]], generation_config: GenerationConfig, request_id: str + ) -> AsyncGenerator[Tuple[str, int, Optional[List[str]], Optional[str]], Any]: + """Asynchronous text generation interface. + The method is a coroutine that streams a tuple at a time via yield. + Each tuple is contained of + - the delta text in type str, + - the number of delta tokens in type int, + - the logprob JSON strings of delta tokens, + - the optional finish reason in type Optional[str]. + + Parameters + ---------- + prompt : Union[str, List[int]] + The input prompt in forms of text string or a list of token ids. + + generation_config : GenerationConfig + The generation config of the request. + + request_id : str + The unique identifier (in string) or this generation request. + """ + if self._terminated: + raise ValueError("The AsyncThreadedEngine has terminated.") + if self._async_event_loop is None: + # Lazily set the asyncio event loop so that the event + # loop is the main driving event loop of the process. + self._async_event_loop = asyncio.get_event_loop() + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = data.TextData(prompt) if isinstance(prompt, str) else data.TokenData(prompt) + request = Request(request_id, input_data, generation_config) + + # Create the unique stream of the request. + stream = AsyncRequestStream() + if request_id in self._request_tools: + # Report error in the stream if the request id already exists. + stream.push( + RuntimeError( + f'The request id "{request_id} already exists. ' + 'Please make sure the request id is unique."' + ) + ) + else: + # Record the stream in the tracker + self._request_tools[request_id] = (stream, TextStreamer(self.tokenizer)) + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the token. + try: + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: # pylint: disable=broad-exception-caught + await self.abort(request_id) + raise e + + async def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._abort(request_id) + + def _abort(self, request_id: str): + """Internal implementation of request abortion.""" + self._request_tools.pop(request_id, None) + self._ffi["abort_request"](request_id) + + def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for engine to stream back + the request generation results. + + Parameters + ---------- + delta_outputs : List[data.RequestStreamOutput] + The delta output of each requests. + Check out data.RequestStreamOutput for the fields of the outputs. + + Note + ---- + This callback function uses `call_soon_threadsafe` in asyncio to + schedule the invocation in the event loop, so that the underlying + callback logic will be executed asynchronously in the future rather + than right now. + """ + # Schedule a callback run in the event loop without executing right now. + # NOTE: This function causes GIL during execution. + self._async_event_loop.call_soon_threadsafe( + self._request_stream_callback_impl, delta_outputs + ) + + def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The underlying implementation of request stream callback.""" + for delta_output in delta_outputs: + ( + request_id, + delta_token_ids, + delta_logprob_json_strs, + finish_reason, + ) = delta_output.unpack() + tools = self._request_tools.get(request_id, None) + if tools is None: + continue + + self.record_event(request_id, event="start callback") + stream, text_streamer = tools + + self.record_event(request_id, event="start detokenization") + delta_text = text_streamer.put(delta_token_ids) + if finish_reason is not None: + delta_text += text_streamer.finish() + self.record_event(request_id, event="finish detokenization") + + # Push new delta text to the stream. + stream.push((delta_text, len(delta_token_ids), delta_logprob_json_strs, finish_reason)) + if finish_reason is not None: + stream.finish() + self._request_tools.pop(request_id, None) + self.record_event(request_id, event="finish callback") + + def record_event(self, request_id: str, event: str) -> None: + """Record a event for the the input request in the trace + recorder when the recorder exists. + + Parameters + ---------- + request_id : str + The subject request of the event. + + event : str + The event in a string name. + It can have one of the following patterns: + - "start xxx", which marks the start of event "xxx", + - "finish xxx", which marks the finish of event "xxx", + - "yyy", which marks the instant event "yyy". + The "starts" and "finishes" will be automatically paired in the trace recorder. + """ + if self.trace_recorder is None: + return + self.trace_recorder.add_event(request_id, event) diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py new file mode 100644 index 0000000..ccc152a --- /dev/null +++ b/python/mlc_chat/serve/config.py @@ -0,0 +1,155 @@ +"""Configuration dataclasses used in MLC LLM serving""" + +import json +from dataclasses import asdict, dataclass, field +from typing import Dict, List, Optional + + +@dataclass +class GenerationConfig: # pylint: disable=too-many-instance-attributes + """The generation configuration dataclass. + + Parameters + ---------- + temperature : float + The value that applies to logits and modulates the next token probabilities. + + top_p : float + In sampling, only the most probable tokens with probabilities summed up to + `top_k` are kept for sampling. + + frequency_penalty : float + Positive values penalize new tokens based on their existing frequency + in the text so far, decreasing the model's likelihood to repeat the same + line verbatim. + + presence_penalty : float + Positive values penalize new tokens based on whether they appear in the text + so far, increasing the model's likelihood to talk about new topics. + + repetition_penalty : float + The penalty term that applies to logits to control token repetition in generation. + It will be suppressed when any of frequency_penalty and presence_penalty is + non-zero. + + logprobs : bool + Whether to return log probabilities of the output tokens or not. + If true, the log probabilities of each output token will be returned. + + top_logprobs : int + An integer between 0 and 5 specifying the number of most likely + tokens to return at each token position, each with an associated + log probability. + `logprobs` must be set to True if this parameter is used. + + logit_bias : Optional[Dict[int, float]] + The bias logit value added to selected tokens prior to sampling. + + max_tokens : Optional[int] + The maximum number of generated tokens, + or None, in which case the generation will not stop + until exceeding model capability or hit any stop criteria. + + seed : Optional[int] + The random seed of the generation. + The seed will be a random value if not specified. + + stop_strs : List[str] + The list of strings that mark the end of generation. + + stop_token_ids : List[int] + The list of token ids that mark the end of generation. + + ignore_eos: bool + When it is true, ignore the eos token and generate tokens until `max_tokens`. + Default is set to False. + """ + + temperature: float = 0.8 + top_p: float = 0.95 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) + + max_tokens: Optional[int] = 128 + seed: Optional[int] = None + stop_strs: List[str] = field(default_factory=list) + stop_token_ids: List[int] = field(default_factory=list) + ignore_eos: bool = False + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "GenerationConfig": + """Construct a config from JSON string.""" + return GenerationConfig(**json.loads(json_str)) + + +@dataclass +class KVCacheConfig: + """The KV cache initialization configuration. + + Parameters + ---------- + page_size : int + The number of consecutive tokens handled in each page in paged KV cache. + + max_num_sequence : int + The maximum number of sequences that are allowed to processed by the KV + cache at any time. + + max_total_sequence_length : Optional[int] + The maximum total number of tokens whose KV data are allowed to exist + in the KV cache at any time. + Set it to None to enable automatic computation of the max total + sequence length. + + prefill_chunk_size : Optional[int] + The maximum total sequence length in a prefill. + If not specified, it will be automatically inferred from model config. + """ + + page_size: int = 16 + max_num_sequence: int = 32 + max_total_sequence_length: Optional[int] = None + prefill_chunk_size: Optional[int] = None + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "KVCacheConfig": + """Construct a config from JSON string.""" + return KVCacheConfig(**json.loads(json_str)) + + +@dataclass +class EngineMode: + """The Engine execution mode. + + Parameters + ---------- + enable_speculative : bool + Whether the speculative decoding mode is enabled, default False. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft), default 4. + """ + + enable_speculative: bool = False + spec_draft_length: int = 4 + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "EngineMode": + """Construct a config from JSON string.""" + return EngineMode(**json.loads(json_str)) diff --git a/python/mlc_chat/serve/data.py b/python/mlc_chat/serve/data.py new file mode 100644 index 0000000..15c0a4f --- /dev/null +++ b/python/mlc_chat/serve/data.py @@ -0,0 +1,117 @@ +"""Classes denoting multi-modality data used in MLC LLM serving""" + +from typing import List, Optional, Tuple + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.Data") # pylint: disable=protected-access +class Data(Object): + """The base class of multi-modality data (text, tokens, embedding, etc).""" + + def __init__(self): + pass + + +@tvm._ffi.register_object("mlc.serve.TextData") # pylint: disable=protected-access +class TextData(Data): + """The class of text data, containing a text string. + + Parameters + ---------- + text : str + The text string. + """ + + def __init__(self, text: str): + self.__init_handle_by_constructor__(_ffi_api.TextData, text) # type: ignore # pylint: disable=no-member + + @property + def text(self) -> str: + """The text data in `str`.""" + return str(_ffi_api.TextDataGetTextString(self)) # type: ignore # pylint: disable=no-member + + def __str__(self) -> str: + return self.text + + +@tvm._ffi.register_object("mlc.serve.TokenData") # type: ignore # pylint: disable=protected-access +class TokenData(Data): + """The class of token data, containing a list of token ids. + + Parameters + ---------- + token_ids : List[int] + The list of token ids. + """ + + def __init__(self, token_ids: List[int]): + self.__init_handle_by_constructor__(_ffi_api.TokenData, *token_ids) # type: ignore # pylint: disable=no-member + + @property + def token_ids(self) -> List[int]: + """Return the token ids of the TokenData.""" + return list(_ffi_api.TokenDataGetTokenIds(self)) # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access +class RequestStreamOutput(Object): + """The generated delta request output that is streamed back + through callback stream function. + It contains four fields (in order): + + request_id : str + The id of the request that the function is invoked for. + + delta_tokens : List[int] + The new generated tokens since the last callback invocation + for the input request. + + delta_logprob_json_strs : Optional[List[str]] + The logprobs JSON strings of the new generated tokens + since last invocation. + + finish_reason : Optional[str] + The finish reason of the request when it is finished, + of None if the request has not finished yet. + + Note + ---- + We do not provide constructor, since in practice only C++ side + instantiates this class. + """ + + def unpack(self) -> Tuple[str, List[int], Optional[List[str]], Optional[str]]: + """Return the fields of the delta output in a tuple. + + Returns + ------- + request_id : str + The id of the request that the function is invoked for. + + delta_tokens : List[int] + The new generated tokens since the last callback invocation + for the input request. + + delta_logprob_json_strs : Optional[List[str]] + The logprobs JSON strings of the new generated tokens + since last invocation. + + finish_reason : Optional[str] + The finish reason of the request when it is finished, + of None if the request has not finished yet. + """ + fields = _ffi_api.RequestStreamOutputUnpack(self) # type: ignore # pylint: disable=no-member + return ( + str(fields[0]), + list(fields[1]), + ( + [str(logprob_json_str) for logprob_json_str in fields[2]] + if fields[2] is not None + else None + ), + str(fields[3]) if fields[3] is not None else None, + ) diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py new file mode 100644 index 0000000..407fb72 --- /dev/null +++ b/python/mlc_chat/serve/engine.py @@ -0,0 +1,494 @@ +"""The MLC LLM Serving Engine.""" + +import json +import os +import subprocess +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import tvm +from tvm.runtime import Device + +from mlc_chat.serve import data +from mlc_chat.support import logging +from mlc_chat.support.auto_device import detect_device +from mlc_chat.support.style import green + +from ..chat_module import _get_chat_config, _get_lib_module_path, _get_model_path +from ..streamer import TextStreamer +from ..tokenizer import Tokenizer +from . import data +from .config import EngineMode, GenerationConfig, KVCacheConfig +from .event_trace_recorder import EventTraceRecorder +from .request import Request + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclass +class ModelInfo: + """The model info dataclass. + + Parameters + ---------- + model : str + The identifier of the input model. + It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), + or a full path to a model directory + (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") + + device : str + The device where to run the model. + It can be "auto", "device_name" (e.g., "cuda") or + "device_name:device_id" (e.g., "cuda:1"). + + model_lib_path : str + The path to the compiled library of the model. + E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + """ + + model: str + model_lib_path: str + device: Device = "auto" # type: ignore + + def __post_init__(self): + if isinstance(self.device, str): + self.device = detect_device(self.device) + assert isinstance(self.device, Device) + + +def _create_tvm_module( + creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None +) -> Dict[str, Callable]: + """Internal method to create a module.""" + if creator_args is None: + creator_args = [] + module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) + return {key: module[key] for key in ffi_funcs} + + +def _process_model_args( + models: List[ModelInfo], +) -> Tuple[List[Any], List[str], str, int, int, Optional[str]]: + """Process the input ModelInfo to get the engine initialization arguments.""" + max_single_sequence_length = int(1e9) + prefill_chunk_size = int(1e9) + tokenizer_path: Optional[str] = None + conv_template_name: Optional[str] = None + config_file_paths: List[str] = [] + + def _convert_model_info(model: ModelInfo) -> List[Any]: + nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conv_template_name + + device = model.device + model_path, config_file_path = _get_model_path(model.model) + config_file_paths.append(config_file_path) + chat_config = _get_chat_config(config_file_path, user_chat_config=None) + if chat_config.context_window_size: + max_single_sequence_length = min( + max_single_sequence_length, + chat_config.context_window_size, + ) + if chat_config.prefill_chunk_size: + prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) + if tokenizer_path is None: + tokenizer_path = model_path + if conv_template_name is None: + conv_template_name = chat_config.conv_template + # Try look up model library, and do JIT compile if model library not found. + try: + model_lib_path = _get_lib_module_path( + model=model.model, + model_path=model_path, + chat_config=chat_config, + model_lib_path=model.model_lib_path, + device_name=device.MASK2STR[device.device_type], + config_file_path=config_file_path, + ) + except FileNotFoundError: + from mlc_chat.interface import ( # pylint: disable=import-outside-toplevel + jit, + ) + + model_lib_path = str( + jit.jit( + model_path=Path(model_path), + chat_config=asdict(chat_config), + device=device, + ) + ) + return [model_lib_path, model_path, device.device_type, device.device_id] + + model_args: List[Any] = sum( + (_convert_model_info(model) for model in models), + start=[], + ) + + return ( + model_args, + config_file_paths, + tokenizer_path, + max_single_sequence_length, + prefill_chunk_size, + conv_template_name, + ) + + +def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals + models: List[ModelInfo], config_file_paths: List[str] +) -> int: + """Estimate the max total sequence length (capacity) of the KV cache.""" + assert len(models) != 0 + + kv_bytes_per_token = 0 + params_bytes = 0 + temp_func_bytes = 0 + + for model, config_file_path in zip(models, config_file_paths): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_chat.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + config_file_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + params_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + # Read model config and compute the kv size per token. + with open(config_file_path, mode="rt", encoding="utf-8") as file: + json_object = json.load(file) + model_config = json_object["model_config"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + num_qo_heads = model_config["num_attention_heads"] + num_kv_heads = model_config["num_key_value_heads"] + tensor_parallel_shards = model_config["tensor_parallel_shards"] + kv_bytes_per_token += ( + (hidden_size / num_qo_heads) + * (num_kv_heads / tensor_parallel_shards) # on single GPU + * num_layers + * 4 # key, value, fp16 + * 1.10 # over estimation to guarantee safety + ) + + # Get single-card GPU size. + gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) + if gpu_size_bytes is None: + gpu_size_bytes = models[0].device.total_global_memory + if gpu_size_bytes is None: + raise ValueError( + "Cannot read total GPU global memory from device. " + 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' + ) + + max_total_sequence_length = int( + (int(gpu_size_bytes) * 0.97 - params_bytes * 1.04 - temp_func_bytes) / kv_bytes_per_token + ) + assert max_total_sequence_length > 0, ( + "Cannot estimate KV cache capacity. " + f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" + ) + + total_size = ( + params_bytes * 1.05 + temp_func_bytes + kv_bytes_per_token * max_total_sequence_length + ) + logger.info( + "%s: %d.", + green('Estimated KVCacheConfig "max_total_sequence_length"'), + max_total_sequence_length, + ) + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", + green("Estimated total single GPU memory usage"), + total_size / 1024 / 1024, + params_bytes / 1024 / 1024, + kv_bytes_per_token * max_total_sequence_length / 1024 / 1024, + temp_func_bytes / 1024 / 1024, + ) + return int(max_total_sequence_length) + + +class Engine: + """The Python interface of request serving engine for MLC LLM. + + The engine can run one or multiple LLM models internally for + text generation. Usually, when there are multiple models, + speculative inference will be activated, where the first model + (index 0) is the main "large model" that has better generation + quality, and all other models are "small" models that used for + speculation. + + The engine receives requests from the "add_request" method. For + an given request, the engine will keep generating new tokens for + the request until finish (under certain criterion). After finish, + the engine will return the generation result through the callback + function provided by the request. + + Parameters + ---------- + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] + The provided callback function to handle the generation + output. It has the signature of `(str, data.TokenData, bool) -> None`, + where + - the first string is the request id, + - the TokenData contains the generated **delta** token ids since + the last invocation of the callback on the specific request, + - the optional string value denotes the finish reason if the + generation of the request is finished, or None if it has not finished. + + The callback function is optional at construction, but it needs to + be set before the engine executing requests. This can be done via + the `set_request_stream_callback` method. Otherwise, the engine will raise + exception. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + models: Union[ModelInfo, List[ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + enable_tracing: bool = False, + ): + if isinstance(models, ModelInfo): + models = [models] + ( + model_args, + config_file_paths, + tokenizer_path, + self.max_single_sequence_length, + prefill_chunk_size, + self.conv_template_name, + ) = _process_model_args(models) + self._ffi = _create_tvm_module( + "mlc.serve.create_engine", + ffi_funcs=[ + "init", + "add_request", + "abort_request", + "step", + "stats", + "reset", + "get_request_stream_callback", + "set_request_stream_callback", + ], + ) + self.trace_recorder = EventTraceRecorder() if enable_tracing else None + + if kv_cache_config.max_total_sequence_length is None: + kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( + models, config_file_paths + ) + if kv_cache_config.prefill_chunk_size is None: + kv_cache_config.prefill_chunk_size = prefill_chunk_size + elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: + raise ValueError( + f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " + f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " + "models. Please specify a smaller prefill chunk size." + ) + + if engine_mode is None: + # The default engine mode: non-speculative + engine_mode = EngineMode() + + self._ffi["init"]( + self.max_single_sequence_length, + tokenizer_path, + kv_cache_config.asjson(), + engine_mode.asjson(), + request_stream_callback, + self.trace_recorder, + *model_args, + ) + self.tokenizer = Tokenizer(tokenizer_path) + + def generate( + self, + prompts: Union[str, List[str], List[int], List[List[int]]], + generation_config: Union[GenerationConfig, List[GenerationConfig]], + ) -> Tuple[List[str], List[Optional[List[str]]]]: + """Generate texts for a list of input prompts. + Each prompt can be a string or a list of token ids. + The generation for each prompt is independent. + Return the generation results, one for each prompt. + + Parameters + ---------- + prompts : Union[str, List[str], List[int], List[List[int]]] + One or a list of input prompts for text generation. + Each prompt can be a string or a list of token ids. + + generation_config : Union[GenerationConfig, List[GenerationConfig]] + The generation config for each requests. + If the it is a single GenerationConfig instance, + this config will be shared by all the prompts. + Otherwise, one generation config is required for every + prompt. + + Returns + ------- + output_text : List[str] + The text generation results, one string for each input prompt. + + output_logprobs_str : List[Optional[List[str]]] + The logprob strings of each token for each input prompt, or None + if an input prompt does not require logprobs. + """ + if isinstance(prompts, str): + # `prompts` is a single string. + prompts = [prompts] + else: + assert isinstance(prompts, list), ( + "Input `prompts` is expected to be a string, a list of " + "str, a list of token ids or multiple lists of token ids." + ) + if len(prompts) == 0: + return [], [] + if isinstance(prompts[0], int): + # `prompts` is a list of token ids + prompts = [prompts] # type: ignore + + num_requests = len(prompts) + if not isinstance(generation_config, list): + generation_config = [generation_config] * num_requests + + assert ( + len(generation_config) == num_requests + ), "Number of generation config and number of prompts mismatch" + + num_finished_requests = 0 + output_texts: List[str] = [] + output_logprobs_str: List[Optional[List[str]]] = [] + text_streamers: List[TextStreamer] = [] + for i in range(num_requests): + output_texts.append("") + output_logprobs_str.append([] if generation_config[i].logprobs else None) + text_streamers.append(TextStreamer(self.tokenizer)) + + # Save a copy of the original function callback since `generate` + # overrides the callback function. + # The original callback will be set back later on. + original_callback = self._ffi["get_request_stream_callback"]() + + # Define the callback function for request generation results + def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): + nonlocal num_finished_requests + for delta_output in delta_outputs: + ( + request_id, + delta_token_ids, + delta_logprob_json_strs, + finish_reason, + ) = delta_output.unpack() + rid = int(request_id) + text_streamer = text_streamers[rid] + if output_logprobs_str[rid] is not None: + assert delta_logprob_json_strs is not None + output_logprobs_str[rid] += delta_logprob_json_strs + + delta_text = text_streamer.put(delta_token_ids) + if finish_reason is not None: + delta_text += text_streamer.finish() + + output_texts[rid] += delta_text + if finish_reason is not None: + num_finished_requests += 1 + + # Override the callback function in engine. + self._ffi["set_request_stream_callback"](request_stream_callback) + + # Add requests to engine. + for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): + input_data = ( + data.TextData(prompt) + if isinstance(prompt, str) + else data.TokenData(prompt) # type: ignore + ) + self.add_request( + Request( + request_id=str(req_id), + inputs=input_data, + generation_config=generation_cfg, + ) + ) + + while num_finished_requests != num_requests: + self.step() + + # Restore the callback function in engine. + self._ffi["set_request_stream_callback"](original_callback) + return output_texts, output_logprobs_str + + def add_request(self, request: Request) -> None: + """Add a new request to the engine. + + Parameters + ---------- + request : Request + The request to add. + """ + self._ffi["add_request"](request) + + def abort_request(self, request_id: str) -> None: + """Abort the generation of the request corresponding to the input request id. + + Parameters + ---------- + request_id : str + The unique id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def step(self) -> None: + """The main function that the engine takes a step of action. + + At each step, the engine may decide to + - run prefill for one (or more) requests, + - run one-step decode for the all existing requests + ... + + In the end of certain actions (e.g., decode), the engine will + check if any request has finished, and will return the + generation results for those finished requests. + """ + self._ffi["step"]() + + def reset(self) -> None: + """Reset the engine, clean up all running data and statistics.""" + self._ffi["reset"]() + + def stats(self) -> Dict[str, float]: + """The engine runtime statistics. + We collect the following entries: + - single token prefill latency (s/tok): avg latency of processing one token in prefill + - single token decode latency (s/tok): avg latency of processing one token in decode + - engine time for prefill (sec) + - engine time for decode (sec) + - total number of processed tokens in prefill. + - total number of processed tokens in decode. + """ + stats_json_str = self._ffi["stats"]() + return json.loads(stats_json_str) diff --git a/python/mlc_chat/serve/entrypoints/__init__.py b/python/mlc_chat/serve/entrypoints/__init__.py new file mode 100644 index 0000000..3002bf8 --- /dev/null +++ b/python/mlc_chat/serve/entrypoints/__init__.py @@ -0,0 +1,2 @@ +"""The entrypoints for MLC LLM server.""" +from . import debug_entrypoints, openai_entrypoints diff --git a/python/mlc_chat/serve/entrypoints/debug_entrypoints.py b/python/mlc_chat/serve/entrypoints/debug_entrypoints.py new file mode 100644 index 0000000..45da755 --- /dev/null +++ b/python/mlc_chat/serve/entrypoints/debug_entrypoints.py @@ -0,0 +1,48 @@ +"""MLC LLM server debug entrypoints""" +import json +from http import HTTPStatus + +import fastapi + +from ..server import ServerContext +from . import entrypoint_utils + +app = fastapi.APIRouter() + +################ /debug/dump_event_trace ################ + + +@app.post("/debug/dump_event_trace") +async def debug_dump_event_trace(request: fastapi.Request): + """Return the recorded events in Chrome Trace Event Format in JSON string. + The input request payload should have only one field, specifying the + model to query. For example: `{"model": "Llama-2-7b-chat-hf-q0f16"}`. + """ + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # - Check the requested model. + model = request_dict["model"] + async_engine = ServerContext.get_engine(model) + if async_engine is None: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' + ) + if async_engine.trace_recorder is None: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" does not enable tracing' + ) + + return json.loads(async_engine.trace_recorder.dump_json()) diff --git a/python/mlc_chat/serve/entrypoints/entrypoint_utils.py b/python/mlc_chat/serve/entrypoints/entrypoint_utils.py new file mode 100644 index 0000000..5a9924b --- /dev/null +++ b/python/mlc_chat/serve/entrypoints/entrypoint_utils.py @@ -0,0 +1,92 @@ +"""Utility functions for server entrypoints""" + +import uuid +from http import HTTPStatus +from typing import Callable, List, Optional, Union + +import fastapi + +from ...protocol import RequestProtocol +from ...protocol.protocol_utils import ErrorResponse, get_unsupported_fields + + +def random_uuid() -> str: + """Generate a random id in hexadecimal string.""" + return uuid.uuid4().hex + + +def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse: + """Create a JSON response that reports error with regarding the input message.""" + return fastapi.responses.JSONResponse( + ErrorResponse(message=message, code=status_code.value).model_dump_json(), + status_code=status_code.value, + ) + + +def check_unsupported_fields( + request: RequestProtocol, +) -> Optional[fastapi.responses.JSONResponse]: + """Check if the request has unsupported fields. Return an error if so.""" + unsupported_fields = get_unsupported_fields(request) + if len(unsupported_fields) != 0: + unsupported_fields = [f'"{field}"' for field in unsupported_fields] + return create_error_response( + HTTPStatus.BAD_REQUEST, + message=f'Request fields {", ".join(unsupported_fields)} are not supported right now.', + ) + return None + + +def check_prompts_length( + prompts: List[List[int]], max_single_sequence_length: int +) -> Optional[fastapi.responses.JSONResponse]: + """Check if the total prompt length exceeds the max single sequence + sequence length allowed by the served model. Return an error if so. + """ + total_length = 0 + for prompt in prompts: + total_length += len(prompt) + if total_length > max_single_sequence_length: + return create_error_response( + HTTPStatus.BAD_REQUEST, + message=f"Request prompt has {total_length} tokens in total," + f" larger than the model capacity {max_single_sequence_length}.", + ) + return None + + +def process_prompts( + input_prompts: Union[str, List[int], List[Union[str, List[int]]]], + ftokenize: Callable[[str], List[int]], +) -> Union[List[List[int]], fastapi.responses.JSONResponse]: + """Convert all input tokens to list of token ids with regard to the + given tokenization function. + For each input prompt, return the list of token ids after tokenization. + """ + error_msg = f"Invalid request prompt {input_prompts}" + + # Case 1. The prompt is a single string. + if isinstance(input_prompts, str): + return [ftokenize(input_prompts)] + + assert isinstance(input_prompts, list) + if len(input_prompts) == 0: + return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + + # Case 2. The prompt is a list of token ids. + if isinstance(input_prompts[0], int): + if not all(isinstance(token_id, int) for token_id in input_prompts): + return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + return [input_prompts] + + # Case 3. A list of prompts. + output_prompts: List[List[int]] = [] + for input_prompt in input_prompts: + is_str = isinstance(input_prompt, str) + is_token_ids = isinstance(input_prompt, list) and all( + isinstance(token_id, int) for token_id in input_prompt + ) + if not (is_str or is_token_ids): + return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + output_prompts.append(ftokenize(input_prompt) if is_str else input_prompt) # type: ignore + return output_prompts diff --git a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py new file mode 100644 index 0000000..de85ab8 --- /dev/null +++ b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py @@ -0,0 +1,537 @@ +"""OpenAI API-compatible server entrypoints in MLC LLM""" + +# pylint: disable=too-many-locals,too-many-return-statements,too-many-statements +import ast +import json +from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional, Union + +import fastapi + +from ...protocol import protocol_utils +from ...protocol.conversation_protocol import Conversation +from ...protocol.openai_api_protocol import ( + ChatCompletionMessage, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + ChatFunctionCall, + ChatToolCall, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + ListResponse, + LogProbs, + LogProbsContent, + ModelResponse, + UsageInfo, +) +from ..server import ServerContext +from . import entrypoint_utils + +app = fastapi.APIRouter() + +################ v1/models ################ + + +@app.get("/v1/models") +async def request_models(): + """OpenAI-compatible served model query API. + API reference: https://platform.openai.com/docs/api-reference/models + """ + return ListResponse(data=[ModelResponse(id=model) for model in ServerContext.get_model_list()]) + + +################ v1/completions ################ + + +@app.post("/v1/completions") +async def request_completion(request: CompletionRequest, raw_request: fastapi.Request): + """OpenAI-compatible completion API. + API reference: https://platform.openai.com/docs/api-reference/completions/create + """ + # - Check the requested model. + async_engine = ServerContext.get_engine(request.model) + if async_engine is None: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' + ) + request_id = f"cmpl-{entrypoint_utils.random_uuid()}" + async_engine.record_event(request_id, event="receive request") + + # - Check if unsupported arguments are specified. + error = entrypoint_utils.check_unsupported_fields(request) + if error is not None: + return error + + # - Process prompt and check validity. + async_engine.record_event(request_id, event="start tokenization") + prompts = entrypoint_utils.process_prompts(request.prompt, async_engine.tokenizer.encode) + async_engine.record_event(request_id, event="finish tokenization") + if isinstance(prompts, fastapi.responses.JSONResponse): + # Errored when processing the prompts + return prompts + if len(prompts) > 1: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, + message="Entrypoint /v1/completions only accept single prompt. " + f"However, {len(prompts)} prompts {prompts} are received.", + ) + error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_single_sequence_length) + if error is not None: + return error + prompt = prompts[0] + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config(request) + + # Streaming response. + if request.stream: + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + assert request.n == 1 + + # - Echo back the prompt. + if request.echo: + text = async_engine.tokenizer.decode(prompt) + response = CompletionResponse( + id=request_id, + choices=[CompletionResponseChoice(text=text)], + model=request.model, + usage=UsageInfo( + prompt_tokens=len(prompt), + completion_tokens=0, + ), + ) + yield f"data: {response.model_dump_json()}\n\n" + + # - Generate new tokens. + num_completion_tokens = 0 + finish_reason = None + async_engine.record_event(request_id, event="invoke generate") + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): + num_completion_tokens += num_delta_tokens + if delta_text == "": + # Ignore empty delta text -- do not yield. + continue + + response = CompletionResponse( + id=request_id, + choices=[ + CompletionResponseChoice( + finish_reason=finish_reason, + text=delta_text, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in delta_logprob_json_strs + ] + ) + if delta_logprob_json_strs is not None + else None + ), + ) + ], + model=request.model, + usage=UsageInfo( + prompt_tokens=len(prompt), + completion_tokens=num_completion_tokens, + ), + ) + yield f"data: {response.model_dump_json()}\n\n" + async_engine.record_event(request_id, event="finish") + + # - Echo the suffix. + if request.suffix is not None: + assert finish_reason is not None + response = CompletionResponse( + id=request_id, + choices=[ + CompletionResponseChoice( + finish_reason=finish_reason, + text=request.suffix, + ) + ], + model=request.model, + usage=UsageInfo( + prompt_tokens=len(prompt), + completion_tokens=num_completion_tokens, + ), + ) + yield f"data: {response.model_dump_json()}\n\n" + + yield "data: [DONE]\n\n" + + return fastapi.responses.StreamingResponse( + completion_stream_generator(), media_type="text/event-stream" + ) + + # Normal response. + output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) + num_completion_tokens = 0 + finish_reason: Optional[str] = None + logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None + async_engine.record_event(request_id, event="invoke generate") + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): + if await raw_request.is_disconnected(): + # In non-streaming cases, the engine will not be notified + # when the request is disconnected. + # Therefore, we check if it is disconnected each time, + # and abort the request from engine if so. + await async_engine.abort(request_id) + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message="The request has disconnected" + ) + output_text += delta_text + num_completion_tokens += num_delta_tokens + if logprob_json_strs is not None: + assert delta_logprob_json_strs is not None + logprob_json_strs += delta_logprob_json_strs + assert finish_reason is not None + suffix = request.suffix if request.suffix is not None else "" + async_engine.record_event(request_id, event="finish") + response = CompletionResponse( + id=request_id, + choices=[ + CompletionResponseChoice( + finish_reason=finish_reason, + text=output_text + suffix, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in logprob_json_strs + ] + ) + if logprob_json_strs is not None + else None + ), + ) + ], + model=request.model, + usage=UsageInfo( + prompt_tokens=len(prompt), + completion_tokens=num_completion_tokens, + ), + ) + return response + + +################ v1/chat/completions ################ + + +def chat_completion_check_message_validity( + messages: List[ChatCompletionMessage], +) -> Optional[str]: + """Check if the given chat messages are valid. Return error message if invalid.""" + for i, message in enumerate(messages): + if message.role == "system" and i != 0: + return f"System prompt at position {i} in the message list is invalid." + if message.role == "tool": + return "Tool as the message author is not supported yet." + if message.tool_call_id is not None: + if message.role != "tool": + return "Non-tool message having `tool_call_id` is invalid." + if isinstance(message.content, list): + if message.role != "user": + return "Non-user message having a list of content is invalid." + return "User message having a list of content is not supported yet." + if message.tool_calls is not None: + if message.role != "assistant": + return "Non-assistant message having `tool_calls` is invalid." + return "Assistant message having `tool_calls` is not supported yet." + return None + + +def check_function_call_usage( + request: ChatCompletionRequest, conv_template: Conversation +) -> Optional[str]: + """Check if function calling is used and update the conversation template. + Return error message if invalid request format for function calling. + """ + + # return if no tools are provided or tool_choice is set to none + if request.tools is None or ( + isinstance(request.tool_choice, str) and request.tool_choice == "none" + ): + conv_template.use_function_calling = False + return None + + # select the tool based on the tool_choice if specified + if isinstance(request.tool_choice, dict): + if request.tool_choice["type"] != "function": + return "Only 'function' tool choice is supported" + + if len(request.tool_choice["function"]) > 1: + return "Only one tool is supported when tool_choice is specified" + + for tool in request.tools: + if tool.function.name == request.tool_choice["function"]["name"]: + conv_template.use_function_calling = True + conv_template.function_string = tool.function.model_dump_json() + return None + + return ( + f"The tool_choice function {request.tool_choice['function']['name']}" + " is not found in the tools list" + ) + + if isinstance(request.tool_choice, str) and request.tool_choice != "auto": + return f"Invalid tool_choice value: {request.tool_choice}" + + function_list = [] + for tool in request.tools: + if tool.type != "function": + return "Only 'function' tool type is supported" + function_list.append(tool.function.model_dump()) + + conv_template.use_function_calling = True + conv_template.function_string = json.dumps(function_list) + return None + + +def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: + """Convert a (possibly list) of function call string to a list of json objects. + Return None for invalid function call string.""" + + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None + + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) + else: + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json + + +@app.post("/v1/chat/completions") +async def request_chat_completion( + request: ChatCompletionRequest, raw_request: fastapi.Request +): # pylint: disable=too-many-branches + """OpenAI-compatible chat completion API. + API reference: https://platform.openai.com/docs/api-reference/chat + """ + # - Check the requested model. + async_engine = ServerContext.get_engine(request.model) + if async_engine is None: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' + ) + request_id = f"chatcmpl-{entrypoint_utils.random_uuid()}" + async_engine.record_event(request_id, event="receive request") + + # - Check if the model supports chat conversation. + conv_template = ServerContext.get_conv_template(request.model) + if conv_template is None: + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, + message=f'The requested model "{request.model}" does not support chat.', + ) + + # - Check if unsupported arguments are specified. + error = entrypoint_utils.check_unsupported_fields(request) + if error is not None: + return error + + # - Process messages and update the conversation template in three steps: + # i. Check the message validity. + # ii. Add the input messages to the conversation template. + # iii. Add the additional message for the assistant. + error_msg = chat_completion_check_message_validity(request.messages) + if error_msg is not None: + return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + + # Check for function calling usage and update the conversation template + error_msg = check_function_call_usage(request, conv_template) + if error_msg is not None: + return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + + for message in request.messages: + role = message.role + content = message.content + assert isinstance(content, str), "Internal error: content is not a string." + if role == "system": + conv_template.system_message = content if content is not None else "" + continue + + assert role != "tool", "Internal error: tool role." + conv_template.messages.append((role, content)) + conv_template.messages.append(("assistant", None)) + + # - Get the prompt from template, and encode to token ids. + # - Check prompt length + async_engine.record_event(request_id, event="start tokenization") + prompts = entrypoint_utils.process_prompts( + conv_template.as_prompt(), async_engine.tokenizer.encode + ) + async_engine.record_event(request_id, event="finish tokenization") + assert isinstance(prompts, list) and len(prompts) == 1, "Internal error" + if conv_template.system_prefix_token_ids is not None: + prompts[0] = conv_template.system_prefix_token_ids + prompts[0] + error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_single_sequence_length) + if error is not None: + return error + prompt = prompts[0] + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config( + request, + extra_stop_token_ids=conv_template.stop_token_ids, + extra_stop_str=conv_template.stop_str, + ) + + # Streaming response. + if request.stream: + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + assert request.n == 1 + async_engine.record_event(request_id, event="invoke generate") + async for ( + delta_text, + _, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): + if delta_text == "": + async_engine.record_event(request_id, event="skip empty delta text") + # Ignore empty delta text -- do not yield. + continue + + if conv_template.use_function_calling: + finish_reason = "tool_calls" + + response = ChatCompletionStreamResponse( + id=request_id, + choices=[ + ChatCompletionStreamResponseChoice( + finish_reason=finish_reason, + delta=ChatCompletionMessage(content=delta_text, role="assistant"), + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in delta_logprob_json_strs + ] + ) + if delta_logprob_json_strs is not None + else None + ), + ) + ], + model=request.model, + system_fingerprint="", + ) + async_engine.record_event(request_id, event=f"yield delta text {delta_text}") + yield f"data: {response.model_dump_json()}\n\n" + async_engine.record_event(request_id, event="finish") + yield "data: [DONE]\n\n" + + return fastapi.responses.StreamingResponse( + completion_stream_generator(), media_type="text/event-stream" + ) + + # Normal response. + output_text = "" + num_completion_tokens = 0 + finish_reason: Optional[str] = None + logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None + async_engine.record_event(request_id, event="invoke generate") + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): + if await raw_request.is_disconnected(): + # In non-streaming cases, the engine will not be notified + # when the request is disconnected. + # Therefore, we check if it is disconnected each time, + # and abort the request from engine if so. + await async_engine.abort(request_id) + return entrypoint_utils.create_error_response( + HTTPStatus.BAD_REQUEST, message="The request has disconnected" + ) + output_text += delta_text + num_completion_tokens += num_delta_tokens + if logprob_json_strs is not None: + assert delta_logprob_json_strs is not None + logprob_json_strs += delta_logprob_json_strs + assert finish_reason is not None + + async_engine.record_event(request_id, event="finish") + + if conv_template.use_function_calling: + try: + fn_json_list = convert_function_str_to_json(output_text) + except (SyntaxError, ValueError): + output_text = "Got an invalid function call output from model" + finish_reason = "error" + else: + tool_calls = [ + ChatToolCall( + type="function", + function=ChatFunctionCall( + name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] + ), + ) + for fn_json_obj in fn_json_list + if fn_json_obj is not None + ] + if len(tool_calls) == 0: + output_text = "Got an invalid function call output from model" + finish_reason = "error" + else: + finish_reason = "tool_calls" + + message = ( + ChatCompletionMessage(role="assistant", content=output_text) + if (not conv_template.use_function_calling or finish_reason == "error") + else ChatCompletionMessage(role="assistant", content=None, tool_calls=tool_calls) + ) + + return ChatCompletionResponse( + id=request_id, + choices=[ + ChatCompletionResponseChoice( + finish_reason=finish_reason, + message=message, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in logprob_json_strs + ] + ) + if logprob_json_strs is not None + else None + ), + ) + ], + model=request.model, + system_fingerprint="", + usage=UsageInfo(prompt_tokens=len(prompt), completion_tokens=num_completion_tokens), + ) diff --git a/python/mlc_chat/serve/event_trace_recorder.py b/python/mlc_chat/serve/event_trace_recorder.py new file mode 100644 index 0000000..7a8a817 --- /dev/null +++ b/python/mlc_chat/serve/event_trace_recorder.py @@ -0,0 +1,41 @@ +"""The event trace recorder in MLC LLM serving""" + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.EventTraceRecorder") # pylint: disable=protected-access +class EventTraceRecorder(Object): + """The event trace recorder for requests.""" + + def __init__(self) -> None: + """Initialize a trace recorder.""" + self.__init_handle_by_constructor__( + _ffi_api.EventTraceRecorder # type: ignore # pylint: disable=no-member + ) + + def add_event(self, request_id: str, event: str) -> None: + """Record a event for the the input request in the trace recorder. + + Parameters + ---------- + request_id : str + The subject request of the event. + + event : str + The event in a string name. + It can have one of the following patterns: + - "start xxx", which marks the start of event "xxx", + - "finish xxx", which marks the finish of event "xxx", + - "yyy", which marks the instant event "yyy". + The "starts" and "finishes" will be automatically paired in the trace recorder. + """ + return _ffi_api.EventTraceRecorderAddEvent( # type: ignore # pylint: disable=no-member + self, request_id, event + ) + + def dump_json(self) -> str: + """Dump the logged events in Chrome Trace Event Format in JSON string.""" + return _ffi_api.EventTraceRecorderDumpJSON(self) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_chat/serve/grammar.py b/python/mlc_chat/serve/grammar.py new file mode 100644 index 0000000..3df954c --- /dev/null +++ b/python/mlc_chat/serve/grammar.py @@ -0,0 +1,242 @@ +"""Classes handling the grammar guided generation of MLC LLM serving""" +from typing import List, Union + +import tvm._ffi +from tvm.runtime import Object + +from ..tokenizer import Tokenizer +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.BNFGrammar") # pylint: disable=protected-access +class BNFGrammar(Object): + """This class stores the abstract syntax tree (AST) of the Backus-Naur Form (BNF) grammar and + provides utilities to parse and print the AST. User should provide a BNF/EBNF (Extended + Backus-Naur Form) grammar, and use from_ebnf_string to parse and simplify the grammar into an + AST of BNF grammar. + """ + + @staticmethod + def from_ebnf_string( + ebnf_string: str, normalize: bool = True, simplify: bool = True + ) -> "BNFGrammar": + r"""Parse a BNF grammar from a string in BNF/EBNF format. + + This method accepts the EBNF notation from the W3C XML Specification + (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following + changes: + - Using # as comment mark instead of /**/ + - Using C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 + - Do not support A-B (match A and not match B) yet + + See tests/python/serve/json.ebnf for an example. + + Parameters + ---------- + ebnf_string : str + The grammar string. + + normalize : bool + Whether to normalize the grammar. Default: true. Only set to false for the purpose of + testing. + + In The normalized form of a BNF grammar, every rule is in the form: + `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + + I.e. a list of choices, each choice is a sequence of elements. Elements can be a + character class or a rule reference. And if the rule can be empty, the first choice + will be an empty string. + + simplify : bool + Whether to simplify the grammar to make matching more efficient. Default: true. Not + implemented yet. + + Returns + ------- + grammar : BNFGrammar + The parsed BNF grammar. + """ + return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member + ebnf_string, normalize, simplify + ) + + def to_string(self) -> str: + """Print the BNF grammar to a string, in standard BNF format. + + Returns + ------- + grammar_string : str + The BNF grammar string. + """ + return str(_ffi_api.BNFGrammarToString(self)) # type: ignore # pylint: disable=no-member + + def __str__(self) -> str: + return self.to_string() + + @staticmethod + def from_json(json_string: str) -> "BNFGrammar": + """Load a BNF grammar from the raw representation of the AST in JSON format. + + Parameters + ---------- + json_string : str + The JSON string. + + Returns + ------- + grammar : BNFGrammar + The loaded BNF grammar. + """ + return _ffi_api.BNFGrammarFromJSON(json_string) # type: ignore # pylint: disable=no-member + + def to_json(self, prettify: bool = True) -> str: + """Serialize the AST. Dump the raw representation of the AST to a JSON file. + + Parameters + ---------- + prettify : bool + Whether to format the JSON string. If False, all whitespaces will be removed. + + Returns + ------- + json_string : str + The JSON string. + """ + return str( + _ffi_api.BNFGrammarToJSON(self, prettify) # type: ignore # pylint: disable=no-member + ) + + @staticmethod + def get_grammar_of_json() -> "BNFGrammar": + """Get the grammar of standard JSON. + + Returns + ------- + grammar : BNFGrammar + The JSON grammar. + """ + return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.serve.GrammarStateMatcher") # pylint: disable=protected-access +class GrammarStateMatcher(Object): + """A stateful matcher to match tokens to the specified BNF grammar. This class is the core logic + of the grammar-guided generation. + + This class implements the non-deterministic pushdown automaton (NPDA) matching algorithm to + match characters to a BNF grammar. It keep track of the current state of the matching process by + maintaining several stacks internally as possible paths in the NPDA. It also supports + backtracking. + + It is particularly capable of finding the set of tokens that are acceptable for the next step + and storing them in a bitmask. This aids in grammar-guided generation. + + Parameters + ---------- + grammar : BNFGrammar + The BNF grammar to match. + + tokenizer : Union[None, Tokenizer, List[str]] + The tokenizer to use, or the list of tokens. + + (For debug purpose) If None, the matcher will use an empty token set, and can only accept + and match characters. Default: None. + + max_rollback_steps : int + The maximum number of steps to rollback when backtracking. Default: 0. + """ + + def __init__( + self, + grammar: BNFGrammar, + tokenizer: Union[None, Tokenizer, List[str]] = None, + max_rollback_steps: int = 0, + ): + if isinstance(tokenizer, list): + self.__init_handle_by_constructor__( + _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member + grammar, + *tokenizer, + max_rollback_steps, + ) + else: + self.__init_handle_by_constructor__( + _ffi_api.GrammarStateMatcherFromTokenizer, # type: ignore # pylint: disable=no-member + grammar, + tokenizer, + max_rollback_steps, + ) + + def accept_token(self, token_id: int) -> bool: + """Accept one token and update the state of the matcher. + + Parameters + ---------- + token_id : int + The id of the token to accept. + + Returns + ------- + accepted : bool + Whether the token is accepted. + """ + return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member + + def find_next_rejected_tokens(self) -> List[int]: + """Find the ids of the rejected tokens for the next step. + + Returns + ------- + rejected_token_ids : List[int] + A list of rejected token ids. + """ + + return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self) # type: ignore # pylint: disable=no-member + + def rollback(self, num_tokens: int) -> None: + """Rollback the matcher to a previous state. + + Parameters + ---------- + num_tokens : int + The number of tokens to rollback. It cannot exceed the current number of steps, nor can + it exceed the specified maximum number of rollback steps. + """ + _ffi_api.GrammarStateMatcherRollback(self, num_tokens) # type: ignore # pylint: disable=no-member + + def max_rollback_steps(self) -> int: + """Get the maximum number of rollback steps allowed. + + Returns + ------- + max_rollback_steps : int + The maximum number of rollback steps. + """ + return _ffi_api.GrammarStateMatcherMaxRollbackSteps(self) # type: ignore # pylint: disable=no-member + + def reset_state(self) -> None: + """Reset the matcher to the initial state.""" + _ffi_api.GrammarStateMatcherResetState(self) # type: ignore # pylint: disable=no-member + + def debug_accept_char(self, codepoint: int) -> bool: + """Accept one unicode codepoint to the current state. + + Parameters + ---------- + codepoint : int + The unicode codepoint of the character to be accepted. + """ + return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member + self, codepoint + ) + + def debug_match_complete_string(self, string: str) -> bool: + """Check if a matcher can accept the complete string, and then reach the end of the + grammar. + + Parameters + ---------- + string : str + The string to be matched. + """ + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_chat/serve/request.py b/python/mlc_chat/serve/request.py new file mode 100644 index 0000000..5c2d8ad --- /dev/null +++ b/python/mlc_chat/serve/request.py @@ -0,0 +1,58 @@ +"""The request class in MLC LLM serving""" + +from typing import List, Union + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api +from .config import GenerationConfig +from .data import Data + + +@tvm._ffi.register_object("mlc.serve.Request") # pylint: disable=protected-access +class Request(Object): + """The user submitted text-generation request, which contains + a unique request id, a list of multi-modal inputs, a set of + generation configuration parameters. + + Parameters + ---------- + request_id : str + The unique identifier of the request. + Different requests should have different ids. + + inputs : List[Data] + The user inputs of a request. Input may have multi-modality. + + generation_config : GenerationConfig + The sampling configuration which may contain temperature, + top_p, repetition_penalty, max_gen_len, etc. + """ + + def __init__( + self, + request_id: str, + inputs: Union[Data, List[Data]], + generation_config: GenerationConfig, + ): + if not isinstance(inputs, list): + inputs = [inputs] + self.__init_handle_by_constructor__( + _ffi_api.Request, # type: ignore # pylint: disable=no-member + request_id, + inputs, + generation_config.asjson(), + ) + + @property + def inputs(self) -> List[Data]: + """The inputs of the request.""" + return _ffi_api.RequestGetInputs(self) # type: ignore # pylint: disable=no-member + + @property + def generation_config(self) -> GenerationConfig: + """The generation config of the request.""" + return GenerationConfig.from_json( + _ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member + ) diff --git a/python/mlc_chat/serve/server/__init__.py b/python/mlc_chat/serve/server/__init__.py new file mode 100644 index 0000000..cd4fce2 --- /dev/null +++ b/python/mlc_chat/serve/server/__init__.py @@ -0,0 +1,3 @@ +"""The server related data structure and tools in MLC LLM serve.""" +from .popen_server import PopenServer +from .server_context import ServerContext diff --git a/python/mlc_chat/serve/server/__main__.py b/python/mlc_chat/serve/server/__main__.py new file mode 100644 index 0000000..e57e9f4 --- /dev/null +++ b/python/mlc_chat/serve/server/__main__.py @@ -0,0 +1,71 @@ +"""Entrypoint of RESTful HTTP request server in MLC LLM""" +import argparse +import json + +import fastapi +import uvicorn +from fastapi.middleware.cors import CORSMiddleware + +from .. import async_engine, config +from .server_context import ServerContext + + +def parse_args_and_initialize() -> argparse.Namespace: + """Parse the server arguments and initialize the engine.""" + + args = argparse.ArgumentParser() # pylint: disable=redefined-outer-name + args.add_argument("--model", type=str, required=True) + args.add_argument("--model-lib-path", type=str, required=True) + args.add_argument("--device", type=str, default="auto") + args.add_argument("--max-batch-size", type=int, default=80) + args.add_argument("--max-total-seq-length", type=int) + args.add_argument("--prefill-chunk-size", type=int) + args.add_argument("--enable-tracing", action="store_true") + + args.add_argument("--host", type=str, default="127.0.0.1", help="host name") + args.add_argument("--port", type=int, default=8000, help="port") + args.add_argument("--allow-credentials", action="store_true", help="allow credentials") + args.add_argument("--allowed-origins", type=json.loads, default=["*"], help="allowed origins") + args.add_argument("--allowed-methods", type=json.loads, default=["*"], help="allowed methods") + args.add_argument("--allowed-headers", type=json.loads, default=["*"], help="allowed headers") + + parsed = args.parse_args() + + # Initialize model loading info and KV cache config + model_info = async_engine.ModelInfo( + model=parsed.model, + model_lib_path=parsed.model_lib_path, + device=parsed.device, + ) + kv_cache_config = config.KVCacheConfig( + max_num_sequence=parsed.max_batch_size, + max_total_sequence_length=parsed.max_total_seq_length, + prefill_chunk_size=parsed.prefill_chunk_size, + ) + # Create engine and start the background loop + engine = async_engine.AsyncThreadedEngine( + model_info, kv_cache_config, enable_tracing=parsed.enable_tracing + ) + + ServerContext.add_model(parsed.model, engine) + return parsed + + +if __name__ == "__main__": + # Parse the arguments and initialize the asynchronous engine. + args: argparse.Namespace = parse_args_and_initialize() + app = fastapi.FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Include the routers from subdirectories. + from ..entrypoints import debug_entrypoints, openai_entrypoints + + app.include_router(openai_entrypoints.app) + app.include_router(debug_entrypoints.app) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/mlc_chat/serve/server/popen_server.py b/python/mlc_chat/serve/server/popen_server.py new file mode 100644 index 0000000..09e4688 --- /dev/null +++ b/python/mlc_chat/serve/server/popen_server.py @@ -0,0 +1,119 @@ +"""The MLC LLM server launched in a subprocess.""" +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import psutil +import requests + + +class PopenServer: # pylint: disable=too-many-instance-attributes + """The wrapper of MLC LLM server, which runs the server in + a background subprocess.""" + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + device: str = "auto", + *, + max_batch_size: int = 80, + max_total_sequence_length: Optional[int] = None, + enable_tracing: bool = False, + host: str = "127.0.0.1", + port: int = 8000, + ) -> None: + """Please check out `python/mlc_chat/serve/server/__main__.py` + for the server arguments.""" + self.model = model + self.model_lib_path = model_lib_path + self.device = device + self.max_batch_size = max_batch_size + self.max_total_sequence_length = max_total_sequence_length + self.enable_tracing = enable_tracing + self.host = host + self.port = port + self._proc: Optional[subprocess.Popen] = None + + def start(self) -> None: + """Launch the server in a popen subprocess. + Wait until the server becomes ready before return. + """ + cmd = [sys.executable] + cmd += ["-m", "mlc_chat.serve.server"] + cmd += ["--model", self.model] + cmd += ["--model-lib-path", self.model_lib_path] + cmd += ["--device", self.device] + cmd += ["--max-batch-size", str(self.max_batch_size)] + if self.max_total_sequence_length is not None: + cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] + if self.enable_tracing: + cmd += ["--enable-tracing"] + + cmd += ["--host", self.host] + cmd += ["--port", str(self.port)] + process_path = str(Path(__file__).resolve().parents[4]) + self._proc = subprocess.Popen(cmd, cwd=process_path) # pylint: disable=consider-using-with + # NOTE: DO NOT USE `stdout=subprocess.PIPE, stderr=subprocess.PIPE` + # in subprocess.Popen here. PIPE has a fixed-size buffer with may block + # and hang forever. + + # Try to query the server until it is ready. + openai_v1_models_url = "http://127.0.0.1:8000/v1/models" + query_result = None + timeout = 60 + attempts = 0 + while query_result is None and attempts < timeout: + try: + query_result = requests.get(openai_v1_models_url, timeout=60) + except: # pylint: disable=bare-except + attempts += 1 + time.sleep(1) + + # Check if the subprocess terminates unexpectedly or + # the queries reach the timeout. + process_return_code = self._proc.poll() + if process_return_code is not None: + raise RuntimeError( + "The server fails to launch. " + f'Please check if "{self.model}" is a valid model compiled by MLC LLM.' + ) + if attempts == timeout: + self.terminate() + raise RuntimeError(f"The server fails to launch in {timeout} seconds.") + + def terminate(self) -> None: + """Terminate the server subprocess.""" + if self._proc is None: + return + + # Kill all the child processes. + def kill_child_processes(): + try: + parent = psutil.Process(self._proc.pid) + children = parent.children(recursive=True) + except psutil.NoSuchProcess: + return + + for process in children: + try: + process.kill() + except psutil.NoSuchProcess: + pass + + kill_child_processes() + + # Kill the process. + try: + self._proc.kill() + except OSError: + pass + + # Join the process to avoid zombies. + try: + self._proc.wait(timeout=10.0) + except subprocess.TimeoutExpired: + pass + self._proc = None diff --git a/python/mlc_chat/serve/server/server_context.py b/python/mlc_chat/serve/server/server_context.py new file mode 100644 index 0000000..d382bb7 --- /dev/null +++ b/python/mlc_chat/serve/server/server_context.py @@ -0,0 +1,47 @@ +"""Server context that shared by multiple entrypoint files.""" + +from typing import Dict, List, Optional + +from ...conversation_template import ConvTemplateRegistry +from ...protocol.conversation_protocol import Conversation +from .. import async_engine + + +class ServerContext: + """The global server context, including the running models + and corresponding async engines. + """ + + _models: Dict[str, async_engine.AsyncThreadedEngine] = {} + _conv_templates: Dict[str, Conversation] = {} + + @staticmethod + def add_model(hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: + """Add a new model to the server context together with the engine.""" + if hosted_model in ServerContext._models: + raise RuntimeError(f"Model {hosted_model} already running.") + ServerContext._models[hosted_model] = engine + + # Get the conversation template. + if engine.conv_template_name is not None: + conv_template = ConvTemplateRegistry.get_conv_template(engine.conv_template_name) + if conv_template is not None: + ServerContext._conv_templates[hosted_model] = conv_template + + @staticmethod + def get_engine(model: str) -> Optional[async_engine.AsyncThreadedEngine]: + """Get the async engine of the requested model.""" + return ServerContext._models.get(model, None) + + @staticmethod + def get_conv_template(model: str) -> Optional[Conversation]: + """Get the conversation template of the requested model.""" + conv_template = ServerContext._conv_templates.get(model, None) + if conv_template is not None: + return conv_template.model_copy(deep=True) + return None + + @staticmethod + def get_model_list() -> List[str]: + """Get the list of models on serve.""" + return list(ServerContext._models.keys()) diff --git a/python/mlc_chat/streamer.py b/python/mlc_chat/streamer.py new file mode 100644 index 0000000..1eb88af --- /dev/null +++ b/python/mlc_chat/streamer.py @@ -0,0 +1,84 @@ +"""Streamers in MLC LLM.""" + +from typing import List, Union + +import tvm +import tvm._ffi +from tvm.runtime import Object, ShapeTuple + +from . import _ffi_api +from .tokenizer import Tokenizer + + +@tvm._ffi.register_object("mlc.TextStreamer") # pylint: disable=protected-access +class TextStreamer(Object): + """The class that streams back validated utf-8 text strings + that generated by tokenizer. + """ + + def __init__(self, tokenizer: Tokenizer) -> None: + """Create the text streamer from tokenizer""" + self.__init_handle_by_constructor__( + _ffi_api.TextStreamer, tokenizer # type: ignore # pylint: disable=no-member + ) + + def put(self, delta_tokens: Union[List[int], ShapeTuple]) -> str: + """Put new delta tokens into the streamer, and get the UTF-8-valid + delta string. The text streamer may hold some of the input delta tokens + which cannot decode into valid UTF-8 strings. The returned string + is always guaranteed to be UTF-8 valid. + + Parameters + ---------- + delta_tokens : Union[List[int], ShapeTuple] + The new tokens to put into the streamer. + + Returns + ------- + delta_text : str + The decoded delta string after putting the input new tokens. + """ + if isinstance(delta_tokens, list): + delta_tokens = ShapeTuple(delta_tokens) + return _ffi_api.TextStreamerPut( # type: ignore # pylint: disable=no-member + self, delta_tokens + ) + + def finish(self) -> str: + """Return the string decoded by remaining tokens.""" + return _ffi_api.TextStreamerFinish(self) # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.StopStrHandler") # pylint: disable=protected-access +class StopStrHandler(Object): + """The stop string handler in MLC LLM, which takes input delta tokens + one at a time, and return the output delta token before stopping due to + stop strings.""" + + def __init__(self, stop_strs: List[str], tokenizer: Tokenizer) -> None: + self.__init_handle_by_constructor__( + _ffi_api.StopStrHandler, # type: ignore # pylint: disable=no-member + stop_strs, + tokenizer, + ) + + def put(self, token_id: int) -> List[int]: + """Add new input delta token to the handler, return output + delta tokens before stopping. The stop string handler may hold + some of the input delta token which may be part of a stop string. + The returned tokens are always guaranteed not to be part of stop string. + """ + return list( + _ffi_api.StopStrHandlerPut(self, token_id) # type: ignore # pylint: disable=no-member + ) + + def finish(self) -> List[int]: + """Stop string handling has finished, return remaining cached token ids.""" + return list( + _ffi_api.StopStringHandlerFinish(self) # type: ignore # pylint: disable=no-member + ) + + @property + def stop_triggered(self) -> bool: + """Check if the generation has stopped due to stop string.""" + return _ffi_api.StopStrHandlerStopTriggered(self) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_chat/support/__init__.py b/python/mlc_chat/support/__init__.py new file mode 100644 index 0000000..ca5d7a6 --- /dev/null +++ b/python/mlc_chat/support/__init__.py @@ -0,0 +1,4 @@ +""" +Common utilities used in the Python package. Do not import anything by default, +as they may introduce unnecessary dependencies. +""" diff --git a/python/mlc_chat/support/argparse.py b/python/mlc_chat/support/argparse.py new file mode 100644 index 0000000..81211e8 --- /dev/null +++ b/python/mlc_chat/support/argparse.py @@ -0,0 +1,15 @@ +"""An enhanced argument parser for mlc-chat.""" +import argparse +import sys + + +class ArgumentParser(argparse.ArgumentParser): + """An enhanced argument parser for mlc-chat.""" + + def error(self, message): + """Overrides the behavior when erroring out""" + print("-" * 25 + " Usage " + "-" * 25) + self.print_help() + print("-" * 25 + " Error " + "-" * 25) + print(message, file=sys.stderr) + sys.exit(2) diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py new file mode 100644 index 0000000..a5b73b7 --- /dev/null +++ b/python/mlc_chat/support/auto_config.py @@ -0,0 +1,192 @@ +"""Help function for detecting the model configuration file `config.json`""" +import json +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING + +from . import logging +from .style import bold, green + +if TYPE_CHECKING: + from mlc_chat.model import Model # pylint: disable=unused-import + from mlc_chat.quantization import Quantization # pylint: disable=unused-import + + +logger = logging.getLogger(__name__) + +FOUND = green("Found") + + +def detect_mlc_chat_config(mlc_chat_config: str) -> Path: + """Detect and return the path that points to mlc-chat-config.json. + If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it. + + Parameters + --------- + mlc_chat_config : str + The path to `mlc-chat-config.json`, or the directory containing + `mlc-chat-config.json`. + + Returns + ------- + mlc_chat_config_json_path : pathlib.Path + The path points to mlc_chat_config.json. + """ + # pylint: disable=import-outside-toplevel + from mlc_chat.model import MODEL_PRESETS + + from .download import download_mlc_weights + + # pylint: enable=import-outside-toplevel + + if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"): + mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config)) + elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS: + logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config) + content = MODEL_PRESETS[mlc_chat_config].copy() + content["model_preset_tag"] = mlc_chat_config + temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with + suffix=".json", + delete=False, + ) + logger.info("Dumping config to: %s", temp_file.name) + mlc_chat_config_path = Path(temp_file.name) + with mlc_chat_config_path.open("w", encoding="utf-8") as mlc_chat_config_file: + json.dump(content, mlc_chat_config_file, indent=2) + else: + mlc_chat_config_path = Path(mlc_chat_config) + if not mlc_chat_config_path.exists(): + raise ValueError(f"{mlc_chat_config_path} does not exist.") + + if mlc_chat_config_path.is_dir(): + # search mlc-chat-config.json under path + mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" + if not mlc_chat_config_json_path.exists(): + raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + else: + mlc_chat_config_json_path = mlc_chat_config_path + + logger.info("%s model configuration: %s", FOUND, mlc_chat_config_json_path) + return mlc_chat_config_json_path + + +def detect_config(config: str) -> Path: + """Detect and return the path that points to config.json. If `config` is a directory, + it looks for config.json below it. + + Parameters + --------- + config : str + The preset name of the model, or the path to `config.json`, or the directory containing + `config.json`. + + Returns + ------- + config_json_path : pathlib.Path + The path points to config.json. + """ + from mlc_chat.model import MODEL_PRESETS # pylint: disable=import-outside-toplevel + + if isinstance(config, str) and config in MODEL_PRESETS: + logger.info("%s preset model: %s", FOUND, config) + content = MODEL_PRESETS[config].copy() + content["model_preset_tag"] = config + temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with + suffix=".json", + delete=False, + ) + logger.info("Dumping config to: %s", temp_file.name) + config_path = Path(temp_file.name) + with config_path.open("w", encoding="utf-8") as config_file: + json.dump(content, config_file, indent=2) + else: + config_path = Path(config) + if not config_path.exists(): + raise ValueError(f"{config_path} does not exist.") + + if config_path.is_dir(): + # search config.json under config path + config_json_path = config_path / "config.json" + if not config_json_path.exists(): + raise ValueError(f"Fail to find config.json under {config_path}.") + else: + config_json_path = config_path + + logger.info("%s model configuration: %s", FOUND, config_json_path) + return config_json_path + + +def detect_model_type(model_type: str, config: Path) -> "Model": + """Detect the model type from the configuration file. If `model_type` is "auto", it will be + inferred from the configuration file. Otherwise, it will be used as the model type, and sanity + check will be performed. + + Parameters + ---------- + model_type : str + The model type, for example, "llama". + + config : pathlib.Path + The path to config.json. + + Returns + ------- + model : mlc_chat.compiler.Model + The model type. + """ + + from mlc_chat.model import MODELS # pylint: disable=import-outside-toplevel + + if model_type == "auto": + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if "model_type" not in cfg and ( + "model_config" not in cfg or "model_type" not in cfg["model_config"] + ): + raise ValueError( + f"'model_type' not found in: {config}. " + f"Please explicitly specify `--model-type` instead." + ) + model_type = cfg["model_type"] if "model_type" in cfg else cfg["model_config"]["model_type"] + if model_type in ["mixformer-sequential"]: + model_type = "phi-msft" + logger.info("%s model type: %s. Use `--model-type` to override.", FOUND, bold(model_type)) + if model_type not in MODELS: + raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") + return MODELS[model_type] + + +def detect_quantization(quantization_arg: str, config: Path) -> "Quantization": + """Detect the model quantization scheme from the configuration file or `--quantization` + argument. If `--quantization` is provided, it will override the value on the configuration + file. + + Parameters + ---------- + quantization_arg : str + The quantization scheme, for example, "q4f16_1". + + config : pathlib.Path + The path to mlc-chat-config.json. + + Returns + ------- + quantization : mlc_chat.quantization.Quantization + The model quantization scheme. + """ + from mlc_chat.quantization import ( # pylint: disable=import-outside-toplevel + QUANTIZATION, + ) + + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if quantization_arg is not None: + quantization = QUANTIZATION[quantization_arg] + elif "quantization" in cfg: + quantization = QUANTIZATION[cfg["quantization"]] + else: + raise ValueError( + f"'quantization' not found in: {config}. " + f"Please explicitly specify `--quantization` instead." + ) + return quantization diff --git a/python/mlc_chat/support/auto_device.py b/python/mlc_chat/support/auto_device.py new file mode 100644 index 0000000..6d18de4 --- /dev/null +++ b/python/mlc_chat/support/auto_device.py @@ -0,0 +1,87 @@ +"""Automatic detection of the device available on the local machine.""" +import subprocess +import sys +from typing import Dict, Optional + +import tvm +from tvm.runtime import Device + +from . import logging +from .style import bold, green, red + +FOUND = green("Found") +NOT_FOUND = red("Not found") +AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"] +_RESULT_CACHE: Dict[str, bool] = {} + + +logger = logging.getLogger(__name__) + + +def detect_device(device_hint: str) -> Optional[Device]: + """Detect locally available device from string hint.""" + if device_hint == "auto": + device = None + for device_type in AUTO_DETECT_DEVICES: + cur_device = tvm.device(dev_type=device_type, dev_id=0) + if _device_exists(cur_device): + if device is None: + device = cur_device + if device is None: + logger.info("%s: No available device detected", NOT_FOUND) + return None + logger.info("Using device: %s", bold(device2str(device))) + return device + try: + device = tvm.device(device_hint) + except Exception as err: + raise ValueError(f"Invalid device name: {device_hint}") from err + if not _device_exists(device): + raise ValueError(f"Device is not found on your local environment: {device_hint}") + return device + + +def device2str(device: Device) -> str: + """Convert a TVM device object to string.""" + return f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}" + + +def _device_exists(device: Device) -> bool: + device_type = tvm.runtime.Device.MASK2STR[device.device_type] + device_str = device2str(device) + if device_str in _RESULT_CACHE: + return _RESULT_CACHE[device_str] + cmd = [ + sys.executable, + "-m", + "mlc_chat.cli.check_device", + device_type, + ] + prefix = "check_device:" + subproc_outputs = [ + line[len(prefix) :].strip() + for line in subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + .stdout.strip() + .splitlines() + if line.startswith(prefix) + ] + if subproc_outputs: + if subproc_outputs[0]: + for i in subproc_outputs[0].split(","): + logger.info("%s device: %s:%s", FOUND, device_type, i) + _RESULT_CACHE[f"{device_type}:{i}"] = True + else: + logger.error( + "GPU device detection failed. Please report this issue with the output of command: %s", + " ".join(cmd), + ) + if device_str in _RESULT_CACHE: + return _RESULT_CACHE[device_str] + logger.info("%s device: %s", NOT_FOUND, device_str) + _RESULT_CACHE[device_str] = False + return False diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py new file mode 100644 index 0000000..80041db --- /dev/null +++ b/python/mlc_chat/support/auto_target.py @@ -0,0 +1,398 @@ +"""Helper functions for target auto-detection.""" +import os +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple + +from tvm import IRModule, relax +from tvm._ffi import get_global_func, register_func +from tvm.contrib import ndk, tar, xcode +from tvm.ir.transform import Pass +from tvm.target import Target + +from . import logging +from .auto_device import AUTO_DETECT_DEVICES, detect_device, device2str +from .constants import MLC_MULTI_ARCH +from .style import bold, green, red + +if TYPE_CHECKING: + from mlc_chat.compiler.compile import CompileArgs + + +logger = logging.getLogger(__name__) + +# TODO: add help message on how to specify the target manually # pylint: disable=fixme +HELP_MSG = """TBD""" +FOUND = green("Found") +NOT_FOUND = red("Not found") +BuildFunc = Callable[[IRModule, "CompileArgs", Pass], None] + + +def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[Target, BuildFunc]: + """Detect the configuration for the target device and its host, for example, target GPU and + the host CPU. + + Parameters + ---------- + target_hint : str + The hint for the target device. + + host_hint : str + The hint for the host CPU, default is "auto". + """ + target, build_func = _detect_target_gpu(target_hint) + if target.host is None: + target = Target(target, host=_detect_target_host(host_hint)) + if target.kind.name == "cuda": + _register_cuda_hook(target) + return target, build_func + + +def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: + if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: + hint += ":generic" + if hint == "auto" or hint in AUTO_DETECT_DEVICES: + target: Optional[Target] = None + device = detect_device(hint) + if device is not None: + device_str = device2str(device) + try: + target = Target.from_device(device) + except ValueError: + logger.info("%s: Cannot detect target from device: %s", NOT_FOUND, device_str) + if target is None: + raise ValueError(f"No target detected from device: {hint}. Please specify explicitly") + logger.info( + '%s configuration of target device "%s": %s', + FOUND, + bold(device_str), + target.export(), + ) + return target, _build_default() + if hint in PRESET: + preset = PRESET[hint] + target = Target(preset["target"]) # type: ignore[index] + build = preset.get("build", _build_default) # type: ignore[attr-defined] + return target, build() + if _is_device(hint): + logger.info("Detecting target device: %s", hint) + target = Target.from_device(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + try: + logger.info("Try creating device target from string: %s", hint) + target = Target(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + except Exception as err: + logger.info("%s: Failed to create target", NOT_FOUND) + raise ValueError(f"Invalid target: {hint}") from err + + +def _detect_target_host(hint: str) -> Target: + """Detect the host CPU architecture.""" + if hint == "auto": + target_triple = get_global_func("tvm.codegen.llvm.GetDefaultTargetTriple")() + target = Target.from_device("cpu") + logger.info("%s host LLVM triple: %s", FOUND, bold(target.attrs["mtriple"])) + logger.info("%s host LLVM CPU: %s", FOUND, bold(target.attrs["mcpu"])) + return target + target_triple = hint + logger.info("Using LLVM triple specified by --host: %s", bold(target_triple)) + return Target({"kind": "llvm", "mtriple": target_triple}) + + +def _is_device(device: str): + if " " in device: + return False + if device.count(":") != 1: + return False + return True + + +def _add_system_lib_prefix(mod: IRModule, prefix: str, is_system_lib: bool) -> IRModule: + if is_system_lib and prefix: + mod = mod.with_attrs({"system_lib_prefix": prefix}) # type: ignore[dict-item] + elif is_system_lib: + logger.warning( + "%s is not specified when building a static library", + bold("--system-lib-prefix"), + ) + elif prefix: + logger.warning( + "--system-lib-prefix is specified, but it will not take any effect " + "when building the shared library" + ) + return mod + + +def _build_metal_x86_64(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=False) + assert output.suffix == ".dylib" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + ).export_library( + str(output), + fcompile=xcode.create_dylib, + sdk="macosx", + arch="x86_64", + ) + + return build + + +def _build_iphone(): + @register_func("tvm_callback_metal_compile", override=True) + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True) + assert output.suffix == ".tar" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_android(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True) + assert output.suffix == ".tar" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_webgpu(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True) + assert output.suffix == ".wasm" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=True, + ).export_library( + str(output), + ) + + return build + + +def _build_mali(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True) + assert output.suffix == ".so" + mod = relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=True, + ) + if "TVM_NDK_CC" in os.environ: + mod.export_library(str(output), fcompile=ndk.create_shared) + else: + mod.export_library(str(output)) + + return build + + +def _build_default(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + if output.suffix in [".tar", ".lib"]: + system_lib = True + elif output.suffix in [".so", ".dylib", ".dll"]: + system_lib = False + else: + logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix) + system_lib = False + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=system_lib) + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=system_lib, + ).export_library( + str(output), + ) + + return build + + +def detect_cuda_arch_list(target: Target) -> List[int]: + """Detect the CUDA architecture list from the target.""" + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" + if MLC_MULTI_ARCH is not None: + multi_arch = [int(x.strip()) for x in MLC_MULTI_ARCH.split(",")] + else: + assert target.arch.startswith("sm_") + multi_arch = [int(target.arch[3:])] + multi_arch = list(set(multi_arch)) + return multi_arch + + +def _register_cuda_hook(target: Target): + if MLC_MULTI_ARCH is None: + default_arch = target.attrs.get("arch", None) + logger.info("Generating code for CUDA architecture: %s", bold(default_arch)) + logger.info( + "To produce multi-arch fatbin, set environment variable %s. " + "Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90", + bold("MLC_MULTI_ARCH"), + ) + multi_arch = None + else: + logger.info("%s %s: %s", FOUND, bold("MLC_MULTI_ARCH"), MLC_MULTI_ARCH) + multi_arch = [int(x.strip()) for x in MLC_MULTI_ARCH.split(",")] + logger.info("Generating code for CUDA architecture: %s", multi_arch) + + @register_func("tvm_callback_cuda_compile", override=True) + def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument + """use nvcc to generate fatbin code for better optimization""" + from tvm.contrib import nvcc # pylint: disable=import-outside-toplevel + + if multi_arch is None: + ptx = nvcc.compile_cuda(code, target_format="fatbin") + else: + arch = [] + for compute_version in multi_arch: + arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] + ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch) + return ptx + + +def detect_system_lib_prefix( + target_hint: str, prefix_hint: str, model_name: str, quantization: str +) -> str: + """Detect the iOS / Android system lib prefix to identify the library needed to load the app. + + Parameters + ---------- + target_hint : str + The hint for the target device. + + prefix_hint : str + The hint for the system lib prefix. + """ + if prefix_hint == "auto" and target_hint in ["iphone", "android"]: + prefix = f"{model_name}_{quantization}_".replace("-", "_") + logger.warning( + "%s is automatically picked from the filename, %s, this allows us to use the filename " + "as the model_lib in android/iOS builds. Please avoid renaming the .tar file when " + "uploading the prebuilt.", + bold("--system-lib-prefix"), + bold(prefix), + ) + return prefix + if target_hint not in ["iphone", "android"]: + return "" + return prefix_hint + + +PRESET = { + "iphone:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + "host": { + "kind": "llvm", + "mtriple": "arm64-apple-darwin", + }, + }, + "build": _build_iphone, + }, + "android:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "metal:x86-64": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + "build": _build_metal_x86_64, + }, + "webgpu:generic": { + "target": { + "kind": "webgpu", + "host": { + "kind": "llvm", + "mtriple": "wasm32-unknown-unknown-wasm", + }, + }, + "build": _build_webgpu, + }, + "opencl:generic": { + "target": { + "kind": "opencl", + }, + }, + "mali:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-gnu", + }, + }, + "build": _build_mali, + }, + "metal:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + }, + "vulkan:generic": { + "target": { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1, + }, + }, +} diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py new file mode 100644 index 0000000..84d8621 --- /dev/null +++ b/python/mlc_chat/support/auto_weight.py @@ -0,0 +1,177 @@ +"""Help functions for detecting weight paths and weight formats.""" +import json +from pathlib import Path +from typing import List, Optional, Tuple + +from . import logging +from .style import bold, green, red + +logger = logging.getLogger(__name__) + +FOUND = green("Found") +NOT_FOUND = red("Not found") + + +def detect_weight( + weight_path: Path, + config_json_path: Path, + weight_format: str = "auto", +) -> Tuple[Path, str]: + """Detect the weight directory, and detect the weight format. + + Parameters + --------- + weight_path : pathlib.Path + The path to weight files. If `weight_path` is not None, check if it exists. Otherwise, find + `weight_path` in `config.json` or use the same directory as `config.json`. + + config_json_path: pathlib.Path + The path to `config.json`. + + weight_format : str + The hint for the weight format. If it is "auto", guess the weight format. + Otherwise, check the weights are in that format. + Available weight formats: + - auto (guess the weight format) + - huggingface-torch (validate via checking pytorch_model.bin.index.json) + - huggingface-safetensor (validate via checking model.safetensors.index.json) + - awq + - ggml + - gguf + + Returns + ------- + weight_config_path : pathlib.Path + The path that points to the weights config file or the weights directory. + + weight_format : str + The valid weight format. + """ + if weight_path is None: + assert ( + config_json_path is not None and config_json_path.exists() + ), "Please provide config.json path." + + # 1. Find the weight_path in config.json + with open(config_json_path, encoding="utf-8") as i_f: + config = json.load(i_f) + if "weight_path" in config: + weight_path = Path(config["weight_path"]) + logger.info('Found "weight_path" in config.json: %s', weight_path) + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + else: + # 2. Find the weights file in the same directory as config.json + weight_path = config_json_path.parent + else: + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + + logger.info("Finding weights in: %s", weight_path) + + # check weight format + # weight_format = "auto", guess the weight format. + # otherwise, check the weight format is valid. + if weight_format == "auto": + return _guess_weight_format(weight_path) + + if weight_format not in AVAILABLE_WEIGHT_FORMAT: + raise ValueError( + f"Available weight format list: {AVAILABLE_WEIGHT_FORMAT}, but got {weight_format}" + ) + if weight_format in CHECK_FORMAT_METHODS: + check_func = CHECK_FORMAT_METHODS[weight_format] + weight_config_path = check_func(weight_path) + if not weight_config_path: + raise ValueError(f"The weight is not in {weight_format} format.") + else: + weight_config_path = weight_path + return weight_config_path, weight_format + + +def _guess_weight_format(weight_path: Path) -> Tuple[Path, str]: + possible_formats: List[Tuple[Path, str]] = [] + for weight_format, check_func in CHECK_FORMAT_METHODS.items(): + weight_config_path = check_func(weight_path) + if weight_config_path: + possible_formats.append((weight_config_path, weight_format)) + + if len(possible_formats) == 0: + raise ValueError( + "Fail to detect source weight format. " + "Use `--source-format` to explicitly specify the format." + ) + + weight_config_path, selected_format = possible_formats[0] + logger.info( + "Using source weight configuration: %s. Use `--source` to override.", + bold(str(weight_config_path)), + ) + logger.info( + "Using source weight format: %s. Use `--source-format` to override.", + bold(selected_format), + ) + return weight_config_path, selected_format + + +def _check_pytorch(weight_path: Path) -> Optional[Path]: + pytorch_json_path = weight_path / "pytorch_model.bin.index.json" + if pytorch_json_path.exists(): + logger.info( + "%s source weight format: huggingface-torch. Source configuration: %s", + FOUND, + pytorch_json_path, + ) + return pytorch_json_path + + pytorch_file_path = weight_path / "pytorch_model.bin" + if pytorch_file_path.exists(): + logger.info( + "%s source weight format: huggingface-torch. Source configuration: %s", + FOUND, + pytorch_file_path, + ) + return pytorch_file_path + + logger.info("%s Huggingface PyTorch", NOT_FOUND) + return None + + +def _check_safetensor(weight_path: Path) -> Optional[Path]: + safetensor_json_path = weight_path / "model.safetensors.index.json" + if safetensor_json_path.exists(): + logger.info( + "%s source weight format: huggingface-safetensor. Source configuration: %s", + FOUND, + safetensor_json_path, + ) + return safetensor_json_path + + safetensor_file_path = weight_path / "model.safetensors" + if safetensor_file_path.exists(): + from safetensors.torch import ( # pylint: disable=import-outside-toplevel,import-error + load_file, + ) + + weights = load_file(safetensor_file_path, device="cpu") + weight_map = {key: "model.safetensors" for key in weights} + with open(safetensor_json_path, "w", encoding="utf-8") as file: + json.dump({"weight_map": weight_map}, file, indent=2) + logger.info( + "%s source weight format: huggingface-safetensor. Source configuration: %s", + FOUND, + safetensor_json_path, + ) + return safetensor_json_path + + logger.info("%s Huggingface Safetensor", NOT_FOUND) + return None + + +CHECK_FORMAT_METHODS = { + "huggingface-torch": _check_pytorch, + "huggingface-safetensor": _check_safetensor, +} + +# "ggml", "gguf" are not supported yet. +AVAILABLE_WEIGHT_FORMAT = ["huggingface-torch", "huggingface-safetensor", "awq"] diff --git a/python/mlc_chat/support/config.py b/python/mlc_chat/support/config.py new file mode 100644 index 0000000..e3ccfce --- /dev/null +++ b/python/mlc_chat/support/config.py @@ -0,0 +1,113 @@ +""" +A common base class for configuration. A configuration could be initialized from its constructor, +a JSON string or a JSON file, and irrelevant fields during initialization are automatically moved +to the `kwargs` field. + +Take model configuration as an example: it is usually a JSON file in HuggingFace that contains +the model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following +[JSON file](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json). +The base class allows us to load the configuration from this JSON file, moving irrelevant fields +into `kwargs`, such as `transformers_version` and `use_cache`. +""" +# pylint: disable=too-few-public-methods +import dataclasses +import json +from pathlib import Path +from typing import Any, Dict, Type, TypeVar + +from . import logging +from .style import bold, red + +logger = logging.getLogger(__name__) + +ConfigClass = TypeVar("ConfigClass", bound="ConfigBase") + + +@dataclasses.dataclass +class ConfigBase: + """Base class for configurations, providing a common interface for loading configs from a + JSON file or a dict. It requires the subclasses to be dataclasses, and has an `kwargs` field + that stores the extra fields that are not defined in the dataclass. + """ + + @classmethod + def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: + """Create a config object from a dictionary. + + Parameters + ---------- + source : Dict[str, Any] + Source to create config from, usually loaded from `config.json` in HuggingFace style. + + Returns + ------- + cfg : ConfigClass + An instance of the config object. + """ + field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] + fields = {k: v for k, v in source.items() if k in field_names} + kwargs = {k: v for k, v in source.items() if k not in field_names} + return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] + + @classmethod + def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: + """Create a config object from a file. + + Parameters + ---------- + cfg_cls : Type[ConfigClass] + The config class to create, for example, LlamaConfig. + + source : pathlib.Path + Path to the source file, usually `config.json` in HuggingFace repo. + + Returns + ------- + cfg : ConfigClass + An instance of the config object. + """ + with source.open("r", encoding="utf-8") as in_file: + return cls.from_dict(json.load(in_file)) + + def asdict(self): + """Convert the config object to a dictionary. + + Returns + ------- + Dict[str, Any] + A dictionary representation of the config object. + """ + result = dataclasses.asdict(self) + result.pop("kwargs") + return result + + +class ConfigOverrideBase: + """Base class for ConfigOverride, providing a common interface for overriding configs. + It requires the subclasses to be dataclasses. + """ + + def apply(self, config): + """Apply the overrides to the given config.""" + updated = config.asdict() + for field in dataclasses.fields(self): + key = field.name + value = getattr(self, key) + if value is None: + continue + if key not in updated: + logger.warning( + "%s: Cannot override %s, because %s does not have this field", + red("Warning"), + bold(key), + bold(type(config).__name__), + ) + else: + logger.info( # pylint: disable=logging-fstring-interpolation + f"Overriding {bold(key)} from {updated[key]} to {value}" + ) + updated[key] = value + return type(config).from_dict(updated) + + +__all__ = ["ConfigBase", "ConfigOverrideBase"] diff --git a/python/mlc_chat/support/constants.py b/python/mlc_chat/support/constants.py new file mode 100644 index 0000000..09e4893 --- /dev/null +++ b/python/mlc_chat/support/constants.py @@ -0,0 +1,55 @@ +"""Environment variables used by the MLC LLM.""" +import os +import sys +from pathlib import Path + + +def _check(): + if MLC_JIT_POLICY not in ["ON", "OFF", "REDO", "READONLY"]: + raise ValueError( + 'Invalid MLC_JIT_POLICY. It has to be one of "ON", "OFF", "REDO", "READONLY"' + f"but got {MLC_JIT_POLICY}." + ) + + +def _get_cache_dir() -> Path: + if "MLC_CACHE_DIR" in os.environ: + result = Path(os.environ["MLC_CACHE_DIR"]) + elif sys.platform == "win32": + result = Path(os.environ["LOCALAPPDATA"]) + result = result / "mlc_chat" + elif os.getenv("XDG_CACHE_HOME", None) is not None: + result = Path(os.getenv("XDG_CACHE_HOME")) + result = result / "mlc_chat" + else: + result = Path(os.path.expanduser("~/.cache")) + result = result / "mlc_chat" + result.mkdir(parents=True, exist_ok=True) + if not result.is_dir(): + raise ValueError( + f"The default cache directory is not a directory: {result}. " + "Use environment variable MLC_CACHE_DIR to specify a valid cache directory." + ) + (result / "model_weights").mkdir(parents=True, exist_ok=True) + (result / "model_lib").mkdir(parents=True, exist_ok=True) + return result + + +def _get_dso_suffix() -> str: + if "MLC_DSO_SUFFIX" in os.environ: + return os.environ["MLC_DSO_SUFFIX"] + if sys.platform == "win32": + return "dll" + if sys.platform == "darwin": + return "dylib" + return "so" + + +MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) +MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None) +MLC_CACHE_DIR: Path = _get_cache_dir() +MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON") +MLC_DSO_SUFFIX = _get_dso_suffix() + + +_check() diff --git a/python/mlc_chat/support/convert_tiktoken.py b/python/mlc_chat/support/convert_tiktoken.py new file mode 100644 index 0000000..9bf0504 --- /dev/null +++ b/python/mlc_chat/support/convert_tiktoken.py @@ -0,0 +1,163 @@ +""" +Adapted from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee +Generator of mlc-chat-config.json and tokenizer configuration. +""" + +# pylint: disable=import-error +# isort: off +import json +import os +from typing import Dict, List, Optional + +from transformers import AutoTokenizer +from transformers.models.gpt2.tokenization_gpt2 import ( + bytes_to_unicode, +) + +byte_encoder = bytes_to_unicode() + + +def token_bytes_to_string(b): + """Convert a token from bytes to a string""" + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + +def bpe( + mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None +) -> List[bytes]: + """Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960""" + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + return parts + + +def generate_vocab_and_merges(encoder, mergeable_ranks): + """Generate vocab and merges in huggingface tokenizers format""" + merges = [] + vocab = {} + for token, rank in mergeable_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + + if len(token) == 1: + continue + merged = tuple(bpe(mergeable_ranks, token, max_rank=rank)) + assert len(merged) == 2 + + merges.append(" ".join(map(token_bytes_to_string, merged))) + + # Also add special tokens + vocab.update(encoder._special_tokens) # pylint: disable=protected-access + + return vocab, merges + + +def convert_tiktoken(model_path, output_dir, context_window_size=None): + """Convert tiktoken tokenizers to huggingface tokenizers style""" + tiktoken_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + encoder = tiktoken_tokenizer.tokenizer + + vocab, merges = generate_vocab_and_merges(encoder, tiktoken_tokenizer.get_vocab()) + + added_tokens = [ + { + "id": id, + "content": content, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True, + } + for content, id in encoder._special_tokens.items() # pylint: disable=protected-access + ] + + tokenizer_template = { + "version": "1.0", + "truncation": None, + "padding": None, + "added_tokens": added_tokens, + "normalizer": None, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": False, + "trim_offsets": True, + "use_regex": True, + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": True, + "trim_offsets": False, + "use_regex": True, + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": True, + "trim_offsets": True, + "use_regex": True, + }, + "model": { + "type": "BPE", + "dropout": None, + "unk_token": None, + "continuing_subword_prefix": "", + "end_of_word_suffix": "", + "fuse_unk": False, + "byte_fallback": False, + "vocab": vocab, + "merges": merges, + }, + } + + tokenizer_config_template = { + "add_prefix_space": False, + "bos_token": "<|endoftext|>", + "clean_up_tokenization_spaces": True, + "eos_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + } + + tokenizer_name = type(tiktoken_tokenizer).__name__ + + tokenizer_config_template["tokenizer_class"] = tokenizer_name + if context_window_size: + tokenizer_config_template["model_max_length"] = context_window_size + tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0])) + + os.makedirs(output_dir, exist_ok=True) + + # Save to files + with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as fp: + json.dump(vocab, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer_config.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "special_tokens_map.json"), "w", encoding="utf-8") as fp: + json.dump( + { + "bos_token": "<|endoftext|>", + "eos_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + }, + fp, + indent=2, + ensure_ascii=False, + ) + + with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as fp: + fp.write("#version: 0.2\n") + fp.write("\n".join(merges)) diff --git a/python/mlc_chat/support/download.py b/python/mlc_chat/support/download.py new file mode 100644 index 0000000..10b1620 --- /dev/null +++ b/python/mlc_chat/support/download.py @@ -0,0 +1,147 @@ +"""Common utilities for downloading files from HuggingFace or other URLs online.""" +import concurrent.futures as cf +import hashlib +import json +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Optional, Tuple + +import requests # pylint: disable=import-error + +from . import logging, tqdm +from .constants import MLC_CACHE_DIR, MLC_TEMP_DIR +from .style import bold + +logger = logging.getLogger(__name__) + + +def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: + if path.exists(): + if force_redo: + logger.info("Deleting existing directory: %s", path) + shutil.rmtree(path) + else: + raise ValueError(f"Directory already exists: {path}") + else: + path.parent.mkdir(parents=True, exist_ok=True) + + +def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: + """Clone a git repository into a directory.""" + repo_name = ".tmp" + command = ["git", "clone", url, repo_name] + _ensure_directory_not_exist(destination, force_redo=False) + try: + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: + logger.info("[Git] Cloning %s to %s", bold(url), destination) + subprocess.run( + command, + env={"GIT_LFS_SKIP_SMUDGE": "1"}, + cwd=tmp_dir, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + git_dir = os.path.join(tmp_dir, repo_name) + if not ignore_lfs: + git_lfs_pull(Path(git_dir)) + shutil.move(git_dir, str(destination)) + except subprocess.CalledProcessError as error: + raise ValueError( + f"Git clone failed with return code {error.returncode}: {error.stderr}. " + f"The command was: {command}" + ) from error + + +def git_lfs_pull(repo_dir: Path) -> None: + """Pull files with Git LFS.""" + filenames = ( + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"], + stderr=subprocess.STDOUT, + ) + .decode("utf-8") + .splitlines() + ) + logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames) + with tqdm.redirect(): + for file in tqdm.tqdm(filenames): + logger.info("[Git LFS] Downloading %s", file) + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "pull", "--include", file], + stderr=subprocess.STDOUT, + ) + + +def download_file( + url: str, + destination: Path, + md5sum: Optional[str], +) -> Tuple[str, Path]: + """Download a file from a URL to a destination file.""" + with requests.get(url, stream=True, timeout=30) as response: + response.raise_for_status() + with destination.open("wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + if md5sum is not None: + hash_md5 = hashlib.md5() + with destination.open("rb") as file: + for chunk in iter(lambda: file.read(8192), b""): + hash_md5.update(chunk) + file_md5 = hash_md5.hexdigest() + if file_md5 != md5sum: + raise ValueError( + f"MD5 checksum mismatch for downloaded file: {destination}. " + f"Expected {md5sum}, got {file_md5}" + ) + return url, destination + + +def download_mlc_weights( # pylint: disable=too-many-locals + model_url: str, + num_processes: int = 4, + force_redo: bool = False, +) -> Path: + """Download weights for a model from the HuggingFace Git LFS repo.""" + prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], "" + mlc_prefix = next(p for p in prefixes if model_url.startswith(p)) + assert mlc_prefix + + git_url_template = "https://huggingface.co/{user}/{repo}.git" + bin_url_template = "https://huggingface.co/{user}/{repo}/resolve/main/{record_name}" + + if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix): + raise ValueError(f"Invalid model URL: {model_url}") + user, repo = model_url[len(mlc_prefix) :].split("/") + git_dir = MLC_CACHE_DIR / "model_weights" / user / repo + try: + _ensure_directory_not_exist(git_dir, force_redo=force_redo) + except ValueError: + logger.info("Weights already downloaded: %s", bold(str(git_dir))) + return git_dir + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix: + tmp_dir = Path(tmp_dir_prefix) / "tmp" + git_url = git_url_template.format(user=user, repo=repo) + git_clone(git_url, tmp_dir, ignore_lfs=True) + shutil.rmtree(tmp_dir / ".git", ignore_errors=True) + with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file: + param_metadata = json.load(in_file)["records"] + with cf.ProcessPoolExecutor(max_workers=num_processes) as executor: + futures = [] + for record in param_metadata: + record_name = record["dataPath"] + file_url = bin_url_template.format(user=user, repo=repo, record_name=record_name) + file_dest = tmp_dir / record_name + file_md5 = record.get("md5sum", None) + futures.append(executor.submit(download_file, file_url, file_dest, file_md5)) + with tqdm.redirect(): + for future in tqdm.tqdm(cf.as_completed(futures), total=len(futures)): + file_url, file_dest = future.result() + logger.info("Downloaded %s to %s", file_url, file_dest) + logger.info("Moving %s to %s", tmp_dir, bold(str(git_dir))) + shutil.move(str(tmp_dir), str(git_dir)) + return git_dir diff --git a/python/mlc_chat/support/logging.py b/python/mlc_chat/support/logging.py new file mode 100644 index 0000000..f2611c7 --- /dev/null +++ b/python/mlc_chat/support/logging.py @@ -0,0 +1,20 @@ +""" +Logging support for MLC. It derives from Python's logging module, and in the future, +it can be easily replaced by other logging modules such as structlog. +""" +import logging + + +def enable_logging(): + """Enable MLC's default logging format""" + logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", + ) + + +def getLogger(name: str): # pylint: disable=invalid-name + """Get a logger according to the given name""" + return logging.getLogger(name) diff --git a/python/mlc_chat/support/preshard.py b/python/mlc_chat/support/preshard.py new file mode 100644 index 0000000..09db02c --- /dev/null +++ b/python/mlc_chat/support/preshard.py @@ -0,0 +1,126 @@ +"""Functions for pre-sharding weights""" +from typing import Any, Dict, List + +from tvm import IRModule +from tvm import dlight as dl +from tvm import relax +from tvm.relax.frontend import nn +from tvm.runtime import Device +from tvm.target import Target + + +def _sharded_param_name(param_name, worker_id): + return f"{param_name}_shard-{worker_id}" + + +def _update_quantize_map( + quantize_map: Any, + named_params: Dict[str, nn.Parameter], + mlc_name: str, + tensor_parallel_shards: int, +): + param_names: List[str] = [mlc_name] + + if mlc_name in quantize_map.param_map: + # the parameter is quantized + quantized_params = quantize_map.param_map[mlc_name] + param_names = quantized_params + quantize_func = quantize_map.map_func[mlc_name] + + for worker_id in range(tensor_parallel_shards): + sharded_mlc_name = _sharded_param_name(mlc_name, worker_id) + quantize_map.param_map[sharded_mlc_name] = [ + _sharded_param_name(param_name, worker_id) for param_name in quantized_params + ] + quantize_map.map_func[sharded_mlc_name] = quantize_func + + for param_name in param_names: + param = named_params.pop(param_name) + for worker_id in range(tensor_parallel_shards): + named_params[_sharded_param_name(param_name, worker_id)] = param + + +def _create_shard_func( + bb: relax.BlockBuilder, param: nn.Parameter, tensor_parallel_shards: int +): # pylint: disable=too-many-locals + shard_strategy = param.attrs.get("shard_strategy", None) + # generate tir shard function + tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param) + tir_func = tir_func.with_attr("global_symbol", f"{shard_strategy.name}_tir") + # add tir shard function to the IRModule + tir_gvar = bb.add_func(tir_func, func_name=f"{shard_strategy.name}_tir") + # create relax function that + # 1. shard weight with tir shard function, result: [num_shards, *sharded_weight_shape] + # 2. split the sharded weight along dim 0, result: num_shards * [1, *sharded_weight_shape] + # 3. squeeze the 0th-dim of all shards, result: num_shards * [*sharded_weight_shape] + weight_shape = param.shape + weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards + sharded_weight_shape = [tensor_parallel_shards, *param.shape] + weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, param.dtype)) + with bb.function(name=shard_strategy.name, params=[weight_var]): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_tir( + tir_gvar, + weight_var, + out_sinfo=relax.TensorStructInfo(sharded_weight_shape, param.dtype), + ) + ) + lv1 = bb.emit(relax.op.split(lv0, indices_or_sections=tensor_parallel_shards, axis=0)) + output_vars = [] + for i in range(tensor_parallel_shards): + lvi = bb.emit(relax.TupleGetItem(lv1, i)) + squeezed_lvi = bb.emit(relax.op.squeeze(lvi, 0)) + output_vars.append(squeezed_lvi) + gv = bb.emit_output(output_vars) + bb.emit_func_output(gv) + + +def _compile_shard_funcs(mod: IRModule, device: Device): + target = Target.from_device(device) + with target: + mod = relax.transform.LegalizeOps()(mod) + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device) + return vm + + +def apply_preshard( + quantize_map: Any, + named_params: Dict[str, nn.Parameter], + tensor_parallel_shards: int, + args: Any, +): + """Update quantize_map and named_params, create shard functions based on shard strategies.""" + model_config = args.model.config.from_file(args.config) + model_config.tensor_parallel_shards = tensor_parallel_shards + model = args.model.model(model_config) + model.to(args.quantization.model_dtype) + + bb = relax.BlockBuilder() + param_to_shard_func = {} + shard_func_names = set() + for name, param in model.state_dict().items(): + shard_strategy = param.attrs.get("shard_strategy", None) + if shard_strategy is not None: + _update_quantize_map(quantize_map, named_params, name, tensor_parallel_shards) + + # create shard functions + param_to_shard_func[name] = shard_strategy.name + if shard_strategy.name not in shard_func_names: + _create_shard_func(bb, param, tensor_parallel_shards) + shard_func_names.add(shard_strategy.name) + + mod = bb.finalize() + vm = _compile_shard_funcs(mod, args.device) + + for name in param_to_shard_func: + param_to_shard_func[name] = vm[param_to_shard_func[name]] + return param_to_shard_func diff --git a/python/mlc_chat/support/random.py b/python/mlc_chat/support/random.py new file mode 100644 index 0000000..0568276 --- /dev/null +++ b/python/mlc_chat/support/random.py @@ -0,0 +1,16 @@ +"""Utility functions for random number generation.""" +import sys + + +def set_global_random_seed(seed): + """Set global random seed for python, numpy, torch and tvm.""" + if "numpy" in sys.modules: + sys.modules["numpy"].random.seed(seed) + if "torch" in sys.modules: + sys.modules["torch"].manual_seed(seed) + if "random" in sys.modules: + sys.modules["random"].seed(seed) + if "tvm" in sys.modules: + set_seed = sys.modules["tvm"].get_global_func("mlc.random.set_seed") + if set_seed: + set_seed(seed) diff --git a/python/mlc_chat/support/style.py b/python/mlc_chat/support/style.py new file mode 100644 index 0000000..5b2272e --- /dev/null +++ b/python/mlc_chat/support/style.py @@ -0,0 +1,62 @@ +"""Printing styles.""" + +from enum import Enum + + +class Styles(Enum): + """Predefined set of styles to be used. + + Reference: + - https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit + - https://stackoverflow.com/a/17303428 + """ + + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +def red(text: str) -> str: + """Return red text.""" + return f"{Styles.RED.value}{text}{Styles.END.value}" + + +def green(text: str) -> str: + """Return green text.""" + return f"{Styles.GREEN.value}{text}{Styles.END.value}" + + +def yellow(text: str) -> str: + """Return yellow text.""" + return f"{Styles.YELLOW.value}{text}{Styles.END.value}" + + +def blue(text: str) -> str: + """Return blue text.""" + return f"{Styles.BLUE.value}{text}{Styles.END.value}" + + +def purple(text: str) -> str: + """Return purple text.""" + return f"{Styles.PURPLE.value}{text}{Styles.END.value}" + + +def cyan(text: str) -> str: + """Return cyan text.""" + return f"{Styles.CYAN.value}{text}{Styles.END.value}" + + +def bold(text: str) -> str: + """Return bold text.""" + return f"{Styles.BOLD.value}{text}{Styles.END.value}" + + +def underline(text: str) -> str: + """Return underlined text.""" + return f"{Styles.UNDERLINE.value}{text}{Styles.END.value}" diff --git a/python/mlc_chat/support/tensor_parallel.py b/python/mlc_chat/support/tensor_parallel.py new file mode 100644 index 0000000..4d58662 --- /dev/null +++ b/python/mlc_chat/support/tensor_parallel.py @@ -0,0 +1,97 @@ +"""Sharding operators for tensor parallelism.""" +import dataclasses +from contextlib import contextmanager +from typing import Any, Dict, List, Optional + +from tvm import te, tir, topi +from tvm.relax.frontend import nn + + +@dataclasses.dataclass +class ShardSingleDim: + """ + Shard a tensor by a single dimension. + + + Parameters + ---------- + name : str + The name of the shard func + + dim : int + The dimension to shard + + segs : Optional[List[int]] + The length of segments along `dim`. Default to None. If specified, + shard a tensor by its "segmented" dimension, where each segment has a different length + and sharded evenly on each worker. + + """ + + name: str + dim: int + segs: Optional[List[int]] = None + + def gen_tir(self, shards: int, weight: nn.Tensor) -> tir.PrimFunc: + """Generate a TIR function that shards the weight tensor by its rows.""" + shape = weight.shape + segs = self.segs or [shape[self.dim]] + assert sum(segs) == shape[self.dim] + w = te.placeholder( + [*shape[: self.dim], shape[self.dim] * shards, *shape[self.dim + 1 :]], + weight.dtype, + name="w", + ) + ws: List[te.Tensor] = [] + offset = 0 + for idx, sub_seg in enumerate(segs): + ws.append( + topi.transpose( + topi.reshape( + te.compute( + (*shape[: self.dim], sub_seg * shards, *shape[self.dim + 1 :]), + lambda *idx: w[ + idx[: self.dim] + + (idx[self.dim] + offset,) # pylint: disable=cell-var-from-loop + + idx[self.dim + 1 :] + ], + name=f"w_{idx}", + ), + (*shape[: self.dim], shards, sub_seg, *shape[self.dim + 1 :]), + ), + [self.dim, *range(self.dim), *range(self.dim + 1, len(shape) + 1)], + ) + ) + offset += sub_seg * shards + o = topi.concatenate(ws, axis=1 + self.dim) + func = te.create_prim_func([w, o]) + return func + + def gen_shard_info(self, shards: int, weight: nn.Tensor) -> Dict[str, Any]: + """Generate shard info for this sharding strategy.""" + return { + "func_name": self.name, + "out_shape": (shards, *weight.shape), + "out_dtype": weight.dtype, + } + + +@contextmanager +def shard_bias(linear: nn.Linear, tensor_parallel_shards: int): + """ + A context manager to shard the bias of a linear into `tensor_parallel_shards` shards. + + + Parameters + ---------- + linear : nn.Linear + The linear layer whose bias would be sharded. + + tensor_parallel_shards : int + The number of shards. + """ + original_bias = linear.bias + if tensor_parallel_shards > 1: + linear.bias = linear.bias / tensor_parallel_shards + yield + linear.bias = original_bias diff --git a/python/mlc_chat/support/tqdm.py b/python/mlc_chat/support/tqdm.py new file mode 100644 index 0000000..9adceca --- /dev/null +++ b/python/mlc_chat/support/tqdm.py @@ -0,0 +1,38 @@ +"""Utils to better use tqdm""" +import contextlib +import inspect +import io + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm as _redirect_logging + + +@contextlib.contextmanager +def _redirect_print(): + old_print = print + + def new_print(*args, **kwargs): + with io.StringIO() as output: + kwargs["file"] = output + kwargs["end"] = "" + old_print(*args, **kwargs) + content = output.getvalue() + tqdm.write(content) + + try: + inspect.builtins.print = new_print + yield + finally: + inspect.builtins.print = old_print + + +@contextlib.contextmanager +def redirect(): + """Redirect tqdm output to logging and print.""" + + with _redirect_logging(): + with _redirect_print(): + yield + + +__all__ = ["tqdm", "redirect"] diff --git a/python/mlc_chat/tokenizer.py b/python/mlc_chat/tokenizer.py new file mode 100644 index 0000000..6158ef4 --- /dev/null +++ b/python/mlc_chat/tokenizer.py @@ -0,0 +1,55 @@ +"""The tokenizer and related tools in MLC LLM. +This tokenizer essentially wraps and binds the HuggingFace tokenizer +library and sentencepiece. +Reference: https://github.com/mlc-ai/tokenizers-cpp +""" +from typing import List + +import tvm +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.Tokenizer") # pylint: disable=protected-access +class Tokenizer(Object): + """The tokenizer class in MLC LLM.""" + + def __init__(self, tokenizer_path: str) -> None: + """Create the tokenizer from tokenizer directory path.""" + self.__init_handle_by_constructor__( + _ffi_api.Tokenizer, tokenizer_path # type: ignore # pylint: disable=no-member + ) + + def encode(self, text: str) -> List[int]: + """Encode text into ids. + + Parameters + ---------- + text : str + The text string to encode. + + Returns + ------- + token_ids : List[int] + The list of encoded token ids. + """ + return list(_ffi_api.TokenizerEncode(self, text)) # type: ignore # pylint: disable=no-member + + def decode(self, token_ids: List[int]) -> str: + """Decode token ids into text. + + Parameters + ---------- + token_ids : List[int] + The token ids to decode to string. + + Returns + ------- + text : str + The decoded text string. + """ + return _ffi_api.TokenizerDecode( # type: ignore # pylint: disable=no-member + self, tvm.runtime.ShapeTuple(token_ids) + ) diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000..f866e9a --- /dev/null +++ b/python/setup.py @@ -0,0 +1,131 @@ +# pylint: disable=invalid-name, exec-used +"""Setup MLC LLM package.""" +import os +import shutil + +from setuptools import find_packages, setup +from setuptools.dist import Distribution + +CURRENT_DIR = os.path.dirname(__file__) +CONDA_BUILD = os.getenv("CONDA_BUILD") is not None + + +def get_lib_path(): + """Get library path, name and version""" + # Directly exec libinfo to get the right setup + libinfo_py = os.path.join(CURRENT_DIR, "./mlc_chat/libinfo.py") + libinfo = {"__file__": libinfo_py} + with open(libinfo_py, "rb") as f: + exec(compile(f.read(), libinfo_py, "exec"), libinfo, libinfo) + version = libinfo["__version__"] + + # conda installs libraries into env instead of packaging with pip + if not CONDA_BUILD: + libs = [ + libinfo["find_lib_path"]("mlc_llm")[0], + libinfo["find_lib_path"]("mlc_llm_module")[0], + ] + else: + libs = None + + return libs, version + + +def git_describe_version(original_version): + """Get git describe version.""" + ver_py = os.path.join(CURRENT_DIR, "..", "version.py") + libver = {"__file__": ver_py} + with open(ver_py, "rb") as f: + exec(compile(f.read(), ver_py, "exec"), libver, libver) + _, gd_version = libver["git_describe_version"]() + if gd_version is not None and gd_version != original_version: + print(f"Use git describe based version {gd_version}") + if gd_version is None: + print(f"Use original version {original_version}") + return original_version + return gd_version + + +LIB_LIST, __version__ = get_lib_path() +__version__ = git_describe_version(__version__) + + +class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + + def has_ext_modules(self): + """Return True for binary distribution.""" + return True + + def is_pure(self): + """Return False for binary distribution.""" + return False + + +def main(): + """The main entrypoint.""" + setup_kwargs = {} + if not CONDA_BUILD: + with open("MANIFEST.in", "w", encoding="utf-8") as fo: + for path in LIB_LIST: + if os.path.isfile(path): + shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) + _, libname = os.path.split(path) + fo.write(f"include mlc_chat/{libname}\n") + setup_kwargs = {"include_package_data": True} + + setup( + name="mlc_chat", + version=__version__, + description="MLC Chat: an universal runtime running LLMs", + url="https://llm.mlc.ai/", + author="MLC LLM Contributors", + license="Apache 2.0", + # See https://pypi.org/classifiers/ + classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + ], + keywords="machine learning", + zip_safe=False, + packages=find_packages(), + entry_points={ + "console_scripts": [ + "mlc_chat = mlc_chat.__main__:main", + ], + }, + package_dir={"mlc_chat": "mlc_chat"}, + install_requires=[ + "fastapi", + "uvicorn", + "shortuuid", + "torch", + "safetensors", + "requests", + "tqdm", + "tiktoken", + "prompt_toolkit", + ], + distclass=BinaryDistribution, + **setup_kwargs, + ) + + def _remove_path(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + if not CONDA_BUILD: + # Wheel cleanup + os.remove("MANIFEST.in") + for path in LIB_LIST: + _, libname = os.path.split(path) + _remove_path(f"mlc_chat/{libname}") + + +main() diff --git a/run_scripts/README.md b/run_scripts/README.md new file mode 100644 index 0000000..2483993 --- /dev/null +++ b/run_scripts/README.md @@ -0,0 +1,36 @@ +# MLC run scripts + +This directory includes scripts for running MLC on different targets. +We support runtime through `expect` scripts, to interact with the `mlc_chat` executable for local and jetson execution. + +Interaction with android/ios apps is happening over [phonelab](https://github.com/brave-experiments/blade-public). +Interaction with jetson is also coordinated with [jetsonlab](https://github.com/brave-experiments/jetsonlab-public). + +## Structure + +```bash +├── run-mlc-chat-cli.exp # MLC expect script +├── run-mlc.sh # Wrapper shell script +└── run_expect_all.sh # Script for running all experiments +``` + +## How to run? + +The `run-mlc.sh` script is the entry point for running experiments locally. Outside of this repo, this is used by `jetsonlab` for automated runtime of benchmarks. However, one can invoke the script manually if they desire. + +``` +./run-mlc.sh + +: The path to the model +: The path to the library of the model (e.g. so file) +: The path of the input prompts json file +: The ordinal of the conversation to start from +: The ordinal of the conversation to end at +: The output path for logs and metrics. +: The filename to use for energy events timestamps +: The iteration (i.e. repetition) that this experiment is running. +``` + +## Known issues + +* If you run the expect script on Mac OS, there is an issue where a message "your terminal doesn't support cursor position requests (CPR)" prevents the automation. \ No newline at end of file diff --git a/run_scripts/run-mlc-chat-cli.exp b/run_scripts/run-mlc-chat-cli.exp new file mode 100755 index 0000000..db733e2 --- /dev/null +++ b/run_scripts/run-mlc-chat-cli.exp @@ -0,0 +1,142 @@ +#!/usr/bin/expect + +# Note: This script is used to run the MLC models in interactive mode on jetson devices. +# Author: Stefanos Laskaridis (stefanos@brave.com) + +package require json + +# Check if an argument is provided +if { $argc != 8 } { + puts "Usage: $argv0 model_path model_lib_path input_prompts_filename conversation_from conversation_to output_path events_filename iteration" + exit 1 +} + +# config +set timeout -1 +set sleep_time 5 +set model_path [lindex $argv 0] +set model_lib_path [lindex $argv 1] +set input_prompts_filename [lindex $argv 2] +set conversation_from [expr {int([lindex $argv 3])}] +set conversation_to [expr {int([lindex $argv 4])}] +set output_path [lindex $argv 5] +set events_filename [lindex $argv 6] +set iteration [lindex $argv 7] + +# create output path +exec mkdir -p "$output_path/melt_measurements/" + +set log_path "$output_path/melt_measurements/llm_output_iter${iteration}_conv${conversation_from}.txt" +set measurements "$output_path/melt_measurements/measurements_iter${iteration}_conv${conversation_from}.csv" + +# log file +log_file $log_path + +# build expect prompt based on given model_path +if {[string first "Llama-2" $model_path] != -1} { + set expect_prompt "\[INST\]:\ " +} elseif {[string first "mistral" $model_path] != -1} { + set expect_prompt "\[INST\]:\ " +} elseif {[string first "TinyLlama" $model_path] != -1} { + set expect_prompt "<|im_start|>user: " +} elseif {[string first "stablelm" $model_path] != -1} { + set expect_prompt "<|user|>" +} elseif {[string first "google_gemma" $model_path] != -1} { + set expect_prompt "user: " +} else { + # error + puts "Error: Unknown model for given model_path: $model_path" + exit 1 +} +set expect_prompt "\n$expect_prompt" + +# define store metrics function +proc store_metrics {start_time end_time state measurements} { + set duration [expr {double($end_time - $start_time) / 1000.0}] + set start_time_epoch [expr {$start_time / 1000.0}] + set parsed_state [string map {\n \\n} $state] + exec echo "$start_time_epoch,$duration,\"$parsed_state\"\r" >> "$measurements" +} + +# Read the JSON file +set file_data [read [open $input_prompts_filename r]] +set input_prompts [json::json2dict $file_data] + +# set range +if {$conversation_to > [expr [llength $input_prompts] -1] } { + set conversation_to [expr [llength $input_prompts] -1] +} +set input_prompts [lrange $input_prompts $conversation_from $conversation_to] + + +# init measurements file (write csv header) +exec echo "start_date,duration,state\r" > "$measurements" + +# init variables, this init states are proxy to model loading +set start_time [clock milliseconds] +set state "load_model" + +# build command +set command "spawn mlc_chat chat --model-lib-path $model_lib_path --energy-events $events_filename $model_path" + +# Execute the command +eval $command + +sleep $sleep_time + +# iterate through conversations +foreach conversation $input_prompts { + + # iterate through prompts + foreach prompt $conversation { + + expect -ex $expect_prompt { + + # save metrics of previous prompt (or model load if first iteration) + set end_time [clock milliseconds] + store_metrics $start_time $end_time $state $measurements + + sleep $sleep_time + + # save state vars for next iteration and send the prompt + set state $prompt + + # Send stats on every prompt + send "/stats\r" + sleep 1 + + set start_time [clock milliseconds] + + # escape any \n characters in the prompt + set parsed_prompt [string map {\n \\n} $prompt] + send "$parsed_prompt\r" + } + } + + expect -ex $expect_prompt { + # print stats + send "/stats\r" + } + # expect -ex $expect_prompt { + # send "/reload\r" + # } + # expect -ex $expect_prompt { + # # reload model/context + # send "/reset\r" + # } +} + +# finish +expect -ex $expect_prompt { + + # save last metrics + set end_time [clock milliseconds] + store_metrics $start_time $end_time $prompt $measurements + + sleep $sleep_time + + # exit + send "/exit\r" + sleep 10 + expect eof +} diff --git a/run_scripts/run-mlc.sh b/run_scripts/run-mlc.sh new file mode 100755 index 0000000..1d095d5 --- /dev/null +++ b/run_scripts/run-mlc.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Note: This script is used to run the MLC models locally. +# Author: Stefanos Laskaridis (stefanos@brave.com) + +# show script usage +if [ $# -ne 8 ] +then + echo "====================================================" + echo "USAGE: $0 model_path model_lib_path input_prompts_filename conversation_from conversation_to output_path events_filename iteration" + echo "Passed parameters: $@" + echo "====================================================" + exit -1 +fi +MODEL_PATH=$1 +MODEL_LIB_PATH=$2 +INPUT_PROMPTS_FILENAME=$3 +CONVERSATION_FROM=$4 +CONVERSATION_TO=$5 +OUTPUT_PATH=$6 +EVENTS_FILENAME=$7 +ITERATION=$8 + +FILE_DIRECTORY="$(dirname "${BASH_SOURCE[0]}")" +REAL_OUTPUT_PATH=$(realpath $OUTPUT_PATH) + +# Check device type and set expect_script accordingly +expect_script="run-mlc-chat-cli.exp" + +# iterate per conversation +for (( i=CONVERSATION_FROM; i for ChatModuleError { + fn from(e: tvm_rt::Error) -> Self { + Self::TvmRuntime(e) + } +} + +pub type Result = result::Result; + +#[derive(Debug, Clone)] +pub struct ChatMessage { + role: String, + content: String, +} + +impl ChatMessage { + pub fn new(role: &str, content: &str) -> Self { + ChatMessage { + role: role.to_owned(), + content: content.to_owned(), + } + } +} + +#[derive(Debug, Clone)] +pub enum Prompt { + String(String), + MessageList(Vec), +} + +impl From<&str> for Prompt { + fn from(s: &str) -> Self { + Prompt::String(s.to_owned()) + } +} + +impl From for Prompt { + fn from(s: String) -> Self { + Prompt::String(s) + } +} + +impl From> for Prompt { + fn from(messages: Vec) -> Self { + Prompt::MessageList(messages) + } +} + +#[derive(Debug, Copy, Clone)] +pub enum PlaceInPrompt { + All = 0, + Begin = 1, + Middle = 2, + End = 3, +} + +impl PlaceInPrompt { + pub fn to_value(&self) -> i32 { + *self as i32 + } +} + +macro_rules! tvm_func_invoke { + // Handle the case with return type + ($self:ident, $func_name:ident($($args:expr),*) -> $ret_type:ty) => { + { + let f = $self.chat_module.get_function(stringify!($func_name), false)?; + let res: $ret_type = f.invoke(vec![$($args.into()),*])?.try_into().expect("call should succeed"); + Ok(res) + } + }; + // Handle the case without return type + ($self:ident, $func_name:ident($($args:expr),*)) => { + { + let f = $self.chat_module.get_function(stringify!($func_name), false)?; + f.invoke(vec![$($args.into()),*])?; + Ok(()) + } + }; +} + +/// Parse the input device identifier into device name and id. +/// +/// # Arguments +/// * `device` - The device identifier to parse. It can be in the format "device_name" (e.g., "cuda") +/// or "device_name:device_id" (e.g., "cuda:1"). +/// +/// # Returns +/// * `device_name` - The name of the device. +/// * `device_id` - The id of the device, or 0 if not specified in the input. +fn parse_device_str(device: &str) -> (&str, i32) { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + let device_args: Vec<&str> = device.split(':').collect(); + match device_args.len() { + 1 => (device_args[0], 0), + 2 => (device_args[0], device_args[1].parse::().unwrap()), + _ => panic!("{}", device_err_msg), + } +} + +/// Use user-provided argument `model` to search for a valid model path. +/// We define "valid" as having an `mlc-chat-config.json` right under the folder. +/// +/// # Arguments +/// * `model`: User's input; may be a compiled model's name, or a full path. +/// +/// # Returns +/// * `model_path`: A "valid" path to model folder with `mlc-chat-config.json` existing under it. +/// * `chat_file`: The path to the `mlc-chat-config.json` file. +/// +/// # Panics +/// * If a valid model_path cannot be found. +pub fn get_model_path(model: &str) -> (PathBuf, PathBuf) { + // Note that the order of this list corresponds to our search priority + let candidate_paths = vec![ + PathBuf::from(model), // full path, or just the name + PathBuf::from(format!("{}/params", model)), // Default directory after mlc_llm.build_model() + PathBuf::from(format!("dist/prebuilt/{}", model)), // Using prebuilt workflow + PathBuf::from(format!("dist/{}/params", model)), // Default directory after mlc_llm.build_model() in the current path + PathBuf::from(format!("dist/prebuilt/mlc-chat-{}", model)), // Also prebuilt workflow, but missed prefix + ]; + + // Look for the first folder that has `mlc-chat-config.json` under it + for candidate in &candidate_paths { + let chat_file = candidate.join("mlc-chat-config.json"); + if chat_file.is_file() { + info!("Using model folder: {:?}", candidate.canonicalize().unwrap()); + info!("Using mlc chat config: {:?}", chat_file.canonicalize().unwrap()); + return (candidate.clone(), chat_file); + } + } + + let mut found_folder = false; + let mut valid_dir_str = String::new(); + for candidate in &candidate_paths { + if candidate.is_dir() { + valid_dir_str += &format!("- {:?}\n", candidate.canonicalize().unwrap()); + found_folder = true; + } + } + + if found_folder { + // Error 1: there is a folder, but not an mlc-llm model folder (E1) + let err_msg = format!( + "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n\ + Specifically, we cannot find `mlc-chat-config.json`, a required file. You should \ + provide a path that contains the file.\n\ + According to your input `model`, we looked at folder(s):\n\ + {}\n\ + MLC-Chat consumes models that are processed by the MLC-LLM build process.\n\ + ", + valid_dir_str, + ); + panic!("{}", err_msg); + } else { + // Error 2: cannot find a folder (E0) + let all_paths_str = candidate_paths + .iter() + .map(|path| format!("- {}\n", path.display())) + .collect::(); + let err_msg = format!( + "Cannot find the model folder. We searched over the following possible paths:\n\ + {}\n\ + You can try to pass in `model=/path/to/your-model-path`, and confirm \ + that it contains `mlc-chat-config.json`, among other essential files.\n\ + ", + all_paths_str, + ); + panic!("{}", err_msg); + } +} + +/// Read in the config file in model path, then potentially override with user input. +/// +/// # Arguments +/// * `config_file_path`: &Path +/// `chat_file` returned by a function like `get_model_path()`. +fn get_chat_config(config_file_path: &Path) -> result::Result> { + // Read the base configuration from the file + let file_contents = fs::read_to_string(config_file_path)?; + let final_chat_config = ChatConfig::from_json(&file_contents)?; + Ok(final_chat_config) +} + +/// Look up the model library and return a corresponding `tvm` runtime Module. +/// +/// # Arguments +/// * `model` - A string representing either the name of a compiled model or a full path to it. +/// * `model_path` - The path to the model, as determined by `get_model_path`. +/// * `chat_config` - The chat configuration, possibly with overrides, returned by `get_chat_config`. +/// * `model_lib_path` - An optional string specifying the full path to the model library. This is prioritized if provided. +/// * `device_name` - A string representing the device for which the library model file name will be constructed. +/// * `config_file_path` - The path to the `mlc-chat-config.json` file, used for constructing error messages. +/// +/// # Returns +/// The path pointing to the model library we find. +fn get_lib_module_path( + model: &str, model_path: &Path, chat_config: &ChatConfig, model_lib_path: Option<&str>, device_name: &str, + config_file_path: &Path, +) -> PathBuf { + // 1. Use user's model_lib_path if provided + if let Some(lib_path) = model_lib_path { + let path = Path::new(lib_path); + if path.is_file() { + info!("Using library model: {:?}", path); + return path.to_path_buf(); + } else { + panic!("The `model_lib_path` you passed in is not a file: {:?}.", lib_path); + } + } + + // 2. Generate all possible file names according to OS + let mut candidate_paths = Vec::new(); + if let Some(model_lib) = &chat_config.model_lib { + let candidate_lib_names: Vec = if cfg!(target_os = "linux") { + vec![format!("{}-{}.so", model_lib, device_name)] + } else if cfg!(target_os = "macos") { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + ] + } else if cfg!(target_os = "windows") { + vec![format!("{}-{}.dll", model_lib, device_name)] + } else { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + format!("{}-{}.dll", model_lib, device_name), + ] + }; + + // 3. Generate possible model library paths + let pardir_model_path = model_path.parent().unwrap(); + for lib_name in &candidate_lib_names { + let paths: Vec = vec![ + lib_name.clone(), + format!("dist/prebuilt/lib/{}", lib_name), + format!("dist/{}/{}", model, lib_name), + model_path.join(lib_name).to_string_lossy().into_owned(), + pardir_model_path.join(lib_name).to_string_lossy().into_owned(), + ]; + + candidate_paths.extend(paths); + } + + // 4. Search for model library + for candidate in &candidate_paths { + let candidate_path = Path::new(candidate); + if candidate_path.is_file() { + info!("Using library model: {:?}", candidate_path); + return candidate_path.to_path_buf(); + } + } + + // 5. Error + let mut err_msg = format!( + "Cannot find the model library that corresponds to `{:?}`.\n\ + `{:?}` is either provided in the `chat_config` \ + you passed in, or specified in {:?}.\n\ + We searched over the following possible paths: \n", + model_lib, model_lib, config_file_path + ); + for candidate in &candidate_paths { + err_msg += &format!("- {}\n", candidate); + } + err_msg += &format!( + "If you would like to directly specify the model library path, you may \ + consider passing in the `ChatModule.model_lib_path` parameter." + ); + + panic!("{}", err_msg); + } else { + panic!("Cannot find the model library, you need to either pass it in, or specify in the chat_config file."); + } +} + +/// The ChatModule for MLC LLM. +/// +/// # Examples +/// +/// ``` +/// use mlc_llm::chat_module::ChatModule; +/// +/// // Create a ChatModule instance +/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); +/// +/// // Generate a response for a given prompt +/// let output = cm.generate("what is the meaning of life?", None).unwrap(); +/// +/// // Print prefill and decode performance statistics +/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); +/// +/// let output = cm.generate("what is Rust?", None).unwrap(); +/// ``` +pub struct ChatModule { + chat_module: Module, + chat_config: ChatConfig, +} + +impl ChatModule { + pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + + let (device_name, device_id) = parse_device_str(device); + + // 1. Get device name and id + let device_type = match device_name { + "cuda" => 2, + "opencl" => 4, + "vulkan" => 7, + "metal" => 8, + "rocm" => 10, + _ => panic!("{}", device_err_msg), + }; + + unsafe { + LLMChatDummyLinkFunc(); + } + + static GLOBAL_FUNC_NAME: &str = "mlc.llm_chat_create"; + let f = Function::get(GLOBAL_FUNC_NAME).ok_or(ChatModuleError::GlobalFuncNotFound)?; + let m: Module = f + .invoke(vec![device_type.into(), device_id.into()]) + .unwrap() + .try_into() + .expect("call should succeed"); + + // 2. Look up the model path + let (model_path, config_file_path) = get_model_path(model); + + // 3. Instantiate chat_config + let chat_config = get_chat_config(&config_file_path).unwrap(); + + // 4. Look up the model library + let model_lib_path = get_lib_module_path( + model, + &model_path, + &chat_config, + model_lib_path, + device_name, + &config_file_path, + ); + + let chat_mod = Self { + chat_module: m, + chat_config, + }; + let model_lib_str = model_lib_path.as_path().display().to_string(); + let model_path_str = model_path.as_path().display().to_string(); + chat_mod.reload(&model_lib_str, &model_path_str, "").unwrap(); + Ok(chat_mod) + } + + /// Reload the chat module from the given library and model path. + fn reload(&self, lib: &str, model_path: &str, app_config_json: &str) -> Result<()> { + tvm_func_invoke!(self, reload(lib, model_path, app_config_json)) + } + + /// Reset the chat session, clear all chat history, and potentially + /// override the original `mlc-chat-config.json`. + pub fn reset_chat(&self) -> Result<()> { + // TODO: add optional user-specified ChatConfig + tvm_func_invoke!(self, reset_chat()) + } + + /// Get the runtime stats of the encoding step, decoding step (and embedding step if exists) + /// of the chat module in text form. + pub fn stats(&self, verbose: bool) -> Result { + if verbose { + return tvm_func_invoke!(self, verbose_runtime_stats_text() -> String); + } + tvm_func_invoke!(self, runtime_stats_text() -> String) + } + + /// Check if the stop condition is met for the current round. + fn stopped(&self) -> Result { + tvm_func_invoke!(self, stopped() -> bool) + } + + /// Get the output message in the current round. + fn get_message(&self) -> Result { + tvm_func_invoke!(self, get_message() -> String) + } + + /// Decode the next token, the decoding result is stored in a buffer and + /// can be retrieved by [get_message]. + fn decode(&self, generation_config: Option<&GenerationConfig>) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + tvm_func_invoke!(self, decode(generation_config_str)) + } + + /// Load JSON config and override existing configurations for the chat module. + fn load_json_override(&self, config_str: &str, partial_update: bool) -> Result<()> { + tvm_func_invoke!(self, load_json_override(config_str, &partial_update)) + } + + /// Get the configuration of the chat module in a single json string. + fn get_config_json(&self) -> Result { + tvm_func_invoke!(self, get_config_json() -> String) + } + + /// Get the name of role 0 in the conversation. + fn get_role_0(&self) -> Result { + tvm_func_invoke!(self, get_role0() -> String) + } + + /// Get the name of role 1 in the conversation. + fn get_role_1(&self) -> Result { + tvm_func_invoke!(self, get_role1() -> String) + } + + /// A high-level method that returns the full response from the chat module given a user + /// prompt. User can optionally specify which callback method to use upon receiving the + /// response. + /// + /// # Arguments + /// * `prompt` - The user input prompt, i.e. a question to ask the chat module. + /// It can also be the whole conversation history (list of messages with role and content) + /// + /// # Examples + /// ``` + /// // Single prompt case, the `prompt` can be a &str + /// let prompt = "what is the meaning of life?"; + /// + /// // Multi-prompt case, the `prompt` can be Vec + /// let message1 = ChatMessage::new("user", "suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"); + /// let message2 = ChatMessage::new( + /// "assistant", + /// "based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. this name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."); + /// let message3 = ChatMessage::new("user", "I like cervidae, but the name is too long!"); + /// let prompt = vec![message1, message2, message3]; + /// ``` + /// + /// * `generation_config` - The generation config object to override the ChatConfig generation settings. + /// + /// # Returns + /// * `output` - The generated full output from the chat module. + pub fn generate( + &self, prompt: impl Into, generation_config: Option<&GenerationConfig>, + ) -> Result> { + // TODO: add progress_callback + let mut new_msgs: Vec = vec![]; + let mut num_return_sequences: usize = 1; + + if let Some(gc) = generation_config { + if let Some(n) = gc.n { + num_return_sequences = n; + } + } + + let prompt = prompt.into(); + for _ in 0..num_return_sequences { + self.reset_chat().unwrap(); + self.prefill(&prompt, true, PlaceInPrompt::All, generation_config) + .unwrap(); + + while !self.stopped().unwrap() { + self.decode(generation_config)?; + } + let new_msg = self.get_message().unwrap(); + new_msgs.push(new_msg); + } + + Ok(new_msgs) + } + + /// Runs the prefill stage for a given input and optionally decodes the first output token. + /// The user can decide where to place the input in the prompt. + /// + /// # Arguments + /// + /// * `input` - A `String` or a `Vec`. The user input prompt, i.e., a question to ask the chat module. + /// It can also be the whole conversation history (list of messages with role and content). + /// + /// # Examples + /// ``` + /// // Single prompt case, the `prompt` can be a &str + /// "what is the meaning of life?"; + /// + /// // Multi-prompt case, the `prompt` can be Vec + /// vec![ + /// ChatMessage::new("user", "Hello, how are you?"), + /// ChatMessage::new("assistant", "I'm fine, thank you. How about you?"), + /// ChatMessage::new("user", "I'm good too."), + /// ] + /// ``` + /// * `decode_next_token` - A boolean indicating whether to decode the next token after prefilling. + /// * `place_in_prompt` - The place of the input message in the prompt, as defined by the `PlaceInPrompt` enum. + /// * `generation_config` - An optional `GenerationConfig` to override the ChatConfig generation settings. + /// + /// # Examples + /// + /// ``` + /// let input = "Hello, how are you?"; + /// let decode_next_token = true; + /// let place_in_prompt = PlaceInPrompt::All; + /// let generation_config = Some(GenerationConfig::new()); + /// + /// prefill(input, decode_next_token, place_in_prompt, generation_config); + /// ``` + fn prefill( + &self, input: &Prompt, decode_next_token: bool, place_in_promt: PlaceInPrompt, + generation_config: Option<&GenerationConfig>, + ) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + + let input_string = match input { + Prompt::String(inp) => inp.clone(), + Prompt::MessageList(chat_msgs) => { + let mut chat_msgs = chat_msgs.clone(); + if chat_msgs.len() == 1 { + chat_msgs.remove(0).content + } else { + let chat_config = ChatConfig::from_json(&(self.get_config_json()?)).unwrap(); + let mut conv_config = chat_config + .conv_config + .unwrap_or_else(|| ConvConfigBuilder::default().build().unwrap()); + + let role0 = self.get_role_0()?; + let role1 = self.get_role_1()?; + + let last_msg = chat_msgs.last().expect("No last message in the vector").clone(); + if last_msg.role != "user" { + panic!("Last message should be from user."); + } + + let mut messages = Vec::new(); + let msg_len = chat_msgs.len(); + for msg in chat_msgs.into_iter().take(msg_len - 1) { + match msg.role.as_str() { + "user" => messages.push(vec![role0.clone(), msg.content]), + "assistant" => messages.push(vec![role1.clone(), msg.content]), + _ => panic!("Only user and assistant roles are supported."), + } + } + + conv_config.messages = Some(messages); + conv_config.offset = Some(0); + + let mut map = HashMap::new(); + map.insert("conv_config", conv_config); + self.load_json_override(&serde_json::to_string(&map).unwrap(), true)?; + + last_msg.content + } + } + }; + + tvm_func_invoke!( + self, + prefill( + input_string, + &decode_next_token, + place_in_promt.to_value(), + generation_config_str + ) + ) + } +} diff --git a/rust/src/config.rs b/rust/src/config.rs new file mode 100644 index 0000000..a623395 --- /dev/null +++ b/rust/src/config.rs @@ -0,0 +1,273 @@ +use serde::{Deserialize, Serialize}; + +/// A struct that represents user-defined partial configuration for conversation template. +/// +/// This can be passed in to the instantiation of a [ChatModule](crate::chat_module::ChatModule) +/// instance to override the default setting in `mlc-chat-config.json` under the +/// model folder. Note that we will first load the predefined template +/// with the name specified in `conv_template`. +/// +/// Since the configuration is partial, everything will be optional. +#[derive(Clone, Default, Builder, Debug, Serialize, Deserialize)] +#[builder(default)] +pub struct ConvConfig { + /// Token list prefixing the conversation. + prefix_tokens: Option>, + + /// Name of the conversation. + name: Option, + + /// The prompt encoded before starting the chat. + system: Option, + + /// An array that describes the role names of the user and the model. + roles: Option>, + + /// The chat history represented as an array of string pairs. + pub messages: Option>>, + + /// The offset used to begin the chat from the chat history. + pub offset: Option, + + /// Specifies whether we are in chat-bot mode (`0`) or pure LM prompt mode (`1`). + separator_style: Option, + + /// An array of strings indicating the separators to be used after a user message and a model message respectively. + seps: Option>, + + /// A string indicating the separator between a role and a message. + role_msg_sep: Option, + + /// A string indicating the separator to append to a role when there is no message yet. + role_empty_sep: Option, + + /// When the `stop_str` is encountered, the model will stop generating output. + stop_str: Option, + + /// A list of token IDs that act as stop tokens. + stop_tokens: Option>, + + /// Determines whether a beginning-of-string (bos) token should be added before the input tokens. + add_bos: Option, +} + +impl ConvConfig { + pub fn post_init(&mut self) { + if let Some(messages) = &self.messages { + if self.offset.is_none() { + self.offset = Some(messages.len()); + } + } + } +} + +/// A struct that represents user-defined partial configuration for the chat config file. +/// +/// An instance of [ChatConfig] can be passed in to override the default setting. +/// Since the configuration is partial, everything will be optional. +/// +/// Note: This struct is used to represent the chat config during intermediate processing. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct ChatConfig { + /// The necessary model library to launch this model architecture. + /// Recommended to reuse model library when possible. + pub model_lib: Option, + + /// Uniquely identifying the model in application. Also used by + /// CLI to specify which model to run. + pub local_id: Option, + + /// The name of the conversation template that this chat uses. + pub conv_template: Option, + + /// Temperature applied to logits before sampling. Encourages diverse outputs if higher. + pub temperature: Option, + + /// Controls the likelihood of the model generating repeated texts. + /// See the CTRL paper for more details: + repetition_penalty: Option, + + /// Determines the set of tokens from which we sample during decoding. + /// More info on top-p sampling: + top_p: Option, + + /// Approximated average number of generated tokens in each round. + mean_gen_len: Option, + + /// Maximum number of tokens to be generated in each round. + max_gen_len: Option, + + /// Fraction of maximum window size to shift when it is exceeded. + shift_fill_factor: Option, + + /// List of tokenizer files of the model. + tokenizer_files: Option>, + + /// Partial overriding configuration for conversation template. + pub conv_config: Option, + + /// The category of the model's architecture (e.g. `llama`, `gpt_neox`, `rwkv`). + model_category: Option, + + /// Name of the model (e.g. `Llama-2-7b-chat-hf`). + model_name: Option, + + /// Tensor parallel degree. + num_shards: Option, + + /// Maximum kv cache window size. + max_window_size: Option, +} + +impl ChatConfig { + pub fn from_json(json_str: &str) -> Result { + serde_json::from_str(json_str) + } +} + +/// A struct that represents user-defined generation configuration. +/// +/// An instance of [GenerationConfig] can be passed into the +/// [ChatModule::generate](crate::chat_module::ChatModule::generate) function +/// to override the default generation settings specified in `mlc-chat-config.json` +/// and `ChatConfig` under the model folder. +/// +/// Once the generation ends, `GenerationConfig` is discarded, as the values +/// are only intended to override the `ChatConfig` generation settings during a +/// single generation, unless it is recurrently passed to the `generate` function. +/// This allows for changing generation settings over time, without permanently +/// overriding the `ChatConfig`. +/// +/// Since the configuration is partial, all fields are optional. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct GenerationConfig { + /// The temperature applied to logits before sampling. The default value is + /// `0.7`. A higher temperature encourages more diverse outputs, while a + /// lower temperature produces more deterministic outputs. + temperature: Option, + + /// The repetition penalty controls the likelihood of the model generating + /// repeated texts. The default value is set to `1.0`, indicating that no + /// repetition penalty is applied. Increasing the value reduces the + /// likelihood of repeat text generation. However, setting a high + /// `repetition_penalty` may result in the model generating meaningless + /// texts. The ideal choice of repetition penalty may vary among models. Only + /// Active when presence_penalty and frequency_penalty are both `0.0`. + + /// For more details on how repetition penalty controls text generation, please + /// check out the CTRL paper . + repetition_penalty: Option, + + /// This parameter determines the set of tokens from which we sample during + /// decoding. The default value is set to `0.95`. At each step, we select + /// tokens from the minimal set that has a cumulative probability exceeding + /// the ``top_p` parameter. + + /// For additional information on top-p sampling, please refer to this blog + /// post: . + top_p: Option, + + /// The approximated average number of generated tokens in each round. Used + /// to determine whether the maximum window size would be exceeded. + mean_gen_len: Option, + + /// This parameter determines the maximum length of the generated text. If it is + /// not set, the model will generate text until it encounters a stop token. + max_gen_len: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on + /// whether they appear in the text so far, increasing the model's likelihood + /// to talk about new topics. Negative values can increase the likelihood of + /// repetition. + presence_penalty: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on their + /// existing frequency in the text so far, decreasing the model's likelihood to + /// repeat the same line verbatim. Negative values can increase the likelihood of + /// repetition. + frequency_penalty: Option, + + /// This parameter determines the number of text samples to generate. The default + /// value is `1`. Note that this parameter is only used when `stream` is set to + /// `false`. + pub n: Option, + + /// When `stop` is encountered, the model will stop generating output. + /// It can be a string or a list of strings. If it is a list of strings, the model + /// will stop generating output when any of the strings in the list is encountered. + /// Note that this parameter does not override the default stop string of the model. + stop: Option>, +} + +impl GenerationConfig { + pub fn from_chat_config(chat_config: &ChatConfig) -> Self { + Self { + temperature: chat_config.temperature, + repetition_penalty: chat_config.repetition_penalty, + top_p: chat_config.top_p, + mean_gen_len: chat_config.mean_gen_len, + max_gen_len: chat_config.max_gen_len, + presence_penalty: Some(0.0), + frequency_penalty: Some(0.0), + n: Some(0), + stop: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_conv_config() { + let mut config = ConvConfig { + messages: Some(vec![vec!["User: Hi".to_string(), "Assistant: Hello".to_string()]]), + offset: None, + ..Default::default() + }; + config.post_init(); + assert_eq!(config.offset, Some(1)); + } + + #[test] + fn test_chat_config() { + let json_data = r#" + { + "model_lib": "some_lib", + "local_id": "id123", + "temperature": 0.7 + } + "#; + + let config = ChatConfig::from_json(json_data).unwrap(); + + assert_eq!(config.model_lib, Some("some_lib".to_string())); + assert_eq!(config.local_id, Some("id123".to_string())); + assert_eq!(config.temperature, Some(0.7)); + let _pretty_json = serde_json::to_string_pretty(&config).unwrap(); + } + + #[test] + fn test_generation_config() { + let chat_config = ChatConfigBuilder::default() + .temperature(Some(0.7)) + .top_p(Some(0.8)) + .mean_gen_len(Some(50)) + .max_gen_len(Some(75)) + .build() + .unwrap(); + + let gen_config = GenerationConfig::from_chat_config(&chat_config); + + assert_eq!(gen_config.temperature, chat_config.temperature); + assert_eq!(gen_config.repetition_penalty, chat_config.repetition_penalty); + assert_eq!(gen_config.top_p, chat_config.top_p); + assert_eq!(gen_config.mean_gen_len, chat_config.mean_gen_len); + assert_eq!(gen_config.max_gen_len, chat_config.max_gen_len); + assert_eq!(gen_config.presence_penalty, Some(0.0)); + assert_eq!(gen_config.frequency_penalty, Some(0.0)); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 0000000..a8315d7 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[macro_use] extern crate derive_builder; + +pub mod chat_module; +pub mod config; diff --git a/scripts/build_site.sh b/scripts/build_site.sh new file mode 100755 index 0000000..6340ee8 --- /dev/null +++ b/scripts/build_site.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -euxo pipefail + +cd docs && make html && cd .. + +cd site && jekyll b && cd .. + +rm -rf site/_site/docs +cp -r docs/_build/html site/_site/docs diff --git a/scripts/check_url_validity.py b/scripts/check_url_validity.py new file mode 100644 index 0000000..3cbb29e --- /dev/null +++ b/scripts/check_url_validity.py @@ -0,0 +1,44 @@ +import requests +import argparse +import re +from pathlib import Path + + +def find_urls_in_file(file_path): + with open(file_path, "r") as file: + content = file.read() + + # Regular expression pattern to match URLs + url_pattern = re.compile( + r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" + ) + + # Find all matches of URLs in the content + urls = re.findall(url_pattern, content) + return [url.strip(">") for url in urls] + + +def main(): + parser = argparse.ArgumentParser( + description="Check validity of links in documentation" + ) + parser.add_argument( + "--directory", type=str, default="docs", help="Directory of documentation." + ) + args = parser.parse_args() + + # traversal the directory and find all rst files + doc_directory = Path(args.directory) + for file_path in doc_directory.glob("**/*.rst"): + print("Checking {}...".format(file_path)) + for url in find_urls_in_file(file_path): + try: + r = requests.get(url) + if r.status_code == 404: + print("404 not found: {}".format(url)) + except Exception as e: + print("Error connecting {}, error: {}".format(url, e)) + + +if __name__ == "__main__": + main() diff --git a/scripts/gh_deploy_site.sh b/scripts/gh_deploy_site.sh new file mode 100755 index 0000000..1b21c52 --- /dev/null +++ b/scripts/gh_deploy_site.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# NOTE: this script is triggered by github action automatically +# when megred into main + +set -euxo pipefail + +scripts/build_site.sh + +git fetch +git checkout -B gh-pages origin/gh-pages +rm -rf docs .gitignore +mkdir -p docs +cp -rf site/_site/* docs +touch docs/.nojekyll + +DATE=`date` +git add docs && git commit -am "Build at ${DATE}" +git push origin gh-pages +git checkout main && git submodule update +echo "Finish deployment at ${DATE}" diff --git a/scripts/local_deploy_site.sh b/scripts/local_deploy_site.sh new file mode 100755 index 0000000..9e75aae --- /dev/null +++ b/scripts/local_deploy_site.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# NOTE: use this script to check local site + +set -euxo pipefail + +scripts/build_site.sh + +cd site && jekyll serve --skip-initial-build --host localhost --baseurl /mlc-llm --port 8888 diff --git a/scripts/prep_emcc_deps.sh b/scripts/prep_emcc_deps.sh new file mode 100755 index 0000000..2c1306c --- /dev/null +++ b/scripts/prep_emcc_deps.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# This file prepares all the necessary dependencies for the web build. +set -euxo pipefail + +emcc --version +npm --version + +TVM_HOME_SET="${TVM_HOME:-}" + +git submodule update --init --recursive + +if [[ -z ${TVM_HOME_SET} ]]; then + echo "Do not find TVM_HOME env variable, use 3rdparty/tvm". + echo "Make sure you set TVM_HOME in your env variable to use emcc build correctly" + export TVM_HOME="${TVM_HOME:-3rdparty/tvm}" +fi + +cd ${TVM_HOME}/web && make +cd - diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b972149 --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +from distutils.core import setup +from setuptools.dist import Distribution +from setuptools import find_packages +import os + +# Note there is no need to setup when +# running locally. + +CURRENT_DIR = os.path.dirname(__file__) + + +def git_describe_version(original_version): + """Get git describe version.""" + ver_py = os.path.join(CURRENT_DIR, "version.py") + libver = {"__file__": ver_py} + exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) + _, gd_version = libver["git_describe_version"]() + if gd_version is not None and gd_version != original_version: + print("Use git describe based version %s" % gd_version) + return gd_version + + +__version__ = git_describe_version(None) + +setup( + name="mlc_llm", + version=__version__, + description="MLC LLM: Universal Compilation of Large Language Models", + url="https://llm.mlc.ai/", + author="MLC LLM Contributors", + license="Apache 2.0", + # See https://pypi.org/classifiers/ + classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + ], + keywords="machine learning", + zip_safe=False, + packages=find_packages(), + package_dir={"mlc_llm": "mlc_llm"}, + install_requires=["numpy", "torch", "transformers", "scipy", "timm"], + entry_points={"console_scripts": ["mlc_llm_build = mlc_llm.build:main"]}, + distclass=Distribution, +) diff --git a/site/.gitignore b/site/.gitignore new file mode 100644 index 0000000..51b3599 --- /dev/null +++ b/site/.gitignore @@ -0,0 +1,4 @@ +dist +llm-chat-config.json +_includes/stable_diffusion.html +_site diff --git a/site/CNAME b/site/CNAME new file mode 100644 index 0000000..0b04c40 --- /dev/null +++ b/site/CNAME @@ -0,0 +1 @@ +llm.mlc.ai \ No newline at end of file diff --git a/site/_config.yml b/site/_config.yml new file mode 100644 index 0000000..9806232 --- /dev/null +++ b/site/_config.yml @@ -0,0 +1,42 @@ +name: "MLC LLM" +short_name: "MLC LLM" + +url: https://llm.mlc.ai/ + +exclude: [README.md, serve_local.sh] + +plugins: + - jekyll-remote-theme + +remote_theme: mlc-ai/jekyll-theme-mlc + + +# Colorize code snippets with the rogue module if we want to deploy on GH. +highlighter: rouge + +markdown: kramdown + +# The path structure for blog posts. +permalink: /blog/:year/:month/:day/:title.html + +# Number of news stories on the front page. +front_page_news: 8 + +# Base pathname for links. +base: '' + +# make pages for the _projects folder +collections: + projects: + output: true + +course_title: + +# Navigation bar links. +navigation: + - title: Home + link: / + - title: Docs + link: /docs + - title: Github + link: https://github.com/mlc-ai/mlc-llm diff --git a/site/gif/android-demo.gif b/site/gif/android-demo.gif new file mode 100644 index 0000000..aec883f Binary files /dev/null and b/site/gif/android-demo.gif differ diff --git a/site/gif/ios-demo.gif b/site/gif/ios-demo.gif new file mode 100644 index 0000000..7256afe Binary files /dev/null and b/site/gif/ios-demo.gif differ diff --git a/site/gif/linux-demo.gif b/site/gif/linux-demo.gif new file mode 100644 index 0000000..15cfc9d Binary files /dev/null and b/site/gif/linux-demo.gif differ diff --git a/site/img/android/android-diagram.png b/site/img/android/android-diagram.png new file mode 100644 index 0000000..5f49f7c Binary files /dev/null and b/site/img/android/android-diagram.png differ diff --git a/site/img/android/android-studio.png b/site/img/android/android-studio.png new file mode 100644 index 0000000..7c40215 Binary files /dev/null and b/site/img/android/android-studio.png differ diff --git a/site/img/android/android-vs-ios.png b/site/img/android/android-vs-ios.png new file mode 100644 index 0000000..2436797 Binary files /dev/null and b/site/img/android/android-vs-ios.png differ diff --git a/site/img/android/local-advantage.png b/site/img/android/local-advantage.png new file mode 100644 index 0000000..854864f Binary files /dev/null and b/site/img/android/local-advantage.png differ diff --git a/site/img/diag.svg b/site/img/diag.svg new file mode 100644 index 0000000..af9d1c7 --- /dev/null +++ b/site/img/diag.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/site/img/multi-gpu/figure-1.svg b/site/img/multi-gpu/figure-1.svg new file mode 100644 index 0000000..d3083cf --- /dev/null +++ b/site/img/multi-gpu/figure-1.svg @@ -0,0 +1,247 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-2.svg b/site/img/multi-gpu/figure-2.svg new file mode 100644 index 0000000..70d35f5 --- /dev/null +++ b/site/img/multi-gpu/figure-2.svg @@ -0,0 +1,418 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-3.svg b/site/img/multi-gpu/figure-3.svg new file mode 100644 index 0000000..078231f --- /dev/null +++ b/site/img/multi-gpu/figure-3.svg @@ -0,0 +1,167 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/index.md b/site/index.md new file mode 100644 index 0000000..44befd4 --- /dev/null +++ b/site/index.md @@ -0,0 +1,68 @@ +--- +layout: default +title: Home +notitle: true +--- + +# MLC LLM + +MLC LLM is a universal solution that allows any language model to be deployed natively on a diverse set of hardware backends and native applications. + +Please visit [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) for detailed instructions. + +## Demos + +- [iOS](#ios) +- [Android](#android) +- [Windows Linux Mac](#windows-linux-mac) +- [Web browser](#web-browser) + +### iOS + +Our iOS app, MLCChat, is available on [App Store](https://apps.apple.com/us/app/mlc-chat/id6448482937) for iPhone and iPad. +You can try out the [Testflight app](https://testflight.apple.com/join/57zd7oxa) that sometimes contains beta release of latest models. +This app is tested on iPhone 15 Pro Max, iPhone 14 Pro Max, iPhone 14 Pro and iPhone 12 Pro. +Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, +[documentation](https://llm.mlc.ai/docs/deploy/ios.html) is available for building iOS apps with MLC LLM. + + +

+ +

+ +Note: Llama-7B takes 4GB of RAM and RedPajama-3B takes 2.2GB to run. We recommend a latest device with 6GB RAM for Llama-7B, or 4GB RAM for RedPajama-3B, to run the app. The text generation speed could vary from time to time, for example, slow in the beginning but recover to a normal speed then. + +### Android + +The demo APK is available to [download](https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk). The demo is tested on Samsung S23 with Snapdragon 8 Gen 2 chip, Redmi Note 12 Pro with Snapdragon 685 and Google Pixel phones. +Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, +[documentation](https://llm.mlc.ai/docs/deploy/android.html) is available for building android apps with MLC LLM. + +

+ +

+ +### Windows Linux Mac + +Our cpp interface runs on AMD, Intel, Apple and NVIDIA GPUs. +Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, +[documentation](https://llm.mlc.ai/docs/deploy/cli.html) is available for building C++ apps with MLC LLM. + +

+ +

+ +### Web Browser + +[WebLLM](https://webllm.mlc.ai/) is our companion project that deploys MLC LLM natively to browsers using WebGPU and WebAssembly. Still everything runs inside the browser without server resources, and accelerated by local GPUs (e.g. AMD, Intel, Apple or NVIDIA). + +## Links + +* Our official [GitHub repo](https://github.com/mlc-ai/mlc-llm); +* Our companion project [WebLLM](https://webllm.mlc.ai/) that enables running LLMs purely in browser. +* [Web Stable Diffusion](https://websd.mlc.ai/) is another MLC-series that runs the diffusion models purely in the browser. +* [Machine Learning Compilation course](https://mlc.ai) is available for a systematic walkthrough of our approach to universal deployment. + +## Disclaimer + +The pre-packaged demos are subject to the model License. diff --git a/site/privacy.md b/site/privacy.md new file mode 100644 index 0000000..f7f2d29 --- /dev/null +++ b/site/privacy.md @@ -0,0 +1,10 @@ +--- +layout: default +title: Home +notitle: true +--- + +# MLC Chat App Privacy + +MLC Chat run all generation locally. +All data stays in users' device and is not collected by the app. diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc new file mode 100644 index 0000000..98d01a5 --- /dev/null +++ b/tests/cpp/conv_unittest.cc @@ -0,0 +1,27 @@ +#include +#include + +void _TestConversationJSONRoundTrip(std::string templ_name) { + mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); + std::string conv_json = conv.GetConfigJSON(); + mlc::llm::Conversation conv_new; + conv_new.LoadJSONOverride(conv_json, false); + ASSERT_EQ(conv, conv_new); +} + +void _TestConversationPartialUpdate() { + mlc::llm::Conversation conv; + std::string json_str = "{\"offset\": -1}"; + ASSERT_ANY_THROW(conv.LoadJSONOverride(json_str, false)); + conv.LoadJSONOverride(json_str, true); + ASSERT_EQ(conv.offset, -1); +} + +TEST(ConversationTest, ConversationJSONRoundTripTest) { + _TestConversationJSONRoundTrip("vicuna_v1.1"); + _TestConversationJSONRoundTrip("conv_one_shot"); + _TestConversationJSONRoundTrip("redpajama_chat"); + _TestConversationJSONRoundTrip("LM"); +} + +TEST(ConversationTest, ConversationPartialUpdateTest) { _TestConversationPartialUpdate(); } diff --git a/tests/legacy-python/compare_lib.py b/tests/legacy-python/compare_lib.py new file mode 100644 index 0000000..5bcea1e --- /dev/null +++ b/tests/legacy-python/compare_lib.py @@ -0,0 +1,213 @@ +import argparse +import json +import os +from typing import List + +import numpy as np +import torch +import tvm +from transformers import AutoTokenizer, LlamaTokenizer +from tvm import relax, rpc +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument + +from mlc_llm import utils + + +class LibCompare(LibCompareVMInstrument): + def __init__(self, mod, device, time_eval, skip_rounds=0): + super().__init__(mod, device, True) + self.time_eval = time_eval + self.time_eval_results = {} + self.visited = set([]) + self.skip_rounds = skip_rounds + self.atol = 1e-2 + self.rtol = 1e-3 + + def skip_instrument(self, func, name, before_run, ret_val, *args): + print(f"run {name}") + if name.startswith("shape_func"): + return True + if self.counter < self.skip_rounds: + self.counter += 1 + print(f"[{self.counter}] Skip validating {name}..") + return True + if name in self.visited: + if self.time_eval and name in self.time_eval_results: + record = self.time_eval_results[name] + self.time_eval_results[name] = (record[0], record[1] + 1) + return True + self.visited.add(name) + return False + + def compare( + self, + name: str, + ref_args: List[tvm.nd.NDArray], + new_args: List[tvm.nd.NDArray], + ret_indices: List[int], + ): + super().compare(name, ref_args, new_args, ret_indices) + + if self.time_eval and name not in self.time_eval_results: + res = self.mod.time_evaluator( + name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 + )(*new_args) + self.time_eval_results[name] = (res.mean, 1) + print(f"Time-eval result {name} on {self.device}: {res}") + + +def print_as_table(sorted_list): + print( + "Name".ljust(50) + + "Time (ms)".ljust(12) + + "Count".ljust(8) + + "Total time (ms)".ljust(18) + + "Percentage (%)" + ) + total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 + for record in sorted_list: + time = record[1][0] * 1000 + weighted_time = time * record[1][1] + percentage = weighted_time / total_time * 100 + print( + record[0].ljust(50) + + "{:.4f}".format(time).ljust(12) + + str(record[1][1]).ljust(8) + + "{:.4f}".format(weighted_time).ljust(18) + + "{:.2f}".format(percentage) + ) + print("Total time: {:.4f} ms".format(total_time)) + print() + + +class TestState: + def __init__(self, args): + self.primary_device = tvm.device(args.primary_device) + ex = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.primary_device}.so", + ) + ) + self.vm = relax.VirtualMachine(ex, self.primary_device) + if args.cmp_device == "iphone": + lib_name = f"{args.model}-{args.quantization.name}-{args.cmp_device}.dylib" + local_lib_path = os.path.join(args.artifact_path, lib_name) + proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") + proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) + self.sess = rpc.connect(proxy_host, proxy_port, "iphone") + self.sess.upload(local_lib_path) + self.lib = self.sess.load_module(lib_name) + self.cmp_device = self.sess.metal() + elif args.cmp_device == "android": + lib_name = f"{args.model}-{args.quantization.name}-{args.cmp_device}.so" + local_lib_path = os.path.join(args.artifact_path, lib_name) + tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) + tracker = rpc.connect_tracker(tracker_host, tracker_port) + self.sess = tracker.request("android") + self.sess.upload(local_lib_path) + self.lib = self.sess.load_module(lib_name) + self.cmp_device = self.sess.cl(0) + else: + self.sess = None + self.lib = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", + ) + ) + self.cmp_device = tvm.device(args.cmp_device) + self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device) + self.cmp_instrument = LibCompare( + self.lib, + self.cmp_device, + time_eval=args.time_eval, + skip_rounds=args.skip_rounds, + ) + self.vm.set_instrument(self.cmp_instrument) + + +def deploy_to_pipeline(args) -> None: + with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f: + config = json.load(f) + + primary_device = tvm.device(args.primary_device) + const_params = utils.load_params(args.artifact_path, primary_device) + state = TestState(args) + + if config["model_category"] == "llama": + tokenizer = LlamaTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + + print("Tokenizing...") + inputs = tvm.nd.array( + tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), + primary_device, + ) + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) + seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) + second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) + kv_caches = state.vm["create_kv_cache"]() + + print("Running inference...") + print("======================= Starts Encoding =======================") + logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) + print_as_table( + sorted( + state.cmp_instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + state.cmp_instrument.time_eval_results.clear() + state.cmp_instrument.visited.clear() + print("======================= Starts Decoding =======================") + logits, kv_caches = state.vm["decode"]( + first_sampled_token, second_seq_len_shape, kv_caches, const_params + ) + print_as_table( + sorted( + state.cmp_instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + state.cmp_instrument.time_eval_results.clear() + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--primary-device", type=str, default="auto") + args.add_argument("--cmp-device", type=str, required=True) + args.add_argument("--prompt", type=str, default="The capital of Canada is") + args.add_argument("--time-eval", default=False, action="store_true") + args.add_argument("--skip-rounds", type=int, default=0) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + + if parsed.primary_device == "auto": + if tvm.cuda().exist: + parsed.primary_device = "cuda" + elif tvm.metal().exist: + parsed.primary_device = "metal" + elif tvm.rocm().exist: + parsed.primary_device = "rocm" + else: + raise ValueError("Cannot auto deduce device-name, please set it") + return parsed + + +if __name__ == "__main__": + args = _parse_args() + deploy_to_pipeline(args) diff --git a/tests/legacy-python/dump_intermediate.py b/tests/legacy-python/dump_intermediate.py new file mode 100644 index 0000000..59bcd85 --- /dev/null +++ b/tests/legacy-python/dump_intermediate.py @@ -0,0 +1,172 @@ +"""Debug a model by printing out argument information before and after each function.""" + +import argparse +import json +import os + +import numpy as np +import torch +import tvm +from transformers import AutoTokenizer +from tvm import relax + +from mlc_llm import utils + +# pylint: disable=redefined-outer-name + + +def _extract_metadata(model_lib): + # pylint: disable=import-outside-toplevel + from tvm.runtime import device, load_module + from tvm.runtime.relax_vm import VirtualMachine + + # pylint: enable=import-outside-toplevel + + return json.loads(VirtualMachine(load_module(model_lib), device("cpu"))["_metadata"]()) + + +class DumpInstrument: # pylint: disable=too-few-public-methods + """Defines what to do before and after each function.""" + + def __init__(self, verbose=True): + self.verbose = verbose + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + + def __call__(self, func, name, before_run, ret_val, *args): + # Determine what functions to look at + if before_run: # Whether before the function is called or after + return + # if self.first_nan_occurred: + # return + # if self.first_inf_occurred: + # return + if name.startswith("vm.builtin."): + return + if any(not isinstance(x, tvm.nd.NDArray) for x in args): + return + + # Decide what to print or save about the function's arguments (where args[-1] is the + # buffer we write the result to) + func_name = ( + f"f{self.counter}_before_{name}" if before_run else f"f{self.counter}_after_{name}" + ) + print(func_name) + + # Write your own behavior below. For example, we can count the number of INF/NaN in args[-1] + num_nans = np.sum(np.isnan(args[-1].numpy())) + num_infs = np.sum(np.isinf(args[-1].numpy())) + if num_nans > 0: + print(f"has NaN: {num_nans}") + self.first_nan_occurred = True + if num_infs > 0: + print(f"has INF: {num_infs}") + self.first_inf_occurred = True + + # You can also save the the arguments to experiment offline + # if self.counter == 769: + # for i, ndarray in enumerate(args): + # save_name = func_name + f"_arg{i}" + # np.save(f"./debug/{save_name}.npy", ndarray.numpy()) + + self.counter += 1 + + +def print_as_table(sorted_list): # pylint: disable=missing-function-docstring + # pylint: disable=consider-using-f-string + print( + "Name".ljust(50) + + "Time (ms)".ljust(12) + + "Count".ljust(8) + + "Total time (ms)".ljust(18) + + "Percentage (%)" + ) + total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 + for record in sorted_list: + time = record[1][0] * 1000 + weighted_time = time * record[1][1] + percentage = weighted_time / total_time * 100 + print( + record[0].ljust(50) + + "{:.4f}".format(time).ljust(12) + + str(record[1][1]).ljust(8) + + "{:.4f}".format(weighted_time).ljust(18) + + "{:.2f}".format(percentage) + ) + print("Total time: {:.4f} ms".format(total_time)) + print() + + +class TestState: + """Embodies the virtual machine and instrument.""" + + def __init__(self, args): + self.primary_device = tvm.device(args.primary_device) + ex = tvm.runtime.load_module(args.model_lib_path) + self.vm = relax.VirtualMachine(ex, self.primary_device) + self.sess = None + self.instrument = DumpInstrument(verbose=True) + self.vm.set_instrument(self.instrument) + + +def deploy_to_pipeline(args) -> None: + """Main pipeline forst testing; can be modified for specific testing purposes.""" + primary_device = tvm.device(args.primary_device) + model_metadata = _extract_metadata(args.model_lib_path) + const_params = utils.load_params_SLM(args.model, primary_device, model_metadata) + state = TestState(args) + tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model), trust_remote_code=True) + + print("Tokenizing...") + inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) + seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) + second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) + kv_caches = state.vm["_initialize_effect"]() + + print("Running inference...") + print("======================= Starts Encoding =======================") + + try: + prefill_func = state.vm["prefill"] + except AttributeError: + prefill_func = None + + if inputs.shape[1] > 1 and prefill_func: + inputs = tvm.nd.array(inputs, device=primary_device) + logits, kv_caches = prefill_func(inputs, seq_len_shape, kv_caches, const_params) + else: + for i in range(inputs.shape[1]): + input_slice = tvm.nd.array(inputs[:, i : i + 1], device=primary_device) + logits, kv_caches = state.vm["decode"]( + input_slice, seq_len_shape, kv_caches, const_params + ) + + print("======================= Starts Decoding =======================") + logits, kv_caches = state.vm["decode"]( + first_sampled_token, second_seq_len_shape, kv_caches, const_params + ) + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model", type=str, required=True) # The model weight folder + args.add_argument("--model-lib-path", type=str, required=True) # Path to the model library + args.add_argument("--primary-device", type=str, default="auto") # Device to run on + args.add_argument("--prompt", type=str, default="The capital of Canada is") + parsed = args.parse_args() + + if parsed.primary_device == "auto": + if tvm.cuda().exist: + parsed.primary_device = "cuda" + elif tvm.metal().exist: + parsed.primary_device = "metal" + else: + raise ValueError("Cannot auto deduce device-name, please set it") + return parsed + + +if __name__ == "__main__": + args = _parse_args() + deploy_to_pipeline(args) diff --git a/tests/legacy-python/evaluate.py b/tests/legacy-python/evaluate.py new file mode 100644 index 0000000..4a370c5 --- /dev/null +++ b/tests/legacy-python/evaluate.py @@ -0,0 +1,202 @@ +# pylint: disable=invalid-name,missing-docstring +# Used as reference + +import argparse +import json +import os +import time +from typing import List, Tuple + +import numpy as np +import torch +import tvm +from transformers import AutoTokenizer, LlamaTokenizer # type: ignore[import] +from tvm import relax +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument +from tvm.runtime import ShapeTuple + +from mlc_llm import utils + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--device-name", type=str, default="auto") + args.add_argument("--debug-dump", action="store_true", default=False) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--prompt", type=str, default="The capital of Canada is") + args.add_argument("--profile", action="store_true", default=False) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +class LibCompare(LibCompareVMInstrument): + def __init__(self, mod, device): + super().__init__(mod, device, verbose=False) + self.time_eval_results = {} + + def compare( + self, + name: str, + ref_args: List[tvm.nd.NDArray], + new_args: List[tvm.nd.NDArray], + ret_indices: List[int], + ): + if name.startswith("shape_func"): + return + if name not in self.time_eval_results: + super().compare(name, ref_args, new_args, ret_indices) + res = self.mod.time_evaluator( + name, + dev=self.device, + number=100, + repeat=3, + )(*new_args).mean + shapes = [arg.shape for arg in new_args] + total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args) + self.time_eval_results[name] = (res, 1, shapes, total_bytes) + else: + record = self.time_eval_results[name] + self.time_eval_results[name] = ( + record[0], + record[1] + 1, + record[2], + record[3], + ) + + +def print_as_table(sorted_list: List[Tuple[str, Tuple[float, int]]]): + print( + "Name".ljust(50) + + "Time (ms)".ljust(12) + + "Count".ljust(8) + + "Total time (ms)".ljust(18) + + "Pct (%)".ljust(10) + + "Memory (MB)".ljust(16) + + "Bandwidth (GB/s)".ljust(18) + + "Shape" + ) + total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000 + for record in sorted_list: + time_used = record[1][0] * 1000 + weighted_time = time_used * record[1][1] + percentage = weighted_time / total_time * 100 + total_bytes = record[1][3] + bandwidth = total_bytes / record[1][0] / (1024**3) + + print( + record[0].ljust(50) + + f"{time_used:.4f}".ljust(12) + + str(record[1][1]).ljust(8) + + f"{weighted_time:.4f}".ljust(18) + + f"{percentage:.2f}".ljust(10) + + f"{total_bytes / (1024 * 1024):.2f}".ljust(16) + + f"{bandwidth:.4f}".format(bandwidth).ljust(18) + + ", ".join(str(s) for s in record[1][2]) + ) + print(f"Total time: {total_time:.4f} ms") + print() + + +def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals + device = tvm.device(args.device_name) + const_params = utils.load_params(args.artifact_path, device) + ex = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.device_name}.so", + ) + ) + vm = relax.VirtualMachine(ex, device) + + with open( + os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), + "r", + encoding="utf-8", + ) as f: + config = json.load(f) + + if config["model_category"] == "llama": + tokenizer = LlamaTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + + print("Tokenizing...") + inputs = tvm.nd.array( + tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), + device, + ) + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), device) + seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) + second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) + kv_caches = vm["create_kv_cache"]() + # skip warm up + + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) + logits, kv_caches = vm["decode"]( + first_sampled_token, second_seq_len_shape, kv_caches, const_params + ) + device.sync() + + kv_caches = vm["create_kv_cache"]() + print("Running inference...") + start = time.time() + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) + device.sync() + encoding_end = time.time() + logits, kv_caches = vm["decode"]( + first_sampled_token, second_seq_len_shape, kv_caches, const_params + ) + device.sync() + end = time.time() + if args.debug_dump: + fcache_view = tvm.get_global_func("vm.builtin.attention_kv_cache_view") + first_k_cache = fcache_view(kv_caches[0], ShapeTuple([7, 32, 128])) + print(f"output kv_cache[0]:\n{first_k_cache.numpy().transpose(1, 0, 2)}") + print(f"output logits:\n{logits.numpy()}") + print( + f"Time elapsed: encoding {(encoding_end - start)} seconds, " + f"decoding {end - encoding_end} secs" + ) + + if args.profile: + cmp_instrument = LibCompare(ex, device) + vm.set_instrument(cmp_instrument) + + print("Profiling...") + kv_caches = vm["create_kv_cache"]() + + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) + print("======================= Encoding Profiling =======================") + print_as_table( + sorted( + cmp_instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + cmp_instrument.time_eval_results.clear() + + logits, kv_caches = vm["decode"]( + first_sampled_token, second_seq_len_shape, kv_caches, const_params + ) + print("======================= Decoding Profiling =======================") + print_as_table( + sorted( + cmp_instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + + +if __name__ == "__main__": + ARGS = _parse_args() + deploy_to_pipeline(ARGS) diff --git a/tests/legacy-python/module_intercept.py b/tests/legacy-python/module_intercept.py new file mode 100644 index 0000000..e63bb21 --- /dev/null +++ b/tests/legacy-python/module_intercept.py @@ -0,0 +1,147 @@ +"""This script is an example of running and comparing the outputs of two different TVM Relax VMs. +""" +# pylint: disable=missing-docstring,invalid-name +import json + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer +from tvm import relax +from tvm.contrib import tvmjs + +KVCACHE_FUNCS = [ + "vm.builtin.attention_kv_cache_append", + "vm.builtin.attention_kv_cache_view", +] +DEVICE = "cuda:0" +PROMPT = "What is the meaning of life?" +TOKENIZER = "./dist/debug-llama/" + +COMBO = { + "CURRENT": { + "model_lib": "./dist/debug-llama/llama.so", + "params": "./dist/debug-llama", + "target_func": "fused_fused_dequantize1_NT_matmul6", + }, + "LEGACY": { + "model_lib": "./dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so", + "params": "./dist/Llama-2-7b-chat-hf-q4f16_1/params", + "target_func": "fused_fused_decode2_NT_matmul", + }, +} + + +class Instrument: # pylint: disable=too-few-public-methods + def __init__( + self, + target_func: str, + ): + self.first_time = True + self.target_func = target_func + self.saved_args = [] # type: ignore + + def __call__( + self, + func, + func_symbol: str, + before_run: bool, + ret_value, + *args, + ): + if before_run: + return + if func_symbol.startswith("vm.builtin."): + if func_symbol not in KVCACHE_FUNCS: + return + if func_symbol == self.target_func and self.first_time: + self.first_time = False + for arg in args: + print(arg.shape, arg.dtype) + self.saved_args.append(arg.numpy()) + + +class TestState: + def __init__(self, device, model_lib, target_func): + self.mod = relax.VirtualMachine( + tvm.runtime.load_module(model_lib), + device, + ) + self.inst = Instrument(target_func=target_func) + self.mod.set_instrument(self.inst) + + +def _tokenize(sentence: str): + tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER, trust_remote_code=True) + tokens = tokenizer(PROMPT, return_tensors="pt").input_ids.to(torch.int32).numpy() + print(f"Tokenizing: {sentence}") + print(f"Tokens: {tokens}") + return tokens + + +def _load_params(params, device, metadata): + param_dict, _ = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for name in [x["name"] for x in metadata["params"]]: + param_list.append(param_dict[name]) + return param_list + + +def _load_params_legacy(params, device): + param_dict, metadata = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for i in range(metadata["ParamSize"]): + param_list.append(param_dict[f"param_{i}"]) + return param_list + + +def _as_input_tuple(scalar): + return tvm.runtime.ShapeTuple([scalar]) + + +@tvm.register_func("debug_save") +def _debug_save(x, _): + return tvm.nd.array(x.numpy(), x.device) + + +def main() -> None: + device = tvm.device(DEVICE) + prompt = _tokenize(PROMPT) + + def _run_legacy(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + kv_cache = state.mod["create_kv_cache"]() + param_list = _load_params_legacy(params, device) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + def _run_current(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + metadata = json.loads(state.mod["_metadata"]()) + kv_cache = state.mod["_initialize_effect"]() + param_list = _load_params(params, device, metadata) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + print("============== Running old flow =================") + new_args = _run_current(**COMBO["CURRENT"]) + print("============== Running new flow =================") + old_args = _run_legacy(**COMBO["LEGACY"]) + + for i, (new_arg, old_arg) in enumerate(zip(new_args, old_args)): + print(f"Checking arg {i}") + np.testing.assert_allclose(new_arg, old_arg, rtol=1e-12, atol=1e-12) + + +if __name__ == "__main__": + main() diff --git a/tests/legacy-python/test_batching_llama.py b/tests/legacy-python/test_batching_llama.py new file mode 100644 index 0000000..ff11188 --- /dev/null +++ b/tests/legacy-python/test_batching_llama.py @@ -0,0 +1,160 @@ +# pylint: disable=invalid-name,missing-docstring +# Used as reference + +import argparse +import json +import os + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer # type: ignore[import] +from tvm import relax +from tvm.runtime import ShapeTuple + +from mlc_llm import utils + +############################################################## +# Test file for e2e Llama with batching enabled by directly +# calling functions in VM. +# +# NOTE: the test will not be runnable until the attention +# compute function is integrated to Llama. This is left as +# an item that we will work on shortly in the future. +############################################################## + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") + args.add_argument("--device-name", type=str, default="auto") + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--prompt", type=str, default="What's the meaning of life?") + args.add_argument("--profile", action="store_true", default=False) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def sample_from_logits(vm, logits, device): + temperature = 0.7 + top_p = 0.95 + + num_sequence = logits.shape[0] + temperature_arr = tvm.nd.array(np.full((num_sequence,), temperature, dtype="float32"), device) + probs = vm["softmax_with_temperature"](logits, temperature_arr).numpy() + + sampled_tokens = [] + fsample_top_p_from_prob = tvm.get_global_func("vm.builtin.sample_top_p_from_prob") + for seq_id in range(num_sequence): + token = fsample_top_p_from_prob(tvm.nd.array(probs[seq_id]), top_p, np.random.sample()) + sampled_tokens.append(token) + return sampled_tokens + + +def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals + device = tvm.device(args.device_name) + const_params = utils.load_params(args.artifact_path, device) + ex = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.device_name}.so", + ) + ) + vm = relax.VirtualMachine(ex, device) + + with open( + os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), + "r", + encoding="utf-8", + ) as f: + config = json.load(f) + + assert config["model_category"] == "llama" + tokenizer = LlamaTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + + num_sequences = 4 + generated_tokens = [[], [], [], []] + prompts = [ + "What's the meaning of life?", + "Introduce the history of Pittsburgh to me.", + "Write a three-day Seattle travel plan.", + "What is Alaska famous of?", + ] + num_decode_steps = 256 + + print("Create KV cache...") + max_total_seq_len = 16384 + page_size = 16 + kv_cache = vm["create_kv_cache"](ShapeTuple([num_sequences, max_total_seq_len, page_size])) + + fadd_sequence = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence") + freset_append_length = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reset_append_lengths" + ) + freserve = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append" + ) + fsync = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") + + for seq_id in range(num_sequences): + print(f"Process seq {seq_id} for prefill...") + inputs = tvm.nd.array( + tokenizer(prompts[seq_id], return_tensors="pt").input_ids.to(torch.int32).numpy(), + device, + ) + seq_length = inputs.shape[1] + embedding = vm["embed"](inputs, const_params) + + seq_id_in_cache = fadd_sequence(kv_cache) + assert seq_id_in_cache == seq_id + + freset_append_length(kv_cache) + freserve(kv_cache, seq_id, seq_length) + fsync(kv_cache) + + print(f"Prefilling seq {seq_id}...") + logits, _ = vm["prefill_with_embed"](embedding, kv_cache, const_params) + + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == 1 + generated_tokens[seq_id].append(tokens[0]) + + print("Decoding...") + for step in range(num_decode_steps): + inputs = tvm.nd.array( + np.array( + [[generated_tokens[seq_id][-1]] for seq_id in range(num_sequences)], dtype="int32" + ), + device, + ) + embedding = vm["embed"](inputs, const_params) + freset_append_length(kv_cache) + for seq_id in range(num_sequences): + freserve(kv_cache, seq_id, 1) + fsync(kv_cache) + + logits, _ = vm["decode_with_embed"](embedding, kv_cache, const_params) + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == num_sequences + + for seq_id in range(num_sequences): + generated_tokens[seq_id].append(tokens[seq_id]) + + for seq_id in range(num_sequences): + output = tokenizer.decode(generated_tokens[seq_id]) + print("====================================================================") + print(f"Prompt {seq_id}: {prompts[seq_id]}") + print(f"Output: {output}") + print("\n\n") + + +if __name__ == "__main__": + ARGS = _parse_args() + deploy_to_pipeline(ARGS) diff --git a/tests/legacy-python/test_build_args.py b/tests/legacy-python/test_build_args.py new file mode 100644 index 0000000..8f32d12 --- /dev/null +++ b/tests/legacy-python/test_build_args.py @@ -0,0 +1,175 @@ +"""For testing the functionality of `BuildArgs` and `convert_build_args_to_argparser`.""" +import argparse +import dataclasses +import unittest + +from mlc_llm import BuildArgs, core, utils + + +def old_make_args(): + """The exact old way of creating `ArgumentParser`, used to test whether + `BuildArgs` is equivalent to this.""" + args = argparse.ArgumentParser() + args.add_argument( + "--model", + type=str, + default="auto", + help=( + 'The name of the model to build. If it is "auto", we will ' + 'automatically set the model name according to "--model-path", ' + '"hf-path" or the model folders under "--artifact-path/models"' + ), + ) + args.add_argument( + "--hf-path", + type=str, + default=None, + help="Hugging Face path from which to download params, tokenizer, and config", + ) + args.add_argument( + "--quantization", + type=str, + choices=[*utils.quantization_schemes.keys()], + default=list(utils.quantization_schemes.keys())[0], + help="The quantization mode we use to compile.", + ) + args.add_argument( + "--max-seq-len", + type=int, + default=-1, + help="The maximum allowed sequence length for the model.", + ) + args.add_argument( + "--target", type=str, default="auto", help="The target platform to compile the model for." + ) + args.add_argument( + "--reuse-lib", + type=str, + default=None, + help="Whether to reuse a previously generated lib.", + ) + args.add_argument( + "--artifact-path", type=str, default="dist", help="Where to store the output." + ) + args.add_argument( + "--use-cache", + type=int, + default=1, + help="Whether to use previously pickled IRModule and skip trace.", + ) + args.add_argument( + "--debug-dump", + action="store_true", + default=False, + help="Whether to dump debugging files during compilation.", + ) + args.add_argument( + "--debug-load-script", + action="store_true", + default=False, + help="Whether to load the script for debugging.", + ) + args.add_argument( + "--llvm-mingw", + type=str, + default="", + help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.", + ) + args.add_argument( + "--system-lib", action="store_true", default=False, help="A parameter to `relax.build`." + ) + args.add_argument( + "--sep-embed", + action="store_true", + default=False, + help=( + "Build with separated embedding layer, only applicable to LlaMa. " + "This feature is in testing stage, and will be formally replaced after " + "massive overhaul of embedding feature for all models and use cases" + ), + ) + + return args + + +# Referred to HfArgumentParserTest from https://github.com/huggingface/ +# transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils +# /test_hf_argparser.py#L143 +class BuildArgsTest(unittest.TestCase): + """Tests whether BuildArgs reaches parity with regular ArgumentParser.""" + + def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser): + """ + Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances. + """ + self.assertEqual( + len(parse_a._actions), len(parse_b._actions) + ) # pylint: disable=protected-access + for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access + xx = {k: v for k, v in vars(x).items() if k != "container"} + yy = {k: v for k, v in vars(y).items() if k != "container"} + # Choices with mixed type have custom function as "type" + # So we need to compare results directly for equality + if xx.get("choices", None) and yy.get("choices", None): + for expected_choice in yy["choices"] + xx["choices"]: + self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice)) + del xx["type"], yy["type"] + + self.assertEqual(xx, yy) + + def test_new_and_old_arg_parse_are_equivalent(self): + """Tests whether creating `ArgumentParser` from `BuildArgs` is equivalent + to the conventional way of creating it.""" + self.argparsers_equal(core.convert_build_args_to_argparser(), old_make_args()) + + def test_namespaces_are_equivalent_str(self): + """Tests whether the resulting namespaces from command line entry + and Python API entry are equivalent, as they are passed down to the + same workflow.""" + # Namespace that would be created through Python API build_model + build_args = BuildArgs(model="RedPJ", target="cuda") + build_args_as_dict = dataclasses.asdict(build_args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + + # Namespace that would be created through commandline + empty_args = core.convert_build_args_to_argparser() + parsed_args = empty_args.parse_args(["--model", "RedPJ", "--target", "cuda"]) + + self.assertEqual(build_args_namespace, parsed_args) + + # Modify build_args so that it would not be equivalent + build_args = BuildArgs(model="RedPJ", target="vulkan") + build_args_as_dict = dataclasses.asdict(build_args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + + self.assertNotEqual(build_args_namespace, parsed_args) + + def test_namespaces_are_equivalent_str_boolean_int(self): + """Same test, but for a mixture of argument types.""" + # 1. Equal + build_args = BuildArgs(model="RedPJ", max_seq_len=20, debug_dump=True) + build_args_as_dict = dataclasses.asdict(build_args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + + # Namespace that would be created through commandline + empty_args = core.convert_build_args_to_argparser() + parsed_args = empty_args.parse_args( + ["--model", "RedPJ", "--max-seq-len", "20", "--debug-dump"] + ) + self.assertEqual(build_args_namespace, parsed_args) + + # 2. Not equal - missing boolean + build_args = BuildArgs(model="RedPJ", max_seq_len=20) + build_args_as_dict = dataclasses.asdict(build_args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + self.assertNotEqual(build_args_namespace, parsed_args) + + # 3. Not equal - different integer + build_args = BuildArgs(model="RedPJ", max_seq_len=18, debug_dump=True) + build_args_as_dict = dataclasses.asdict(build_args) + build_args_namespace = argparse.Namespace(**build_args_as_dict) + self.assertNotEqual(build_args_namespace, parsed_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/legacy-python/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py new file mode 100644 index 0000000..b342e03 --- /dev/null +++ b/tests/legacy-python/test_build_model_from_args.py @@ -0,0 +1,142 @@ +import argparse +import os +import unittest +from unittest.mock import MagicMock, mock_open, patch + +from mlc_llm import utils +from mlc_llm.core import build_model_from_args + + +class MockMkdir(object): + def __init__(self): + self.received_args = None + + def __call__(self, *args): + self.received_args = args + + +class BuildModelTest(unittest.TestCase): + def setUp(self): + self._orig_mkdir = os.mkdir + os.mkdir = MockMkdir() + + self.mock_args = argparse.Namespace() + self.mock_args.quantization = utils.quantization_schemes["q8f16_1"] + self.mock_args.debug_dump = False + self.mock_args.use_cache = False + self.mock_args.sep_embed = False + self.mock_args.build_model_only = True + self.mock_args.use_safetensors = False + self.mock_args.convert_weights_only = False + self.mock_args.no_cutlass_attn = True + self.mock_args.no_cutlass_norm = True + self.mock_args.reuse_lib = True + self.mock_args.artifact_path = "/tmp/" + self.mock_args.model_path = "/tmp/" + self.mock_args.model = "/tmp/" + self.mock_args.target_kind = "cuda" + self.mock_args.max_seq_len = 2048 + + def tearDown(self): + os.mkdir = self._orig_mkdir + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", MagicMock(side_effect=[{}])) + def test_llama_model(self, mock_file): + self.mock_args.model_category = "llama" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "use_parallel_residual": False, + "hidden_size": 32, + "intermediate_size": 32, + "num_attention_heads": 32, + "num_hidden_layers": 28, + "vocab_size": 1024, + "rotary_pct": 1, + "rotary_emb_base": 1, + "layer_norm_eps": 1, + } + ] + ), + ) + def test_gpt_neox_model(self, mock_file): + self.mock_args.model_category = "gpt_neox" + self.mock_args.model = "dolly-test" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", MagicMock(side_effect=[{}])) + def test_gpt_bigcode_model(self, mock_file): + self.mock_args.model_category = "gpt_bigcode" + self.mock_args.model = "gpt_bigcode" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", MagicMock(side_effect=[{}])) + def test_minigpt_model(self, mock_file): + self.mock_args.model_category = "minigpt" + self.mock_args.model = "minigpt4-7b" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "vocab_size": 1024, + "n_embd": 32, + "n_inner": 32, + "n_head": 32, + "n_layer": 28, + "bos_token_id": 28, + "eos_token_id": 1, + "rotary_dim": 1, + "tie_word_embeddings": 1, + } + ] + ), + ) + def test_gptj_model(self, mock_file): + self.mock_args.model_category = "gptj" + self.mock_args.model = "gpt-j-" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "num_hidden_layers": 16, + "vocab_size": 1024, + "hidden_size": 16, + "intermediate_size": 32, + } + ] + ), + ) + def test_rwkv_model(self, mock_file): + self.mock_args.model_category = "rwkv" + self.mock_args.model = "rwkv-" + + build_model_from_args(self.mock_args) + + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", MagicMock(side_effect=[{}])) + def test_chatglm_model(self, mock_file): + self.mock_args.model_category = "chatglm" + self.mock_args.model = "chatglm2" + + build_model_from_args(self.mock_args) diff --git a/tests/legacy-python/test_sliding_window_mask.py b/tests/legacy-python/test_sliding_window_mask.py new file mode 100644 index 0000000..51be2d0 --- /dev/null +++ b/tests/legacy-python/test_sliding_window_mask.py @@ -0,0 +1,338 @@ +# fmt: off +"""For testing `_make_sliding_window_mask` in mistral.py""" + +import unittest + +import numpy as np +import tvm +from mlc_llm.relax_model.mistral import _make_sliding_window_mask +from tvm import relax +from tvm.runtime import ShapeTuple + + +def _create_vm(): + # pylint: disable=too-many-locals + bb = relax.BlockBuilder() + + # Step 1: Build `_make_sliding_window_mask()` into an IRModule + bsz = tvm.tir.Var("bsz", "int64") + seq_length = tvm.tir.Var("seq_length", "int64") # tgt_len + kv_seq_len = tvm.tir.Var("kv_seq_len", "int64") + sliding_window = tvm.tir.Var("sliding_window", "int64") + + with bb.function("main"): + # Convert to relax.Var because params to an IRModule function needs to be relax.Var + bsz_shape = relax.Var("bsz", relax.ShapeStructInfo((bsz,))) + seq_length_shape = relax.Var("seq_length", relax.ShapeStructInfo((seq_length,))) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + sliding_window_shape = relax.Var("sliding_window", relax.ShapeStructInfo((sliding_window,))) + + # Convert back to tir.Var since `_prepare_sliding_window_mask` needs it to be tir.Var + with bb.dataflow(): + bsz_input = bsz_shape.struct_info.values[0] + seq_length_input = seq_length_shape.struct_info.values[0] + kv_seq_len_input = kv_seq_len_shape.struct_info.values[0] + sliding_window_input = sliding_window_shape.struct_info.values[0] + mask = _make_sliding_window_mask( + (bsz_input, seq_length_input), + kv_seq_len_input, + sliding_window_input, + "float32", + ) + params = [ + bsz_shape, + seq_length_shape, + kv_seq_len_shape, + sliding_window_shape, + ] + gv = bb.emit_output(mask) + bb.emit_func_output(gv, params) + + # Step 2. Optimize IRModule + mod = bb.get() + mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter + with tvm.target.Target("cuda"): + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + + # Step 3. Deploy to GPU + ex = relax.build(mod, "cuda") + vm = relax.VirtualMachine(ex, tvm.cuda()) #pylint: disable=redefined-outer-name + return vm + + +vm = _create_vm() + +class SlidingWindowMaskTest(unittest.TestCase): + """ + The sliding window mask is based on figure 3 of the Mistral paper. + There are three cases when making a mask: first prefill, subsequent prefill, + and decoding. + + 1. First Prefill + This is when the cache is empty (i.e. kv_seq_len == 0). If tgt_len <= sliding_window, + this is just a normal causal mask. Otherwise, e.g. tgt_len = 3, WS = 2, we create a + mask below: + 1, 0, 0 + 1, 1, 0 + 0, 1, 1 + + 2. Subsequent Prefill + This is when the cache is not empty and yet tgt_len > 1. + e.g. t0-t4 in cache; current input is t5-t7; WS=5 + 0, 1, 2, 3, 4, | 5, 6, 7 + + 0, 1, 1, 1, 1, | 1, 0, 0 + 0, 0, 1, 1, 1, | 1, 1, 0 + 0, 0, 0, 1, 1, | 1, 1, 1 + [in cache] [current] + + 3. Decode + It will always be ones with shape (1 + kv_seq_len) since cache_size equals sliding_window. + Note that a prefilling (first or subsequent) with chunk_size of 1 is equivalent to a decode + in mask making. + """ + + ################### 1. TESTS FOR FIRST PREFILL ################### + def test_first_prefill_chunk_size_smaller_than_ws(self): + """ + When chunk size < WS, we return a normal causal mask. + Here, chunk size 3, WS 5. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) # chunk size is 3 + kv_seq_len = ShapeTuple([3]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_equals_ws(self): + """ + When chunk_size == WS, we also return a normal causal mask. + Here both chunk size and WS are 5. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_greater_than_ws(self): + """ + When chunk_size > WS, return a normal causal mask but each row only has at most WS 1's. + Here chunk_size = 5, WS=3. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], + [-3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], + [-3.402823e38, -3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_one(self): + """ + Corner case: the prompt only has 1 token. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([1]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + ################### 2. TESTS FOR SUBSEQUENT PREFILL ################### + def test_subsequent_prefill_1(self): + """ + Test 1: chunk size is 3, WS is 5, cache carrying t0, t1, t2; input t3, t4, t5. + """ + + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t0 t1 t2 t3 t4 t5 + [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_2(self): + """ + Test 2: chunk size is 3, WS is 5, cache carrying t1 - t5 (t0 is overwritten); + input t6, t7, t8. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) + kv_seq_len = ShapeTuple([8]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t1 t2 t3 t4 t5 t6 t7 t8 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_3(self): + """ + Test 3: chunk size is 5, WS is 5, cache carrying t0-t4; input t5-t9. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([10]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_4(self): + """ + Test 4: chunk size is 5, WS is 3, cache carrying t2-t4 (t0, t1 did not + stay in cache); input t5-t9. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([8]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t2 t3 t4 t5 t6 t7 t8 t9 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_5(self): + """ + Test 5: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); + input t10 (remainder of a prompt). Note that this test can also be + viewed as a decode. That is, prefilling a chunk of size 1, is the same is decoding. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE |CURRENT CHUNK| + # t5 t6 t7 t8 t9 t10 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + ################### 3. TESTS FOR DECODE ################### + def test_decode_1(self): + """ + Test 1: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); + input t10 (decoding). + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE |CURRENT CHUNK| + # t5 t6 t7 t8 t9 t10 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_decode_2(self): + """ + Test 2 (Cache not full): prompt is size 4, WS is 5, cache carrying t0-t3; input t4. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # | IN CACHE |CURRENT CHUNK| + # t0 t1 t2 t3 t4 + [3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/python/__init__.py b/tests/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/python/api/test_python.py b/tests/python/api/test_python.py new file mode 100644 index 0000000..ceba066 --- /dev/null +++ b/tests/python/api/test_python.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring +import pytest + +from mlc_chat import ChatModule, GenerationConfig +from mlc_chat.callback import StreamToStdout + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate_with_stream(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + progress_callback=StreamToStdout(callback_interval=2), + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize( + "generation_config", + [ + GenerationConfig(temperature=0.7, presence_penalty=0.1, frequency_penalty=0.5, top_p=0.9), + GenerationConfig(stop=["cake", "make"], n=3), + GenerationConfig(max_gen_len=40, repetition_penalty=1.2), + ], +) +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_generation_config(generation_config: GenerationConfig, model: str): + chat_module = ChatModule(model=model) + output = chat_module.generate( + prompt="How to make a cake?", + generation_config=generation_config, + ) + print(output) + print(f"Statistics: {chat_module.stats()}\n") diff --git a/tests/python/api/test_rest.py b/tests/python/api/test_rest.py new file mode 100644 index 0000000..f4ef442 --- /dev/null +++ b/tests/python/api/test_rest.py @@ -0,0 +1,105 @@ +# pylint: disable=missing-docstring +import json +import os +import signal +import subprocess +import time + +import pytest +import requests + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.fixture +def run_rest_server(model): + cmd = f"python -m mlc_chat.rest --model {model}" + print(cmd) + os.environ["PYTHONPATH"] = "./python" + with subprocess.Popen(cmd.split()) as server_proc: + # wait for server to start + while True: + try: + _ = requests.get("http://localhost:8000/stats", timeout=5) + break + except requests.exceptions.ConnectionError: + time.sleep(1) + yield + server_proc.send_signal(signal.SIGINT) + server_proc.wait() + + +@pytest.mark.usefixtures("run_rest_server") +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("model", MODELS) +def test_rest_chat_completions(model, stream): + payload = { + "model": model, + "messages": [ + { + "role": "user", + "content": "Hello, I am Bob", + }, + { + "role": "assistant", + "content": "Hello, I am a chatbot.", + }, + { + "role": "user", + "content": "What is my name?", + }, + ], + "stream": stream, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.0, + "top_p": 0.95, + } + if stream: + with requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True, timeout=120 + ) as model_response: + print("With streaming:") + for chunk in model_response: + data = chunk[6:-2] + if data != b"[DONE]": + content = json.loads(data)["choices"][0]["delta"].get("content", "") + print(f"{content}", end="", flush=True) + print("\n") + else: + model_response = requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, timeout=120 + ) + print(f"\n{model_response.json()['choices'][0]['message']['content']}\n") + + +@pytest.mark.usefixtures("run_rest_server") +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("model", MODELS) +def test_rest_completions(model, stream): + payload = { + "model": model, + "prompt": "What is the meaning of life?", + "stream": stream, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.0, + "n": 3, + } + if stream: + with requests.post( + "http://127.0.0.1:8000/v1/completions", json=payload, stream=True, timeout=120 + ) as model_response: + print("With streaming:") + for chunk in model_response: + data = chunk[6:-2] + if data != b"[DONE]": + content = json.loads(data)["choices"][0]["text"] + print(f"{content}", end="", flush=True) + print("\n") + else: + model_response = requests.post( + "http://127.0.0.1:8000/v1/completions", json=payload, timeout=120 + ) + assert len(model_response.json()["choices"]) == 3 + print(f"\n{model_response.json()['choices'][0]['text']}\n") diff --git a/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py b/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py new file mode 100644 index 0000000..eed1010 --- /dev/null +++ b/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py @@ -0,0 +1,342 @@ +# pylint: disable=invalid-name,missing-docstring,too-few-public-methods +import tvm +from tvm.ir import assert_structural_equal +from tvm.script import ir as I +from tvm.script import relax as R + +from mlc_chat.compiler_pass.fuse_ft_dequantize_matmul_epilogue import ( + FuseFTDequantizeEpilogue, +) + + +def test_fuse_bias(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + ): + with R.dataflow(): + lv1 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + "identity", + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + lv2 = R.add(lv1, bias) + R.output(lv2) + return lv2 + + @I.ir_module + class After: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + ) -> R.Tensor((1, 1, 1024), "float16"): + with R.dataflow(): + lv2 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias", + ( + x, + weight, + scale, + bias, + R.str("identity"), + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + R.prim_value(0), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + R.output(lv2) + return lv2 + + seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) + mod = seq(Before) + assert_structural_equal(mod, After) + + +def test_fuse_activation(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + ): + with R.dataflow(): + lv1 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + "identity", + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + lv2 = R.nn.silu(lv1) + R.output(lv2) + return lv2 + + @I.ir_module + class After: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + ) -> R.Tensor((1, 1, 1024), "float16"): + with R.dataflow(): + lv2 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + R.str("silu"), + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + R.output(lv2) + return lv2 + + seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) + mod = seq(Before) + assert_structural_equal(mod, After) + + +def test_fuse_bias_activation(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + ): + with R.dataflow(): + lv1 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + "identity", + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + lv2 = R.add(lv1, bias) + lv3 = R.nn.relu(lv2) + R.output(lv3) + return lv3 + + @I.ir_module + class After: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + ) -> R.Tensor((1, 1, 1024), "float16"): + with R.dataflow(): + lv2 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias", + ( + x, + weight, + scale, + bias, + R.str("relu"), + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + R.prim_value(0), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + R.output(lv2) + return lv2 + + seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) + mod = seq(Before) + assert_structural_equal(mod, After) + + +def test_fuse_residual_binary(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + residual: R.Tensor((1, 1, 1024), "float16"), + ): + with R.dataflow(): + lv1 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + "identity", + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + lv2 = R.add(lv1, bias) + lv3 = R.nn.relu(lv2) + lv4 = R.multiply(lv3, residual) + R.output(lv4) + return lv4 + + @I.ir_module + class After: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + residual: R.Tensor((1, 1, 1024), "float16"), + ) -> R.Tensor((1, 1, 1024), "float16"): + with R.dataflow(): + lv2 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias_residual", + ( + x, + weight, + scale, + bias, + residual, + R.str("relu"), + R.str("multiply"), + R.str("identity"), + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + R.output(lv2) + return lv2 + + seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) + mod = seq(Before) + assert_structural_equal(mod, After) + + +def test_fuse_residual_unary(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + residual: R.Tensor((1, 1, 1024), "float16"), + ): + with R.dataflow(): + lv1 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int", + ( + x, + weight, + scale, + "identity", + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + lv2 = R.add(lv1, bias) + lv3 = R.nn.relu(lv2) + lv4 = R.add(lv3, residual) + lv5 = R.nn.gelu(lv4) + R.output(lv5) + return lv5 + + @I.ir_module + class After: + @R.function + def main( + x: R.Tensor((1, 1, 4096), "float16"), + weight: R.Tensor((4096, 512), "int8"), + scale: R.Tensor((1, 1024), "float16"), + bias: R.Tensor((1, 1, 1024), "float16"), + residual: R.Tensor((1, 1, 1024), "float16"), + ) -> R.Tensor((1, 1, 1024), "float16"): + with R.dataflow(): + lv2 = R.call_dps_packed( + "fastertransformer.gemm_fp16_int_bias_residual", + ( + x, + weight, + scale, + bias, + residual, + R.str("relu"), + R.str("plus"), + R.str("gelu"), + R.prim_value(1), + R.prim_value(1024), + R.prim_value(4096), + R.prim_value(4096), + ), + out_sinfo=R.Tensor((1, 1, 1024), "float16"), + ) + R.output(lv2) + return lv2 + + seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) + mod = seq(Before) + assert_structural_equal(mod, After) + + +if __name__ == "__main__": + test_fuse_bias() + test_fuse_activation() + test_fuse_bias_activation() + test_fuse_residual_binary() + test_fuse_residual_unary() diff --git a/tests/python/conftest.py b/tests/python/conftest.py new file mode 100644 index 0000000..b19fce7 --- /dev/null +++ b/tests/python/conftest.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,unused-import +import pytest +import tvm.testing + +pytest_plugins = ["tvm.testing.plugin"] diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py new file mode 100644 index 0000000..92c8894 --- /dev/null +++ b/tests/python/integration/test_model_compile.py @@ -0,0 +1,155 @@ +# pylint: disable=missing-docstring +import concurrent.futures as cf +import os +import shlex +import subprocess +import sys +import tempfile +from itertools import product + +import tvm + +from mlc_chat.model import MODEL_PRESETS +from mlc_chat.support.constants import MLC_TEMP_DIR + +OPT_LEVEL = "O2" +DEVICE2TARGET = { + "cuda": { + "kind": "cuda", + "arch": "sm_86", + "max_threads_per_block": 1024, + "max_num_threads": 1024, + "max_shared_memory_per_block": 49152, + "thread_warp_size": 32, + }, + "rocm": { + "kind": "rocm", + "mtriple": "amdgcn-amd-amdhsa-hcc", + "mcpu": "gfx1100", + "thread_warp_size": 32, + "max_threads_per_block": 1024, + "max_num_threads": 256, + "max_shared_memory_per_block": 65536, + }, + "vulkan": { + "kind": "vulkan", + "max_threads_per_block": 1024, + "max_num_threads": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_int16": 1, + "supports_float32": 1, + "supports_int32": 1, + "supports_int8": 1, + "supports_16bit_buffer": 1, + "supports_float16": 1, + }, + "metal": "metal", + "wasm": "webgpu", + "android": "android", + "ios": "iphone", +} +DEVICE2SUFFIX = { + "cuda": "so", + "rocm": "so", + "vulkan": "so", + "metal": "dylib", + "wasm": "wasm", + "android": "tar", + "ios": "tar", +} +MODELS = list(MODEL_PRESETS.keys()) +QUANTS = [ # TODO(@junrushao): use `list(mlc_chat.quantization.QUANTIZATION.keys())` + "q0f16", + "q0f32", + "q3f16_1", + "q4f16_1", + "q4f32_1", +] +TENSOR_PARALLEL_SHARDS = [ + 1, +] + + +def run_command(log_file, cmd): + with open(log_file, "w", encoding="utf-8") as file: + subprocess.check_call( + cmd, + stdout=file, + stderr=subprocess.STDOUT, + ) + + +def test_model_compile(): # pylint: disable=too-many-locals + device = sys.argv[1] + num_workers = int(sys.argv[2]) + target = DEVICE2TARGET[device] + if not isinstance(target, str): + target = str(tvm.target.Target(target)) + suffix = DEVICE2SUFFIX[device] + + passed_cmds = [] + failed_cmds = [] + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: + with cf.ProcessPoolExecutor(max_workers=num_workers) as executor: + log_files = [] + cmds = [] + futures = [] + for idx, (model, quant, tp_shard) in enumerate( + product( + MODELS, + QUANTS, + TENSOR_PARALLEL_SHARDS, + ) + ): + log_file = os.path.join(tmp_dir, f"lib{idx}.log") + cmd = [ + sys.executable, + "-m", + "mlc_chat", + "compile", + model, + "--quantization", + quant, + "--overrides", + f"tensor_parallel_shards={tp_shard}", + "--device", + target, + "--opt", + OPT_LEVEL, + "-o", + os.path.join(tmp_dir, f"lib{idx}.{suffix}"), + ] + future = executor.submit(run_command, log_file, cmd) + log_files.append(log_file) + cmds.append(cmd) + futures.append(future) + for log_file, cmd, future in zip(log_files, cmds, futures): + cmd = shlex.join(cmd) + try: + future.result() + passed_cmds.append(cmd) + print(f"[PASS] {cmd}") + except Exception: # pylint: disable=broad-except + failed_cmds.append(cmd) + print("-------------------------------") + print(f"[FAIL] {cmd}") + with open(log_file, "r", encoding="utf-8") as file: + print(file.read()) + print("-------------------------------") + print("-------------------------------") + print(f"Total {len(passed_cmds)} passed, {len(failed_cmds)} failed.") + print("-------------------------------") + print("Passed commands:") + for cmd in passed_cmds: + print(cmd) + if failed_cmds: + print("-------------------------------") + print("Failed commands:") + for cmd in failed_cmds: + print(cmd) + sys.exit(1) + + +if __name__ == "__main__": + test_model_compile() diff --git a/tests/python/loader/test_awq.py b/tests/python/loader/test_awq.py new file mode 100644 index 0000000..d945a95 --- /dev/null +++ b/tests/python/loader/test_awq.py @@ -0,0 +1,40 @@ +# pylint: disable=missing-docstring +from pathlib import Path +from typing import Union + +import pytest +import tvm + +from mlc_chat.loader import HuggingFaceLoader +from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support import logging, tqdm + +logging.enable_logging() + + +@pytest.mark.parametrize( + "param_path", + [ + "./dist/models/llama-2-7b-w4-g128-awq.pt", + "./dist/models/Llama-2-7B-AWQ/model.safetensors", + ], +) +def test_load_llama(param_path: Union[str, Path]): + path_params = Path(param_path) + + model = MODELS["llama"] + quantization = QUANTIZATION["q4f16_awq"] + config = model.config.from_dict(MODEL_PRESETS["llama2_7b"]) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["awq"](config, quantization), + ) + with tqdm.redirect(): + for _name, _param in loader.load(tvm.device("cpu")): + ... + + +if __name__ == "__main__": + test_load_llama(param_path="./dist/models/llama-2-7b-w4-g128-awq.pt") + test_load_llama(param_path="./dist/models/Llama-2-7B-AWQ/model.safetensors") diff --git a/tests/python/loader/test_huggingface.py b/tests/python/loader/test_huggingface.py new file mode 100644 index 0000000..dfbef55 --- /dev/null +++ b/tests/python/loader/test_huggingface.py @@ -0,0 +1,69 @@ +# pylint: disable=missing-docstring +from pathlib import Path +from typing import Union + +import pytest +import tvm + +from mlc_chat.loader import HuggingFaceLoader +from mlc_chat.model import MODELS +from mlc_chat.support import logging, tqdm + +logging.enable_logging() + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_torch_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](config, None), + ) + with tqdm.redirect(): + for _name, _param in loader.load(device=tvm.device("cpu")): + return # To reduce the time of the test + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_safetensor_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "model.safetensors.index.json" + + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-safetensor"](config, None), + ) + with tqdm.redirect(): + for _name, _param in loader.load(device=tvm.device("cpu")): + return # To reduce the time of the test + + +if __name__ == "__main__": + test_load_torch_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-70b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-70b-hf") diff --git a/tests/python/model/test_gpt2.py b/tests/python/model/test_gpt2.py new file mode 100644 index 0000000..9517ad1 --- /dev/null +++ b/tests/python/model/test_gpt2.py @@ -0,0 +1,21 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["gpt2"]) +def test_gpt2_creation(model_name: str): + model_info = MODELS["gpt2"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_gpt2_creation("gpt2") diff --git a/tests/python/model/test_gptNeox.py b/tests/python/model/test_gptNeox.py new file mode 100644 index 0000000..d4fcfdd --- /dev/null +++ b/tests/python/model/test_gptNeox.py @@ -0,0 +1,21 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["redpajama_3b_v1"]) +def test_mistral_creation(model_name: str): + model_info = MODELS["gpt_neox"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_mistral_creation("redpajama_3b_v1") diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py new file mode 100644 index 0000000..970b7ba --- /dev/null +++ b/tests/python/model/test_kv_cache.py @@ -0,0 +1,207 @@ +# pylint: disable=line-too-long,missing-docstring +import tvm +from tvm import tir +from tvm.relax.frontend.nn import core, modules, spec +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + +from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode + +# mypy: disable-error-code="attr-defined" +# pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements + + +def test_nn_module_paged_kv_cache(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): # pylint: disable=too-many-arguments + T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, 96, 128), "float16") + position_map = T.match_buffer(var_position_map, (seq_len,), "int32") + q = T.match_buffer(var_q, (seq_len, 32, 128), "float16") + k = T.match_buffer(var_k, (seq_len, 32, 128), "float16") + v = T.match_buffer(var_v, (seq_len, 32, 128), "float16") + for iters_0, iters_1, iters_2 in T.grid(seq_len, 96, 128): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2]) + T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129]) + T.writes(q[s, h, d], k[s, h - 32, d], v[s, h - 64, d]) + if h < 32: + q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Cast("float16", T.cos(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * qkv[s, h, d] + T.Cast("float16", T.sin(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1), qkv[s, h, d - 64]), qkv[s, h, d]) + else: + if h < 64: + k[s, h - 32, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Cast("float16", T.cos(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * qkv[s, h, d] + T.Cast("float16", T.sin(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1), qkv[s, h, d - 64]), qkv[s, h, d]) + else: + v[s, h - 64, d] = qkv[s, h, d] + + @T.prim_func + def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + num_pages, page_size = T.int64(), T.int64(is_size_var=True) + pages = T.match_buffer(var_pages, (num_pages, 2, 32, page_size, 128), "float16") + seqlen = T.int64(is_size_var=True) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + k_data = T.match_buffer(var_k_data, (32, seqlen, 32, 128), "float16") + v_data = T.match_buffer(var_v_data, (32, seqlen, 32, 128), "float16") + for p, h, d in T.grid(seqlen, 32, 128): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] # type: ignore[name-defined] + k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] + v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] + + @T.prim_func + def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, 32, 16, 128), "float16") + ntoken = T.int64(is_size_var=True) + k_data = T.match_buffer(var_k_data, (ntoken, 32, 128), "float16") + v_data = T.match_buffer(var_v_data, (ntoken, 32, 128), "float16") + position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + # with T.block("root"): + for global_pos, h, f in T.grid(ntoken, 32, 128): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[no-redef] + pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[no-redef] + pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] + + @T.prim_func + def tir_rotary(var_q: T.handle, var_k: T.handle, var_append_len_indptr: T.handle, var_rope_offsets: T.handle, _0: T.int32, _1: T.int32, _2: T.int32, _3: T.int32, _4: T.int32, _5: T.float32, _6: T.float32): + T.func_attr({"tir.is_scheduled": 1}) + total_len = T.int32() + q = T.match_buffer(var_q, (total_len, 32, 128), "float16") + k = T.match_buffer(var_k, (total_len, 32, 128), "float16") + batch_size = T.int32() + append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size + 1,), "int32") + rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32") + with T.block(""): + T.reads() + T.writes() + for b_h in T.thread_binding(batch_size * 64, thread="blockIdx.x"): # pylint: disable=too-many-nested-blocks + b: T.int32 = b_h // 64 + h: T.int32 = b_h % 64 + instance_offset: T.int32 = append_len_indptr[b] + rope_offset: T.int32 = rope_offsets[b] + append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] + for s0 in range((append_len + 31) // 32): + for s1 in T.thread_binding(32, thread="threadIdx.y"): + for d0 in T.thread_binding(32, thread="threadIdx.x"): + for d1 in T.vectorized(4): + s: T.int32 = s0 * 32 + s1 + d: T.int32 = d0 * 4 + d1 + if s < append_len and d < 128: + if h < 32: + q[s + instance_offset, h, d] = T.Cast("float16", T.cos(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * q[s + instance_offset, h, d] + T.Cast("float16", T.sin(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, q[s + instance_offset, h, d + 64] * T.float16(-1), q[s + instance_offset, h, d - 64]) + else: + k[s + instance_offset, h - 32, d] = T.Cast("float16", T.cos(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * k[s + instance_offset, h - 32, d] + T.Cast("float16", T.sin(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, k[s + instance_offset, h - 32, d + 64] * T.float16(-1), k[s + instance_offset, h - 32, d - 64]) + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() # type: ignore + lv: R.Tuple(R.Object) = (_io,) # type: ignore + gv: R.Tuple(R.Object) = lv # type: ignore + R.output(gv) + return gv + + @R.function + def create_flashinfer_paged_kv_cache(max_batch_size: R.Shape(["max_batch_size_1"]), max_total_seq_len: R.Shape(["max_total_seq_len_1"]), prefill_chunk_size: R.Shape(["prefill_chunk_size_1"]), page_size: R.Shape(["page_size_1"]), _io: R.Object) -> R.Tuple(R.Object, R.Tuple(R.Object)): + max_batch_size_1 = T.int64() + max_total_seq_len_1 = T.int64() + prefill_chunk_size_1 = T.int64() + page_size_1 = T.int64() + R.func_attr({"num_input": 5}) + cls = Module + with R.dataflow(): + lv2: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16") # type: ignore + paged_kv_cache: R.Object = R.call_packed("vm.builtin.paged_attention_kv_cache_create", R.shape([max_batch_size_1, max_total_seq_len_1, prefill_chunk_size_1, page_size_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(0), R.prim_value(1), R.prim_value(10000), lv2, cls.tir_kv_cache_transpose_append, R.ExternFunc("paged_kv_cache.attention_kernel_prefill"), R.ExternFunc("paged_kv_cache.attention_kernel_decode"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_end_forward"), R.ExternFunc("flashinfer.merge_state_in_place"), cls.fused_rope, cls.tir_rotary, cls.tir_kv_cache_debug_get_kv, sinfo_args=(R.Object,)) + gv2: R.Tuple(R.Object, R.Tuple(R.Object)) = paged_kv_cache, (_io,) # type: ignore + R.output(gv2) + return gv2 + + @R.function + def forward(cache: R.Object, q: R.Tensor((1, 100, 32, 128), dtype="float16"), k: R.Tensor((1, 100, 32, 128), dtype="float16"), v: R.Tensor((1, 100, 32, 128), dtype="float16"), _io: R.Object) -> R.Tuple(R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object)): + R.func_attr({"num_input": 5}) + with R.dataflow(): + reshape: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(q, R.shape([100, 32, 128])) # type: ignore + reshape1: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(k, R.shape([100, 32, 128])) # type: ignore + reshape2: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(v, R.shape([100, 32, 128])) # type: ignore + lv1 = R.call_dps_packed("vm.builtin.paged_attention_kv_cache_attention", (cache, R.prim_value(0), reshape, reshape1, reshape2), out_sinfo=R.Tensor((100, 32, 128), dtype="float16")) + reshape3: R.Tensor((1, 100, 32, 128), dtype="float16") = R.reshape(lv1, R.shape([1, 100, 32, 128])) # type: ignore + gv1: R.Tuple(R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object)) = reshape3, (_io,) # type: ignore + R.output(gv1) + return gv1 + # fmt: on + + class PagedKVCacheTest(modules.Module): + def forward( + self, + cache: PagedKVCache, + q: core.Tensor, + k: core.Tensor, + v: core.Tensor, + ) -> core.Tensor: + return cache.attention(0, q, k, v) + + def create_flashinfer_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return FlashInferPagedKVCache( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=10000, + rotary_dim=128, + dtype="float16", + target=tvm.target.Target("cuda"), + ) + + export_results = PagedKVCacheTest().export_tvm( + spec={ + "forward": { + "cache": spec.Object(object_type=PagedKVCache), + "q": spec.Tensor((1, 100, 32, 128), "float16"), + "k": spec.Tensor((1, 100, 32, 128), "float16"), + "v": spec.Tensor((1, 100, 32, 128), "float16"), + }, + "create_flashinfer_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + }, + }, + debug=True, + ) + tvm_mod = export_results[0] + tvm.ir.assert_structural_equal(tvm_mod, Module, True) + + +if __name__ == "__main__": + test_nn_module_paged_kv_cache() diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py new file mode 100644 index 0000000..8ea682f --- /dev/null +++ b/tests/python/model/test_llama.py @@ -0,0 +1,23 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) +def test_llama2_creation(model_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_llama2_creation("llama2_7b") + test_llama2_creation("llama2_13b") + test_llama2_creation("llama2_70b") diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py new file mode 100644 index 0000000..4d4c761 --- /dev/null +++ b/tests/python/model/test_llama_quantization.py @@ -0,0 +1,73 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.quantization.group_quantization import ( + GroupQuantizeEmbedding, + GroupQuantizeLinear, +) + + +@pytest.mark.parametrize( + "model_name", + ["llama2_7b", "llama2_13b", "llama2_70b"], +) +@pytest.mark.parametrize( + "quant_name", + ["q3f16_1", "q4f16_1", "q4f32_1"], +) +def test_llama2_group_quantization(model_name: str, quant_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) + assert "model.embed_tokens.weight" in quant_map.param_map + assert isinstance( + model.model.embed_tokens, # type: ignore[attr-defined] + GroupQuantizeEmbedding, + ) + assert "lm_head.weight" in quant_map.param_map + assert isinstance(model.lm_head, GroupQuantizeLinear) # type: ignore[attr-defined] + for i in range(config.num_hidden_layers): + assert f"model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map + assert isinstance( + model.model.layers[i].self_attn.qkv_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + assert f"model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map + assert isinstance( + model.model.layers[i].self_attn.o_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + assert f"model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map + assert isinstance( + model.model.layers[i].mlp.gate_up_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + assert f"model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map + assert isinstance( + model.model.layers[i].mlp.down_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + + +@pytest.mark.parametrize( + "model_name", + ["llama2_7b", "llama2_13b", "llama2_70b"], +) +@pytest.mark.parametrize( + "quant_name", + ["q0f16", "q0f32"], +) +def test_llama2_no_quantization(model_name: str, quant_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + _, quant_map = model_info.quantize["no-quant"](config, QUANTIZATION[quant_name]) + assert len(quant_map.param_map) == 0 + assert len(quant_map.map_func) == 0 + + +if __name__ == "__main__": + test_llama2_group_quantization("llama2_7b", "q4f16_1") + test_llama2_group_quantization("llama2_13b", "q4f16_1") + test_llama2_group_quantization("llama2_70b", "q4f16_1") diff --git a/tests/python/model/test_mistral.py b/tests/python/model/test_mistral.py new file mode 100644 index 0000000..631b592 --- /dev/null +++ b/tests/python/model/test_mistral.py @@ -0,0 +1,21 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["mistral_7b"]) +def test_mistral_creation(model_name: str): + model_info = MODELS["mistral"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_mistral_creation("mistral_7b") diff --git a/tests/python/model/test_phi.py b/tests/python/model/test_phi.py new file mode 100644 index 0000000..e3f55f2 --- /dev/null +++ b/tests/python/model/test_phi.py @@ -0,0 +1,22 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.model import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["phi-1_5", "phi-2"]) +def test_phi_creation(model_name: str): + model_info = MODELS["phi-msft"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_phi_creation("phi-1_5") + test_phi_creation("phi-2") diff --git a/tests/python/quantization/test_awq_quantization.py b/tests/python/quantization/test_awq_quantization.py new file mode 100644 index 0000000..244271a --- /dev/null +++ b/tests/python/quantization/test_awq_quantization.py @@ -0,0 +1,89 @@ +# pylint: disable=invalid-name,missing-docstring +from typing import List + +import numpy as np +import pytest +import torch +import tvm +import tvm.testing +from tvm import DataType +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import QUANTIZATION, AWQQuantize + + +def dequantize_np( + config: AWQQuantize, + weight: np.ndarray, + zeros: np.ndarray, + scale: np.ndarray, +) -> np.ndarray: + def decode_int_arr(int_arr: np.ndarray, num_elem_per_storage: int, bits: int): + bin_mask = (1 << bits) - 1 + int_arr_repeated = np.repeat(int_arr, num_elem_per_storage, axis=-1) + indice_j = np.indices(int_arr_repeated.shape)[1] + arr_bin = np.bitwise_and( + np.right_shift( + int_arr_repeated, + (indice_j % num_elem_per_storage) * bits, + ), + bin_mask, + ) + return arr_bin + + weight_bin = decode_int_arr( + weight, config.num_elem_per_storage, DataType(config.quantize_dtype).bits + ) + zero_bin = decode_int_arr( + zeros, config.num_elem_per_storage, DataType(config.quantize_dtype).bits + ) + scale_repeated = np.repeat(scale, config.group_size, axis=-1) + zero_bin_repeated = np.repeat(zero_bin, config.group_size, axis=-1) + return (weight_bin - zero_bin_repeated) * scale_repeated + + +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q4f16_awq", [2, 4096], "float16"), + ], +) +def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, AWQQuantize) + weight_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], shape[1] // config.num_elem_per_storage), + ).astype(config.storage_dtype) + zeros_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], shape[1] // config.num_elem_per_storage // config.group_size), + ).astype(config.storage_dtype) + scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( + config.model_dtype + ) + mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") + mod.linear.qweight.data = weight_np + mod.linear.qzeros.data = zeros_np + mod.linear.scales.data = scale_np + model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) + out = model["forward"]( + torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member + ) + ref = dequantize_np(config, weight_np, zeros_np, scale_np).T + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_dequantize_weight("q4f16_awq", [2, 4096], "float16") diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py new file mode 100644 index 0000000..72133ff --- /dev/null +++ b/tests/python/quantization/test_group_quantization.py @@ -0,0 +1,189 @@ +# pylint: disable=invalid-name,missing-docstring +from typing import List + +import numpy as np +import pytest +import torch +import tvm +import tvm.testing +from tvm import DataType +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.quantization.group_quantization import ( + GroupQuantize, + GroupQuantizeEmbedding, + GroupQuantizeLinear, +) + + +def quantize_np(config: GroupQuantize, weight: np.ndarray): + n, k = weight.shape + weight_padded = np.pad( + weight, ((0, 0), (0, (config.group_size - k % config.group_size) % config.group_size)) + ) + n, k = weight_padded.shape + weight_reshaped = np.reshape(weight_padded, (n, k // config.group_size, config.group_size)) + max_abs = np.maximum(np.max(np.abs(weight_reshaped), axis=-1), 1e-4) + scale = np.divide(max_abs, config.max_int_value) + scale_reshaped = np.reshape(scale, (*scale.shape, 1)) + weight_scaled_reshaped = np.clip( + np.add( + np.round(np.divide(weight_reshaped, scale_reshaped)), + config.max_int_value, + ), + 0, + config.max_int_value * 2, + ).astype(config.storage_dtype) + weight_filtered = np.reshape(weight_scaled_reshaped, (n, k)) + weight_filtered[..., weight.shape[1] :] = 0 + weight_scaled = np.reshape( + weight_filtered, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) + ) + indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1] + quantized_weight = np.sum( + np.left_shift(weight_scaled, indice_k * DataType(config.quantize_dtype).bits), + axis=-1, + dtype=config.storage_dtype, + ) + return quantized_weight, scale + + +def dequantize_np( + config: GroupQuantize, + weight: np.ndarray, + scale: np.ndarray, + out_shape: List[int] = None, +): + assert weight.shape[0] == scale.shape[0] + bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1 + max_int = config.max_int_value + out_shape = ( + [weight.shape[0], weight.shape[1] * config.num_elem_per_storage] + if out_shape is None + else out_shape + ) + weight_repeated = np.repeat(weight, config.num_elem_per_storage, axis=-1) + scale_repeated = np.repeat(scale, config.group_size, axis=-1) + indice_j = np.indices(weight_repeated.shape)[1] + weight_bin = np.bitwise_and( + np.right_shift( + weight_repeated, + (indice_j % config.num_elem_per_storage) * DataType(config.quantize_dtype).bits, + ), + bin_mask, + ) + assert weight_bin.shape[1] <= scale_repeated.shape[1] + return ((weight_bin - max_int) * scale_repeated[..., : weight_bin.shape[1]])[ + : out_shape[0], : out_shape[1] + ] + + +@pytest.mark.parametrize( + "quant_name, shape, dtype, device", + [ + ("q3f16_1", [2, 13], "float16", "cpu"), + ("q3f16_1", [16, 120], "float16", "cpu"), + ("q4f16_1", [2, 13], "float16", "cpu"), + ("q4f16_1", [16, 128], "float16", "cpu"), + ("q4f32_1", [2, 13], "float32", "cpu"), + ("q4f32_1", [16, 128], "float32", "cpu"), + ], +) +def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + weight_np = np.random.random(shape).astype(dtype) + output = config.quantize_weight(tvm.nd.array(weight_np, device=tvm.device(device))) + quantized_weight, scale = output[0].numpy(), output[1].numpy() + quantized_weight_ref, scale_ref = quantize_np(config, weight_np) + tvm.testing.assert_allclose(scale, scale_ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose( + dequantize_np(config, quantized_weight, scale, shape), + dequantize_np(config, quantized_weight_ref, scale_ref, shape), + rtol=1e-2 if quant_name.startswith("q3") else 1e-3, + atol=0.4 if quant_name.startswith("q3") else 0.2, + ) + + +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q3f16_1", [2, 13], "float16"), + ("q3f16_1", [16, 120], "float16"), + ("q4f16_1", [2, 13], "float16"), + ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [2, 13], "float32"), + ("q4f32_1", [16, 128], "float32"), + ], +) +def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + num_group = -(shape[1] // -config.group_size) + weight_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], config.num_storage_per_group * num_group), + ).astype(config.storage_dtype) + scale_np = np.random.random((shape[0], num_group)).astype(config.model_dtype) + mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") + mod.linear.q_weight.data = weight_np + mod.linear.q_scale.data = scale_np + model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) + out = model["forward"]( + torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member + ) + ref = dequantize_np(config, weight_np, scale_np, shape).T + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q3f16_1", [16, 128], "float16"), + ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [16, 128], "float32"), + ], +) +def test_quantize_model(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[0], shape[1], dtype=dtype) + self.embedding = nn.Embedding(shape[0], shape[1], dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + quant_map = QuantizeMapping({}, {}) + mod = config.quantize_model(Test(), quant_map, "model") + assert quant_map.param_map["model.linear.weight"] == [ + "model.linear.q_weight", + "model.linear.q_scale", + ] + assert quant_map.map_func["model.linear.weight"] == config.quantize_weight + assert isinstance(mod.linear, GroupQuantizeLinear) + assert quant_map.param_map["model.embedding.weight"] == [ + "model.embedding.q_weight", + "model.embedding.q_scale", + ] + assert quant_map.map_func["model.embedding.weight"] == config.quantize_weight + assert isinstance(mod.embedding, GroupQuantizeEmbedding) + + +if __name__ == "__main__": + test_quantize_weight("q4f16_1", [16, 128], "float16", "llvm") + test_quantize_model("q4f16_1", [16, 128], "float16") + test_dequantize_weight("q4f16_1", [16, 128], "float16") diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py new file mode 100644 index 0000000..26e9d9a --- /dev/null +++ b/tests/python/serve/benchmark.py @@ -0,0 +1,159 @@ +# pylint: disable=import-error,line-too-long,missing-docstring,no-member,too-many-locals +# type: ignore +import argparse +import json +import os +import random +import time +from typing import Any, Callable, List, Tuple + +import numpy as np +from transformers import AutoTokenizer + +from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_chat.serve.engine import ModelInfo + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model-lib-path", type=str, required=True) + # Download dataset from + # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + args.add_argument("--dataset", type=str, required=True) + args.add_argument("--device", type=str, default="auto") + args.add_argument("--num-prompts", type=int, default=500) + args.add_argument("--batch-size", type=int, default=80) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int) + args.add_argument("--seed", type=int, default=0) + + parsed = args.parse_args() + parsed.model = os.path.dirname(parsed.model_lib_path) + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + return parsed + + +def sample_requests( + dataset_path: str, num_requests: int, model_path: str +) -> Tuple[List[str], List[GenerationConfig]]: + """Sample requests from dataset. + Acknowledgement to the benchmark scripts in the vLLM project. + """ + tokenizer = AutoTokenizer.from_pretrained(model_path) + + with open(dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + + # Filter out the conversations with less than 2 turns. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + if len(data["conversations"]) >= 2 + ] + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) + + # Construct generation config. + prompts = [prompt for prompt, _, _ in sampled_requests] + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_tokens=output_len) + for _, _, output_len in sampled_requests + ] + return prompts, generation_config_list + + +def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3): + times = [] + for _ in range(num_runs): + start = time.perf_counter() + func(*args) + end = time.perf_counter() + times.append(end - start) + + return np.array(times) + + +def benchmark(args: argparse.Namespace): + random.seed(args.seed) + + # Initialize model loading info and KV cache config + model = ModelInfo(args.model, args.model_lib_path, args.device) + kv_cache_config = KVCacheConfig( + page_size=args.page_size, + max_num_sequence=args.batch_size, + max_total_sequence_length=args.max_total_seq_length, + ) + + # Create engine + engine = Engine(model, kv_cache_config) + # Sample prompts from dataset + prompts, generation_config = sample_requests(args.dataset, args.num_prompts, args.model) + # Engine statistics + num_runs = 1 + single_token_prefill_latency = [] + single_token_decode_latency = [] + engine_total_prefill_time = [] + engine_total_decode_time = [] + total_prefill_tokens = [] + total_decode_tokens = [] + + def engine_generate(): + engine.reset() + engine.generate(prompts, generation_config) + engine_stats = engine.stats() + single_token_prefill_latency.append(engine_stats["single_token_prefill_latency"]) + single_token_decode_latency.append(engine_stats["single_token_decode_latency"]) + engine_total_prefill_time.append(engine_stats["engine_total_prefill_time"]) + engine_total_decode_time.append(engine_stats["engine_total_decode_time"]) + total_prefill_tokens.append(engine_stats["total_prefill_tokens"]) + total_decode_tokens.append(engine_stats["total_decode_tokens"]) + + e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) + single_token_prefill_latency = np.array(single_token_prefill_latency) + single_token_decode_latency = np.array(single_token_decode_latency) + engine_total_prefill_time = np.array(engine_total_prefill_time) + engine_total_decode_time = np.array(engine_total_decode_time) + total_prefill_tokens = np.array(total_prefill_tokens) + total_decode_tokens = np.array(total_decode_tokens) + prefill_throughput = total_prefill_tokens / engine_total_prefill_time + decode_throughput = total_decode_tokens / engine_total_decode_time + overall_throughput = (total_prefill_tokens + total_decode_tokens) / e2e_latency + + print(args) + print(f"Average end-to-end latency: {e2e_latency.mean():.4f} seconds for the entire batch") + print(f"Single token prefill latency: {single_token_prefill_latency.mean() * 1e3:.4f} ms/tok") + print(f"Single token decode latency: {single_token_decode_latency.mean() * 1e3:.4f} ms/tok") + print(f"Engine prefill time: {engine_total_prefill_time.mean():.4f} s") + print(f"Engine decode time: {engine_total_decode_time.mean():.4f} s") + print(f"Request throughput: {args.num_prompts / e2e_latency.mean():.4f} req/s") + print(f"Prefill token throughput: {prefill_throughput.mean():.4f} tok/s") + print(f"Decode token throughput: {decode_throughput.mean():.4f} tok/s") + print(f"Overall token throughput: {overall_throughput.mean():.4f} tok/s") + + +if __name__ == "__main__": + ARGS = _parse_args() + benchmark(ARGS) diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py new file mode 100644 index 0000000..9fd21f6 --- /dev/null +++ b/tests/python/serve/evaluate_engine.py @@ -0,0 +1,76 @@ +# pylint: disable=line-too-long,missing-docstring +import argparse +import os +import random +from typing import List, Tuple + +from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_chat.serve.engine import ModelInfo + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model-lib-path", type=str) + args.add_argument("--device", type=str, default="auto") + args.add_argument("--batch-size", type=int, default=80) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int) + args.add_argument("--seed", type=int, default=0) + + parsed = args.parse_args() + parsed.model = os.path.dirname(parsed.model_lib_path) + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + assert parsed.max_total_seq_length >= 2048 + return parsed + + +def generate_requests( + num_requests: int, input_length: int, output_length: int +) -> Tuple[List[List[int]], List[GenerationConfig]]: + prompt_ids = [] + for _ in range(num_requests): + token_ids = [] + for _ in range(input_length): + token_ids.append(random.randint(0, 30000)) + prompt_ids.append(token_ids) + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_tokens=output_length) + ] * num_requests + return prompt_ids, generation_config_list + + +def benchmark(args: argparse.Namespace): + random.seed(args.seed) + + # Initialize model loading info and KV cache config + model = ModelInfo(args.model, args.model_lib_path, args.device) + kv_cache_config = KVCacheConfig( + page_size=args.page_size, + max_num_sequence=args.batch_size, + max_total_sequence_length=args.max_total_seq_length, + ) + + # Create engine + engine = Engine(model, kv_cache_config) + + print(args) + for num_requests in [1, 2, 4, 8, 16, 32, 64]: + if num_requests > args.batch_size: + continue + for input_length in [64, 128, 256, 512, 1024]: + if num_requests * input_length >= 16384: + continue + for output_length in [4]: + print(f"nreq={num_requests}\t" f"in={input_length}\t" f"out={output_length}") + prompt_ids, generation_config = generate_requests( + num_requests, input_length, output_length + ) + engine.reset() + engine.generate(prompt_ids, generation_config) + print() + + +if __name__ == "__main__": + ARGS = _parse_args() + benchmark(ARGS) diff --git a/tests/python/serve/json.ebnf b/tests/python/serve/json.ebnf new file mode 100644 index 0000000..fc3fb22 --- /dev/null +++ b/tests/python/serve/json.ebnf @@ -0,0 +1,22 @@ +# Adopted from https://www.crockford.com/mckeeman.html +main ::= element +value ::= object | array | string | number | "true" | "false" | "null" +object ::= "{" ws "}" | "{" members "}" +members ::= member | member "," members +member ::= ws string ws ":" element +array ::= "[" ws "]" | "[" elements "]" +elements ::= element | element "," elements +element ::= ws value ws +string ::= "\"" characters "\"" +characters ::= "" | character characters +character ::= [^"\\] | "\\" escape +escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" hex hex hex hex +hex ::= [A-Fa-f0-9] +number ::= integer fraction exponent +integer ::= digit | onenine digits | "-" digit | "-" onenine digits +digits ::= digit | digit digits +digit ::= [0-9] +onenine ::= [1-9] +fraction ::= "" | "." digits +exponent ::= "" | ("e" | "E") ("" | "+" | "-") digits +ws ::= "" | "\u0020" ws | "\u000A" ws | "\u000D" ws | "\u0009" ws diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py new file mode 100644 index 0000000..004b148 --- /dev/null +++ b/tests/python/serve/server/conftest.py @@ -0,0 +1,35 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +import os +from typing import Tuple + +import pytest + +from mlc_chat.serve import PopenServer + + +@pytest.fixture(scope="session") +def served_model() -> Tuple[str, str]: + model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib_path is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' + "Please set it to model lib compiled by MLC LLM " + "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." + ) + model = os.path.dirname(model_lib_path) + return model, model_lib_path + + +@pytest.fixture(scope="session") +def launch_server(served_model): # pylint: disable=redefined-outer-name + """A pytest session-level fixture which launches the server in a subprocess.""" + server = PopenServer( + model=served_model[0], + model_lib_path=served_model[1], + enable_tracing=True, + ) + server.start() + yield + + # Fixture teardown code. + server.terminate() diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py new file mode 100644 index 0000000..a30b744 --- /dev/null +++ b/tests/python/serve/server/test_server.py @@ -0,0 +1,1036 @@ +"""Server tests in MLC LLM. +Before running any test, we use pytest fixtures to launch a +test-session-wide server in a subprocess, and then execute the tests. + +The recommended way to run the tests is to use the following command: + MLC_SERVE_MODEL_LIB="YOUR_MODEL_LIB" pytest -vv tests/python/serve/server/test_server.py + +Here "YOUR_MODEL_LIB" is a compiled model library like +`dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so`, +as long as the model is built with batching and embedding separation enabled. + +To directly run the Python file (a.k.a., not using pytest), you need to +launch the server in ahead before running this file. This can be done in +two steps: +- start a new shell session, run + python -m mlc_chat.serve.server --model "YOUR_MODEL_LIB" +- start another shell session, run this file + MLC_SERVE_MODEL_LIB="YOUR_MODEL_LIB" python tests/python/serve/server/test_server.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +from http import HTTPStatus +from typing import Dict, List, Optional, Tuple + +import pytest +import requests +from openai import OpenAI + +OPENAI_BASE_URL = "http://127.0.0.1:8000/v1" +OPENAI_V1_MODELS_URL = "http://127.0.0.1:8000/v1/models" +OPENAI_V1_COMPLETION_URL = "http://127.0.0.1:8000/v1/completions" +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" +DEBUG_DUMP_EVENT_TRACE_URL = "http://127.0.0.1:8000/debug/dump_event_trace" + + +def check_openai_nonstream_response( + response: Dict, + *, + is_chat_completion: bool, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + completion_tokens: Optional[int] = None, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] == finish_reason + + text: str + if not is_chat_completion: + assert isinstance(choice["text"], str) + text = choice["text"] + if echo_prompt is not None: + assert text + if suffix is not None: + assert text + else: + message = choice["message"] + assert message["role"] == "assistant" + assert isinstance(message["content"], str) + text = message["content"] + + if stop is not None: + for stop_str in stop: + assert stop_str not in text + if require_substr is not None: + for substr in require_substr: + assert substr in text + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + is_chat_completion: bool, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + completion_tokens: Optional[int] = None, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + if not is_chat_completion: + assert isinstance(choice["text"], str) + outputs[idx] += choice["text"] + else: + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + if not is_chat_completion: + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + if completion_tokens is not None: + assert usage["completion_tokens"] <= completion_tokens + + if not is_chat_completion: + if completion_tokens is not None: + assert responses[-1]["usage"]["completion_tokens"] == completion_tokens + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + + +def expect_error(response_str: str, msg_prefix: Optional[str] = None): + response = json.loads(response_str) + assert response["object"] == "error" + assert isinstance(response["message"], str) + if msg_prefix is not None: + assert response["message"].startswith(msg_prefix) + + +def test_openai_v1_models( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + response = requests.get(OPENAI_V1_MODELS_URL, timeout=60).json() + assert response["object"] == "list" + models = response["data"] + assert isinstance(models, list) + assert len(models) == 1 + + model_card = models[0] + assert isinstance(model_card, dict) + assert model_card["id"] == served_model[0], f"{model_card['id']} {served_model[0]}" + assert model_card["object"] == "model" + assert model_card["owned_by"] == "MLC-LLM" + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What is the meaning of life?" + max_tokens = 256 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_openai_package( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + client = OpenAI(base_url=OPENAI_BASE_URL, api_key="None") + prompt = "What is the meaning of life?" + max_tokens = 256 + response = client.completions.create( + model=served_model[0], + prompt=prompt, + max_tokens=max_tokens, + stream=stream, + ) + if not stream: + check_openai_nonstream_response( + response.model_dump(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + else: + responses = [] + for chunk in response: # pylint: disable=not-an-iterable + responses.append(chunk.model_dump()) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + + +def test_openai_v1_completions_invalid_requested_model( + launch_server, # pylint: disable=unused-argument +): + # `launch_server` is a pytest fixture defined in conftest.py. + + model = "unserved_model" + payload = { + "model": model, + "prompt": "What is the meaning of life?", + "max_tokens": 10, + } + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + expect_error( + response_str=response.json(), msg_prefix=f'The requested model "{model}" is not served.' + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_echo( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What is the meaning of life?" + max_tokens = 256 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "echo": True, + "stream": stream, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + echo_prompt=prompt, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + echo_prompt=prompt, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_suffix( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What is the meaning of life?" + suffix = "Hello, world!" + max_tokens = 256 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "suffix": suffix, + "stream": stream, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + suffix=suffix, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + suffix=suffix, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_stop_str( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # Choose "in" as the stop string since it is very unlikely that + # "in" does not appear in the generated output. + prompt = "What is the meaning of life?" + stop = ["in"] + max_tokens = 256 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stop": stop, + "stream": stream, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="stop", + stop=stop, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="stop", + stop=stop, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_temperature( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "temperature": 0.0, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_logit_bias( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # NOTE: This test only tests that the system does not break on logit bias. + # The test does not promise the correctness of logit bias handling. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_presence_frequency_penalty( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "frequency_penalty": 2.0, + "presence_penalty": 2.0, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + +def test_openai_v1_completions_seed( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": False, + "seed": 233, + } + + response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response2 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + for response in [response1, response2]: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + text1 = response1.json()["choices"][0]["text"] + text2 = response2.json()["choices"][0]["text"] + assert text1 == text2 + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_prompt_overlong( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + num_tokens = 17000 + prompt = [128] * num_tokens + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": 256, + "stream": stream, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + error_msg_prefix = ( + f"Request prompt has {num_tokens} tokens in total, larger than the model capacity" + ) + if not stream: + expect_error(response.json(), msg_prefix=error_msg_prefix) + else: + num_chunks = 0 + for chunk in response.iter_lines(chunk_size=512): + if not chunk: + continue + num_chunks += 1 + expect_error(json.loads(chunk.decode("utf-8")), msg_prefix=error_msg_prefix) + assert num_chunks == 1 + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_invalid_logprobs( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model[0], + "prompt": "What is the meaning of life?", + "max_tokens": 256, + "stream": stream, + "logprobs": False, + "top_logprobs": 4, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["msg"].endswith( + '"logprobs" must be True to support "top_logprobs"' + ) + + payload["logprobs"] = True + payload["top_logprobs"] = 6 + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["msg"].endswith('"top_logprobs" must be in range [0, 5]') + + +def test_openai_v1_completions_unsupported_args( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # Right now "best_of" is unsupported. + best_of = 2 + payload = { + "model": served_model[0], + "prompt": "What is the meaning of life?", + "max_tokens": 256, + "best_of": best_of, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + error_msg_prefix = 'Request fields "best_of" are not supported right now.' + expect_error(response.json(), msg_prefix=error_msg_prefix) + + +def test_openai_v1_completions_request_cancellation( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # Use a large max_tokens and small timeout to force timeouts. + payload = { + "model": served_model[0], + "prompt": "What is the meaning of life?", + "max_tokens": 2048, + "stream": False, + } + with pytest.raises(requests.exceptions.Timeout): + requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=1) + + # The server should still be alive after a request cancelled. + # We query `v1/models` to validate the server liveness. + response = requests.get(OPENAI_V1_MODELS_URL, timeout=60).json() + + assert response["object"] == "list" + models = response["data"] + assert isinstance(models, list) + assert len(models) == 1 + + model_card = models[0] + assert isinstance(model_card, dict) + assert model_card["id"] == served_model[0] + assert model_card["object"] == "model" + assert model_card["owned_by"] == "MLC-LLM" + + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [{"role": "user", "content": "Hello! Our project is MLC LLM."}], + # messages #1 + [ + {"role": "user", "content": "Hello! Our project is MLC LLM."}, + { + "role": "assistant", + "content": "Hello! It's great to hear about your project, MLC LLM.", + }, + {"role": "user", "content": "What is the name of our project?"}, + ], + # messages #2 + [ + { + "role": "system", + "content": "You are a helpful, respectful and honest assistant. " + "You always ends your response with an emoji.", + }, + {"role": "user", "content": "Hello! Our project is MLC LLM."}, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completions( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reason="stop", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completions_openai_package( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + client = OpenAI(base_url=OPENAI_BASE_URL, api_key="None") + response = client.chat.completions.create( + model=served_model[0], + messages=messages, + stream=stream, + logprobs=True, + top_logprobs=2, + ) + if not stream: + check_openai_nonstream_response( + response.model_dump(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reason="stop", + ) + else: + responses = [] + for chunk in response: # pylint: disable=not-an-iterable + responses.append(chunk.model_dump()) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_chat_completions_max_tokens( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + messages = [{"role": "user", "content": "Write a novel with at least 500 words."}] + max_tokens = 16 + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "max_tokens": max_tokens, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_chat_completions_ignore_eos( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + messages = [{"role": "user", "content": "Write a sentence with less than 20 words."}] + max_tokens = 128 + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "max_tokens": max_tokens, + "ignore_eos": True, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="length", + completion_tokens=max_tokens, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_chat_completions_system_prompt_wrong_pos( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + messages = [ + {"role": "user", "content": "Hello! Our project is MLC LLM."}, + { + "role": "system", + "content": "You are a helpful, respectful and honest assistant. " + "You always ends your response with an emoji.", + }, + ] + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + error_msg = "System prompt at position 1 in the message list is invalid." + if not stream: + expect_error(response.json(), msg_prefix=error_msg) + else: + num_chunks = 0 + for chunk in response.iter_lines(chunk_size=512): + if not chunk: + continue + num_chunks += 1 + expect_error(json.loads(chunk.decode("utf-8")), msg_prefix=error_msg) + assert num_chunks == 1 + + +def test_debug_dump_event_trace( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + # We only check that the request does not fail. + payload = {"model": served_model[0]} + response = requests.post(DEBUG_DUMP_EVENT_TRACE_URL, json=payload, timeout=60) + assert response.status_code == HTTPStatus.OK + + +if __name__ == "__main__": + model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib_path is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' + "Please set it to model lib compiled by MLC LLM " + "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." + ) + MODEL = (os.path.dirname(model_lib_path), model_lib_path) + + test_openai_v1_models(MODEL, None) + + test_openai_v1_completions(MODEL, None, stream=False) + test_openai_v1_completions(MODEL, None, stream=True) + test_openai_v1_completions_openai_package(MODEL, None, stream=False) + test_openai_v1_completions_openai_package(MODEL, None, stream=True) + test_openai_v1_completions_invalid_requested_model(None) + test_openai_v1_completions_echo(MODEL, None, stream=False) + test_openai_v1_completions_echo(MODEL, None, stream=True) + test_openai_v1_completions_suffix(MODEL, None, stream=False) + test_openai_v1_completions_suffix(MODEL, None, stream=True) + test_openai_v1_completions_stop_str(MODEL, None, stream=False) + test_openai_v1_completions_stop_str(MODEL, None, stream=True) + test_openai_v1_completions_temperature(MODEL, None, stream=False) + test_openai_v1_completions_temperature(MODEL, None, stream=True) + test_openai_v1_completions_logit_bias(MODEL, None, stream=False) + test_openai_v1_completions_logit_bias(MODEL, None, stream=True) + test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=False) + test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=True) + test_openai_v1_completions_seed(MODEL, None) + test_openai_v1_completions_prompt_overlong(MODEL, None, stream=False) + test_openai_v1_completions_prompt_overlong(MODEL, None, stream=True) + test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=False) + test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=True) + test_openai_v1_completions_unsupported_args(MODEL, None) + test_openai_v1_completions_request_cancellation(MODEL, None) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg) + test_openai_v1_chat_completions_openai_package(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completions_openai_package(MODEL, None, stream=True, messages=msg) + test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=False) + test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=True) + test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=False) + test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True) + test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False) + test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=True) + + test_debug_dump_event_trace(MODEL, None) diff --git a/tests/python/serve/server/test_server_function_call.py b/tests/python/serve/server/test_server_function_call.py new file mode 100644 index 0000000..3fff27b --- /dev/null +++ b/tests/python/serve/server/test_server_function_call.py @@ -0,0 +1,210 @@ +# pylint: disable=line-too-long +""" +Test script for function call in chat completion. To run this script, use the following command: +MLC_SERVE_MODEL_LIB=dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so +MLC_SERVE_MODEL_LIB=${MLC_SERVE_MODEL_LIB} python -m pytest -x tests/python/serve/server/test_server_function_call.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +from typing import Dict, List, Optional, Tuple + +import pytest +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" + + +def check_openai_nonstream_response( + response: Dict, + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: List[str], + completion_tokens: Optional[int] = None, +): + print(response) + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] in finish_reason + + # text: str + message = choice["message"] + assert message["role"] == "assistant" + if choice["finish_reason"] == "tool_calls": + assert message["content"] is None + assert isinstance(message["tool_calls"], list) + else: + assert message["tool_calls"] is None + assert message["content"] is not None + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] + + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + { + "role": "user", + "content": "What is the current weather in Pittsburgh, PA?", + } + ], + # messages #1 + [ + { + "role": "user", + "content": "What is the current weather in Pittsburgh, PA and Tokyo, JP?", + } + ], + # messages #2 + [ + { + "role": "user", + "content": "What is the current weather in Pittsburgh, PA in fahrenheit?", + } + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completion_function_call( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "tools": tools, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reason=["tool_calls", "error"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="tool_calls", + ) + + +if __name__ == "__main__": + model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib_path is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' + "Please set it to model lib compiled by MLC LLM " + "(e.g., `./dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so`) " + "which supports function calls." + ) + MODEL = (os.path.dirname(model_lib_path), model_lib_path) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completion_function_call(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completion_function_call(MODEL, None, stream=True, messages=msg) diff --git a/tests/python/serve/test_event_trace_recorder.py b/tests/python/serve/test_event_trace_recorder.py new file mode 100644 index 0000000..fb2a5f2 --- /dev/null +++ b/tests/python/serve/test_event_trace_recorder.py @@ -0,0 +1,44 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +import json + +from mlc_chat.serve.event_trace_recorder import EventTraceRecorder + + +def test_event_trace_recorder(): + trace_recorder = EventTraceRecorder() + request_ids = ["x", "y"] + num_decode = 5 + + for request_id in request_ids: + trace_recorder.add_event(request_id, event="start tokenization") + trace_recorder.add_event(request_id, event="finish tokenization") + trace_recorder.add_event(request_id, event="add request") + trace_recorder.add_event(request_id, event="start embed") + trace_recorder.add_event(request_id, event="finish embed") + trace_recorder.add_event(request_id, event="start prefill") + trace_recorder.add_event(request_id, event="finish prefill") + + for _ in range(num_decode): + for request_id in request_ids: + trace_recorder.add_event(request_id, event="start decode") + trace_recorder.add_event(request_id, event="finish decode") + for request_id in request_ids: + trace_recorder.add_event(request_id, event="start detokenization") + trace_recorder.add_event(request_id, event="finish detokenization") + + events = json.loads(trace_recorder.dump_json()) + decode_count = {} + for event in events: + request_id = event["tid"] + if event["name"].startswith("decode"): + if request_id not in decode_count: + decode_count[request_id] = 1 + else: + decode_count[request_id] += 1 + + for _, decode_cnt in decode_count.items(): + assert decode_cnt == num_decode * 2, decode_cnt + + +if __name__ == "__main__": + test_event_trace_recorder() diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py new file mode 100644 index 0000000..dd6cc64 --- /dev/null +++ b/tests/python/serve/test_grammar_parser.py @@ -0,0 +1,260 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +import os + +import pytest +import tvm.testing +from tvm._ffi.base import TVMError + +from mlc_chat.serve import BNFGrammar + + +def test_bnf_simple(): + before = """main ::= b c +b ::= "b" +c ::= "c" +""" + expected = """main ::= ((b c)) +b ::= (([b])) +c ::= (([c])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_ebnf(): + before = """main ::= b c | b main +b ::= "b"* +c ::= [acep-z]+ +d ::= "d"? +""" + expected = """main ::= ((b c) | (b main)) +b ::= [b]* +c ::= ((c_2)) +d ::= ((d_1)) +c_1 ::= (([acep-z])) +c_2 ::= ((c_1 c_2) | (c_1)) +d_1 ::= ("" | ([d])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_char(): + before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest +rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 +rest1 ::= "\?\"\'测试あc" "👀" "" +""" + expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) +rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) +rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) +""" + # Disable unwrap_nesting_rules to expose the result before unwrapping. + bnf_grammar = BNFGrammar.from_ebnf_string(before, False, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_space(): + before = """ + +main::="a" "b" ("c""d" +"e") | + +"f" | "g" +""" + expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_nest(): + before = """main::= "a" ("b" | "c" "d") | (("e" "f")) +""" + expected = """main ::= (([a] main_choice) | ([e] [f])) +main_choice ::= (([b]) | ([c] [d])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_flatten(): + before = """main ::= or_test sequence_test nested_test empty_test +or_test ::= ([a] | "b") | "de" | "" | or_test | [^a-z] +sequence_test ::= [a] "a" ("b" ("c" | "d")) ("d" "e") sequence_test "" +nested_test ::= ("a" ("b" ("c" "d"))) | ("a" | ("b" | "c")) | nested_rest +nested_rest ::= ("a" | ("b" "c" | ("d" | "e" "f"))) | ((("g"))) +empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" +""" + expected = """main ::= ((or_test sequence_test nested_test empty_test)) +or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) +sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) +nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) +nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) +empty_test ::= ("" | ([d]) | ([a])) +sequence_test_choice ::= (([c]) | ([d])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_json(): + current_file_path = os.path.abspath(__file__) + json_ebnf_path = os.path.join(os.path.dirname(current_file_path), "json.ebnf") + + with open(json_ebnf_path, "r", encoding="utf-8") as file: + before = file.read() + + expected = r"""main ::= ((element)) +value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) +object ::= (([{] ws [}]) | ([{] members [}])) +members ::= ((member) | (member [,] members)) +member ::= ((ws string ws [:] element)) +array ::= (([[] ws [\]]) | ([[] elements [\]])) +elements ::= ((element) | (element [,] elements)) +element ::= ((ws value ws)) +string ::= (([\"] characters [\"])) +characters ::= ("" | (character characters)) +character ::= (([^\"\\]) | ([\\] escape)) +escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) +hex ::= (([A-Fa-f0-9])) +number ::= ((integer fraction exponent)) +integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) +digits ::= ((digit) | (digit digits)) +digit ::= (([0-9])) +onenine ::= (([1-9])) +fraction ::= ("" | ([.] digits)) +exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) +ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) +exponent_choice ::= (([e]) | ([E])) +exponent_choice_1 ::= ("" | ([+]) | ([\-])) +""" + + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_to_string_roundtrip(): + """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" + + before = r"""main ::= (b c) | (b main) +b ::= b_1 d +c ::= c_1 +d ::= d_1 +b_1 ::= ([b] b_1) | "" +c_1 ::= (c_2 c_1) | c_2 +c_2 ::= [acep-z] +d_1 ::= [d] | "" +""" + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + output_string_1 = bnf_grammar_1.to_string() + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, True, False) + output_string_2 = bnf_grammar_2.to_string() + assert output_string_1 == output_string_2 + + +def test_error(): + with pytest.raises( + TVMError, match='TVMError: EBNF parse error at line 1, column 11: Rule "a" is not defined' + ): + BNFGrammar.from_ebnf_string("main ::= a b") + + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 15: Expect element" + ): + BNFGrammar.from_ebnf_string('main ::= "a" |') + + with pytest.raises(TVMError, match='TVMError: EBNF parse error at line 1, column 15: Expect "'): + BNFGrammar.from_ebnf_string('main ::= "a" "') + + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 1: Expect rule name" + ): + BNFGrammar.from_ebnf_string('::= "a"') + + with pytest.raises( + TVMError, + match="TVMError: EBNF parse error at line 1, column 12: Character class should not contain " + "newline", + ): + BNFGrammar.from_ebnf_string("main ::= [a\n]") + + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" + ): + BNFGrammar.from_ebnf_string(r'main ::= "\@"') + + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" + ): + BNFGrammar.from_ebnf_string(r'main ::= "\uFF"') + + with pytest.raises( + TVMError, + match="TVMError: EBNF parse error at line 1, column 14: Invalid character class: " + "lower bound is larger than upper bound", + ): + BNFGrammar.from_ebnf_string(r"main ::= [Z-A]") + + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 6: Expect ::=" + ): + BNFGrammar.from_ebnf_string(r'main := "a"') + + with pytest.raises( + TVMError, + match='TVMError: EBNF parse error at line 2, column 9: Rule "main" is defined multiple ' + "times", + ): + BNFGrammar.from_ebnf_string('main ::= "a"\nmain ::= "b"') + + with pytest.raises( + TVMError, + match='TVMError: EBNF parse error at line 1, column 10: There must be a rule named "main"', + ): + BNFGrammar.from_ebnf_string('a ::= "a"') + + +def test_to_json(): + before = """main ::= b c | b main +b ::= "bcd" +c ::= [a-z] +""" + expected = ( + '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' + ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," + '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' + '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' + ) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_json(False) + assert after == expected + + +def test_to_json_roundtrip(): + before = r"""main ::= ((b c) | (b main)) +b ::= ((b_1 d)) +c ::= ((c_1)) +d ::= ((d_1)) +b_1 ::= ("" | ([b] b_1)) +c_1 ::= ((c_2 c_1) | (c_2)) +c_2 ::= (([acep-z])) +d_1 ::= ("" | ([d])) +""" + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + output_json_1 = bnf_grammar_1.to_json(False) + bnf_grammar_2 = BNFGrammar.from_json(output_json_1) + output_json_2 = bnf_grammar_2.to_json(False) + output_str = bnf_grammar_2.to_string() + assert output_json_1 == output_json_2 + assert output_str == before + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher.py b/tests/python/serve/test_grammar_state_matcher.py new file mode 100644 index 0000000..cf7229a --- /dev/null +++ b/tests/python/serve/test_grammar_state_matcher.py @@ -0,0 +1,387 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking +from typing import List + +import pytest +import tvm +import tvm.testing + +from mlc_chat.serve import BNFGrammar, GrammarStateMatcher +from mlc_chat.tokenizer import Tokenizer + + +@pytest.fixture(scope="function") +def json_grammar(): + return BNFGrammar.get_grammar_of_json() + + +(json_input_accepted,) = tvm.testing.parameters( + ('{"name": "John"}',), + ('{ "name" : "John" } \n',), + ("{}",), + ("[]",), + ('{"name": "Alice", "age": 30, "city": "New York"}',), + ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), + ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), + ('[{"name": "David"}, {"name": "Sophia"}]',), + ( + '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' + ' "hasPets": false}', + ), + ( + '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' + '{"city": "Chicago", "zipcode": "60601"}}}', + ), + ( + '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' + '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', + ), + ( + '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' + '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' + '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', + ), + ( + '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' + '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' + '["CAD", "Project Management"], "projects": [{"name": "Project A", ' + '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', + ), +) + + +def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) + + +# test_json_accept(json_grammar(), '{"name": "John"}') +# exit() + +(json_input_refused,) = tvm.testing.parameters( + (r'{ name: "John" }',), + (r'{ "name": "John", "age": 30, }',), # x + (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), + (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), # x + (r'{ "name": "John", "age": 30.5.7 }',), + (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' + r'["hiking", "swimming",]}] }', # + ), + (r'{ "name": "John", "age": 30, "status": "\P\J" }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' + r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' + r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', + ), +) + + +def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): + assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) + + +(json_input_pressure,) = tvm.testing.parameters( + # Extra long string: 1k chars + ( + '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' + "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " + "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " + "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " + "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " + "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " + "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " + "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " + "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " + "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " + "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " + "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " + "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " + "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " + "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " + "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " + 'massa mattis sem, at interdum magna augue eget diam."]', + ), + # long and complex json: 3k chars + ( + r"""{ + "web-app": { + "servlet": [ + { + "servlet-name": "cofaxCDS", + "servlet-class": "org.cofax.cds.CDSServlet", + "init-param": { + "configGlossary:installationAt": "Philadelphia, PA", + "configGlossary:adminEmail": "ksm@pobox.com", + "configGlossary:poweredBy": "Cofax", + "configGlossary:poweredByIcon": "/images/cofax.gif", + "configGlossary:staticPath": "/content/static", + "templateProcessorClass": "org.cofax.WysiwygTemplate", + "templateLoaderClass": "org.cofax.FilesTemplateLoader", + "templatePath": "templates", + "templateOverridePath": "", + "defaultListTemplate": "listTemplate.htm", + "defaultFileTemplate": "articleTemplate.htm", + "useJSP": false, + "jspListTemplate": "listTemplate.jsp", + "jspFileTemplate": "articleTemplate.jsp", + "cachePackageTagsTrack": 200, + "cachePackageTagsStore": 200, + "cachePackageTagsRefresh": 60, + "cacheTemplatesTrack": 100, + "cacheTemplatesStore": 50, + "cacheTemplatesRefresh": 15, + "cachePagesTrack": 200, + "cachePagesStore": 100, + "cachePagesRefresh": 10, + "cachePagesDirtyRead": 10, + "searchEngineListTemplate": "forSearchEnginesList.htm", + "searchEngineFileTemplate": "forSearchEngines.htm", + "searchEngineRobotsDb": "WEB-INF/robots.db", + "useDataStore": true, + "dataStoreClass": "org.cofax.SqlDataStore", + "redirectionClass": "org.cofax.SqlRedirection", + "dataStoreName": "cofax", + "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", + "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", + "dataStoreUser": "sa", + "dataStorePassword": "dataStoreTestQuery", + "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", + "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", + "dataStoreInitConns": 10, + "dataStoreMaxConns": 100, + "dataStoreConnUsageLimit": 100, + "dataStoreLogLevel": "debug", + "maxUrlLength": 500 + } + }, + { + "servlet-name": "cofaxEmail", + "servlet-class": "org.cofax.cds.EmailServlet", + "init-param": { + "mailHost": "mail1", + "mailHostOverride": "mail2" + } + }, + { + "servlet-name": "cofaxAdmin", + "servlet-class": "org.cofax.cds.AdminServlet" + }, + { + "servlet-name": "fileServlet", + "servlet-class": "org.cofax.cds.FileServlet" + }, + { + "servlet-name": "cofaxTools", + "servlet-class": "org.cofax.cms.CofaxToolsServlet", + "init-param": { + "templatePath": "toolstemplates/", + "log": 1, + "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", + "logMaxSize": "", + "dataLog": 1, + "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", + "dataLogMaxSize": "", + "removePageCache": "/content/admin/remove?cache=pages&id=", + "removeTemplateCache": "/content/admin/remove?cache=templates&id=", + "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", + "lookInContext": 1, + "adminGroupID": 4, + "betaServer": true + } + } + ], + "servlet-mapping": { + "cofaxCDS": "/", + "cofaxEmail": "/cofaxutil/aemail/*", + "cofaxAdmin": "/admin/*", + "fileServlet": "/static/*", + "cofaxTools": "/tools/*" + }, + "taglib": { + "taglib-uri": "cofax.tld", + "taglib-location": "/WEB-INF/tlds/cofax.tld" + } + } +} """, + ), +) + + +def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) + + +(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( + ( + # short test + '{"id": 1,"name": "Example"} ', + [ + # fmt: off + 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + # fmt: on + ], + ), + ( + # long test + """{ +"id": 1, +"na": "ex", +"ac": True, +"t": ["t1", "t2"], +"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, +"res": "res" +} +""", + [ + # fmt: off + 31989, 31907, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 31910, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 31974, 31910, 31910, 278, 278, 278, 31973, + 31841, 31841, 31841, 31841, 31841, 31841, 31841, 31841, 271, 271, 31974, 31974, 31974, + 31974, 31974, 31974, 31974, 31974, 31910, 31910, 278, 278, 278, 31973, 31973, 31973, + 31973, 31973, 31973, 31973, 31973, 31841, 31841, 31903, 278, 278, 278, 278, 31973, + 31841, 31841, 31901, 278, 278, 278, 278, 31973, 31841, 31841, 270, 270, 270, 31968, + 31970, 31910, 31910, 278, 278, 278, 278, 31973, 31841, 31841, 31835, 31943, 31841, + 31841, 31943, 31841, 31841, 31943, 31970, 31974, 31910, 31910, 278, 278, 278, 278, + 31973, 31841, 31841, 271, 271, 271, 271, 31974, 31974, 31980, 31980 + # fmt: on + ], + ), +) + + +def test_find_rejected_tokens( + json_grammar: BNFGrammar, input_find_rejected_tokens: str, expected_rejected_sizes: List[int] +): + tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + tokenizer = Tokenizer(tokenizer_path) + grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + + real_sizes = [] + for c in input_find_rejected_tokens: + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + print("Accepting char:", c) + grammar_state_matcher.debug_accept_char(ord(c)) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + assert real_sizes == expected_rejected_sizes + + +def test_accept_token(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + + result = [] + + expected = [ + ["{"], + ['"', "}", "\n", " ", '"a":true'], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], + [":", "\n", " ", ':"'], + ['"', "{", "6", "\n", " "], + ["}", ", ", "6", "\n", " "], + [" ", "\n", '"', '"a":true'], + [" ", "\n", '"', '"a":true'], + ["}", ", ", "\n", " "], + ["", "\n", " "], + ["", "\n", " "], + ] + + for id in input_ids: + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + assert id in accepted + grammar_state_matcher.accept_token(id) + + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + + assert result == expected + + +def test_rollback(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) + + assert grammar_state_matcher.max_rollback_steps() == 5 + + input_ids_splitted = [input_ids[i : i + 2] for i in range(0, len(input_ids), 2)] + + for i_1, i_2 in input_ids_splitted: + orig_result = [] + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_1) + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_2) + grammar_state_matcher.rollback(2) + result_after_rollback = [] + result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_1) + result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_2) + assert orig_result == result_after_rollback + + +def test_reset(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + + orig_result = [] + + for i in input_ids: + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i) + + grammar_state_matcher.reset_state() + + result_after_reset = [] + + for i in input_ids: + result_after_reset.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i) + + assert orig_result == result_after_reset + + +if __name__ == "__main__": + # Run a benchmark to show the performance before running tests + test_find_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + '{"id": 1,"name": "Example"} ', + [ + # fmt: off + 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + # fmt: on + ], + ) + + tvm.testing.main() diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py new file mode 100644 index 0000000..df8e64b --- /dev/null +++ b/tests/python/serve/test_serve_async_engine.py @@ -0,0 +1,72 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import asyncio +from typing import List + +from mlc_chat.serve import AsyncThreadedEngine, GenerationConfig, KVCacheConfig +from mlc_chat.serve.engine import ModelInfo + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +async def test_engine_generate(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + # Create engine + async_engine = AsyncThreadedEngine(model, kv_cache_config) + + num_requests = 10 + max_tokens = 256 + generation_cfg = GenerationConfig(max_tokens=max_tokens) + + outputs: List[str] = ["" for _ in range(num_requests)] + + async def generate_task( + async_engine: AsyncThreadedEngine, + prompt: str, + generation_cfg: GenerationConfig, + request_id: str, + ): + print(f"generate task for request {request_id}") + rid = int(request_id) + async for delta_text, _, _, _ in async_engine.generate( + prompt, generation_cfg, request_id=request_id + ): + outputs[rid] += delta_text + + tasks = [ + asyncio.create_task( + generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i)) + ) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("All finished") + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}:{output}\n") + + async_engine.terminate() + del async_engine + + +if __name__ == "__main__": + asyncio.run(test_engine_generate()) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py new file mode 100644 index 0000000..89a113d --- /dev/null +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -0,0 +1,82 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import asyncio +from typing import List + +from mlc_chat.serve import ( + AsyncThreadedEngine, + EngineMode, + GenerationConfig, + KVCacheConfig, +) +from mlc_chat.serve.engine import ModelInfo + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +async def test_engine_generate(): + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + llm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + engine_mode = EngineMode(enable_speculative=True) + # Create engine + async_engine = AsyncThreadedEngine([llm, ssm], kv_cache_config, engine_mode) + + num_requests = 10 + max_tokens = 256 + generation_cfg = GenerationConfig(max_tokens=max_tokens) + + outputs: List[str] = ["" for _ in range(num_requests)] + + async def generate_task( + async_engine: AsyncThreadedEngine, + prompt: str, + generation_cfg: GenerationConfig, + request_id: str, + ): + print(f"generate task for request {request_id}") + rid = int(request_id) + async for delta_text, _, _, _ in async_engine.generate( + prompt, generation_cfg, request_id=request_id + ): + outputs[rid] += delta_text + + tasks = [ + asyncio.create_task( + generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i)) + ) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("All finished") + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}:{output}\n") + + async_engine.terminate() + del async_engine + + +if __name__ == "__main__": + asyncio.run(test_engine_generate()) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py new file mode 100644 index 0000000..373a97a --- /dev/null +++ b/tests/python/serve/test_serve_engine.py @@ -0,0 +1,392 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +from typing import Callable, List, Optional + +import numpy as np + +from mlc_chat.serve import ( + Engine, + GenerationConfig, + KVCacheConfig, + Request, + RequestStreamOutput, + data, +) +from mlc_chat.serve.engine import ModelInfo + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +def create_requests( + num_requests: int, + stop_token_id: Optional[int] = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + max_tokens_low: int = 256, + max_tokens_high: int = 257, +) -> List[Request]: + assert num_requests >= 0 and num_requests <= len(prompts) + + stop_token_ids = [stop_token_id] if stop_token_id is not None else [] + requests = [] + for req_id, prompt in zip(range(num_requests), prompts): + max_tokens = np.random.randint(max_tokens_low, max_tokens_high) + requests.append( + Request( + request_id=str(req_id), + inputs=data.TextData(prompt), + generation_config=GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + ), + ) + ) + return requests + + +def test_engine_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations). + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids + + # Create engine + engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() + if finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + assert fin_time == request.generation_config.max_tokens - 1 + + +def test_engine_continuous_batching_2(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the stop token. So each request keeps generating + until having the stop token or reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() + if finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + if fin_time < num_requests + max_tokens - 2: + print(f"Request {req_id} ends early on the stop token") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_3(): + """Test engine **with continuous batching**. + + - Add requests randomly between time [0, 200). + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` until all requests finish. + Then check the output of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens_low = 64 + max_tokens_high = 192 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + finished_requests: int = 0 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() + if finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + self.finished_requests += 1 + outputs[int(request_id)] += delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + def all_finished(self) -> bool: + return self.finished_requests == num_requests + + # Create engine + timer = CallbackTimer() + engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Assign the time to add requests to engine + request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] + + # Run steps + while not timer.all_finished(): + timer.step() + + # Add requests to engine + for req_id, add_time in enumerate(request_add_time): + if add_time == timer.timer: + print(f"add request {req_id} at step {timer.timer}") + engine.add_request(requests[req_id]) + + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Finish time: {fin_time}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_generate(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + # Create engine + engine = Engine(model, kv_cache_config) + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) + ) + for req_id, output in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}:{output}\n") + + +if __name__ == "__main__": + test_engine_basic() + test_engine_continuous_batching_1() + test_engine_continuous_batching_2() + test_engine_continuous_batching_3() + test_engine_generate() diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py new file mode 100644 index 0000000..1eee361 --- /dev/null +++ b/tests/python/serve/test_serve_engine_spec.py @@ -0,0 +1,372 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +from typing import Callable, List, Optional + +import numpy as np + +from mlc_chat.serve import ( + Engine, + EngineMode, + GenerationConfig, + KVCacheConfig, + Request, + RequestStreamOutput, + data, +) +from mlc_chat.serve.engine import ModelInfo + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +def create_requests( + num_requests: int, + stop_token_id: Optional[int] = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + max_tokens_low: int = 256, + max_tokens_high: int = 257, +) -> List[Request]: + assert num_requests >= 0 and num_requests <= len(prompts) + + stop_token_ids = [stop_token_id] if stop_token_id is not None else [] + requests = [] + for req_id, prompt in zip(range(num_requests), prompts): + max_tokens = np.random.randint(max_tokens_low, max_tokens_high) + requests.append( + Request( + request_id=str(req_id), + inputs=data.TextData(prompt), + generation_config=GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + ), + ) + ) + return requests + + +def test_engine_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + """ + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + engine_mode = EngineMode(enable_speculative=True) + + # Hyperparameters for tests (you can try different combinations). + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids + + # Create engine + engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + engine_mode = EngineMode(enable_speculative=True) + + # Hyperparameters for tests (you can try different combinations) + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() + if finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = Engine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + # assert fin_time == request.generation_config.max_tokens - 1 + + +def test_engine_generate(): + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + engine_mode = EngineMode(enable_speculative=True) + # Create engine + engine = Engine([model, ssm], kv_cache_config, engine_mode) + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) + ) + for req_id, output in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}:{output}\n") + + +def test_engine_efficiency(): + """Test engine speculative decoding efficiency.""" + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations). + num_requests = 1 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids + + # Create engine + engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for eg, name in zip([engine], ["Normal Deconding"]): + stats = eg.stats() + print("engine name:", name) + if name == "Speculative Decoding": + print("total draft tokens:", stats["total_draft_tokens"]) + print("total accepted tokens:", stats["total_accepted_tokens"]) + print( + "Accept rate:", + stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), + ) + print("engine total decode time:", stats["engine_total_decode_time"]) + print() + + +def test_engine_spec_efficiency(): + """Test engine speculative decoding efficiency.""" + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + # If Flashinfer allows head_dim < 128, we can test this model + # ssm = ModelInfo( + # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC", + # model_lib_path="dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so", + # ) + model = ModelInfo( + "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + engine_mode = EngineMode(enable_speculative=True, spec_draft_length=6) + + # Hyperparameters for tests (you can try different combinations). + num_requests = 1 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids + + # Create engine + spec_engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + spec_engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + spec_engine.step() + + for eg, name in zip([spec_engine], ["Speculative Decoding"]): + stats = eg.stats() + print("engine name:", name) + if name == "Speculative Decoding": + print("total draft tokens:", stats["total_draft_tokens"]) + print("total accepted tokens:", stats["total_accepted_tokens"]) + print( + "Accept rate:", + stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), + ) + print("engine total decode time:", stats["engine_total_decode_time"]) + print() + + +if __name__ == "__main__": + test_engine_basic() + test_engine_continuous_batching_1() + test_engine_generate() + test_engine_efficiency() + test_engine_spec_efficiency() diff --git a/tests/python/support/test_auto_config.py b/tests/python/support/test_auto_config.py new file mode 100644 index 0000000..77c6a0d --- /dev/null +++ b/tests/python/support/test_auto_config.py @@ -0,0 +1,41 @@ +# pylint: disable=missing-docstring +import json +import tempfile +from pathlib import Path + +import pytest + +from mlc_chat.support import logging +from mlc_chat.support.auto_config import detect_config + +logging.enable_logging() + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +def test_detect_config(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + + assert detect_config(base_path) == config_json_path + assert detect_config(config_json_path) == config_json_path + + +def test_detect_config_fail(): + with pytest.raises(ValueError): + detect_config(Path("do/not/exist")) + + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + assert detect_config(base_path) + + +if __name__ == "__main__": + test_detect_config() + test_detect_config_fail() diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py new file mode 100644 index 0000000..dfbefff --- /dev/null +++ b/tests/python/support/test_auto_weight.py @@ -0,0 +1,118 @@ +# pylint: disable=missing-docstring +import json +import os +import tempfile +from pathlib import Path + +import pytest + +from mlc_chat.support import logging +from mlc_chat.support.auto_weight import detect_weight + +logging.enable_logging() + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), + ], +) +def test_detect_weight(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + if index_filename is not None: + weight_index_file = base_path / index_filename + _create_json_file(weight_index_file, {}) + assert detect_weight(base_path, None, weight_format) == (weight_index_file, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), + ], +) +def test_detect_weight_in_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as config_dir, tempfile.TemporaryDirectory() as weight_dir: + config_path = Path(config_dir) + weight_path = Path(weight_dir) + config_json_path = config_path / "config.json" + _create_json_file(config_json_path, {"weight_path": weight_dir}) + if index_filename is not None: + weight_index_file = weight_path / index_filename + _create_json_file(weight_index_file, {}) + + assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), + ], +) +def test_detect_weight_same_dir_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + if index_filename is not None: + weight_index_file = Path(os.path.join(tmpdir, index_filename)) + _create_json_file(weight_index_file, {}) + assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) + + +def test_find_weight_fail(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + detect_weight(Path("do/not/exist"), base_path, "awq") + with pytest.raises(AssertionError): + detect_weight(None, Path("do/not/exist"), "awq") + + +if __name__ == "__main__": + test_detect_weight("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight("auto", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight("auto", "model.safetensors.index.json", "huggingface-safetensor") + test_detect_weight_in_config_json( + "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" + ) + test_detect_weight_in_config_json( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight_in_config_json( + "auto", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_same_dir_config_json( + "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" + ) + test_detect_weight_same_dir_config_json( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_same_dir_config_json( + "auto", "pytorch_model.bin.index.json", "huggingface-torch" + ) + test_detect_weight_same_dir_config_json( + "auto", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_find_weight_fail() diff --git a/tests/python/support/test_streamer.py b/tests/python/support/test_streamer.py new file mode 100644 index 0000000..4f51ea1 --- /dev/null +++ b/tests/python/support/test_streamer.py @@ -0,0 +1,198 @@ +"""Streamer tests in MLC LLM. + +Please specify the local path to llama2 tokenizer via environment +variable before running this test. +The recommended way to run the tests is to use the following command: + MLC_LLAMA_TOKENIZER_PATH="path/to/llama/tokenizer" \ + pytest -vv tests/python/support/test_text_streamer_stop_handler.py + +Here "MLC_LLAMA_TOKENIZER_PATH" can be chosen from +- a llama2 weight directory (e.g., "path/to/Llama-2-7b-chat-hf"), +- a sentencepiece llama2 tokenizer path + (e.g., "path/to/Llama-2-7b-chat-hf/tokenizer.model"). + +To directly run the Python file (a.k.a., not using pytest), you also need to +specify the tokenizer path via environment variable. +""" + +# pylint: disable=missing-function-docstring +import os +import time +from typing import List, Tuple + +import pytest + +from mlc_chat.streamer import StopStrHandler, TextStreamer +from mlc_chat.tokenizer import Tokenizer + +# fmt: off +para_input_tokens = [18585, 29892, 1244, 29915, 29879, 263, 3273, 14880, 1048, 953, 29877, 2397, + 29892, 988, 1269, 1734, 338, 5643, 491, 385, 953, 29877, 2397, 29901, 13, 13, + 29950, 1032, 727, 29991, 29871, 243, 162, 148, 142, 306, 29915, 29885, 1244, 304, + 1371, 1234, 738, 5155, 366, 505, 1048, 953, 29877, 2397, 29871, 243, 162, 167, 151, + 29889, 7440, 366, 1073, 393, 953, 29877, 2397, 508, 367, 1304, 304, 27769, 23023, + 1080, 322, 21737, 297, 263, 2090, 322, 1708, 1319, 982, 29973, 29871, 243, 162, 155, + 135, 2688, 508, 884, 367, 1304, 304, 788, 263, 6023, 310, 2022, 2877, 304, 596, 7191, + 322, 11803, 29889, 29871, 243, 162, 149, 152, 1126, 29892, 1258, 366, 1073, 393, 727, + 526, 1584, 953, 29877, 2397, 8090, 322, 14188, 366, 508, 1708, 29973, 29871, 243, 162, + 145, 177, 243, 162, 148, 131, 1105, 29892, 748, 14432, 322, 679, 907, 1230, 411, 953, + 29877, 2397, 29991, 29871, 243, 162, 149, 168, 243, 162, 145, 171] + +DECODED_PARAGRAPH = ( + "Sure, here's a short paragraph about emoji, " + "where each word is followed by an emoji:\n\n" + "Hey there! 👋 I'm here to help answer any questions you have about emoji 🤔. " + "Did you know that emoji can be used to convey emotions and feelings in a " + "fun and playful way? 😄 " + "They can also be used to add a touch of personality to your messages and posts. 💕 " + "And, did you know that there are even emoji games and activities you can play? 🎮👀 " + "So, go ahead and get creative with emoji! 💥🎨" +) +# fmt: on + + +def _get_tokenizer_path() -> str: + path = os.environ.get("MLC_LLAMA_TOKENIZER_PATH") + if path is None: + raise ValueError( + 'Environment variable "MLC_LLAMA_TOKENIZER_PATH" not found. ' + "Please set it to the a valid llama tokenizer path." + ) + return path + + +@pytest.fixture +def llama_tokenizer_path() -> str: + return _get_tokenizer_path() + + +def test_text_streamer(llama_tokenizer_path: str): # pylint: disable=redefined-outer-name + text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path)) + total_text = "" + for token in para_input_tokens: + total_text += text_streamer.put([token]) + total_text += text_streamer.finish() + + assert total_text == DECODED_PARAGRAPH + + +def stop_handler_process_tokens( + stop_handler: StopStrHandler, tokens: List[int], tokenizer: Tokenizer +) -> str: + returned_tokens = [] + for token in tokens: + returned_tokens += stop_handler.put(token) + if stop_handler.stop_triggered: + break + + if not stop_handler.stop_triggered: + returned_tokens += stop_handler.finish() + + return tokenizer.decode(returned_tokens) + + +def test_stop_str_handler_stop(llama_tokenizer_path: str): # pylint: disable=redefined-outer-name + stop_strs = [" 🤔"] + tokenizer = Tokenizer(llama_tokenizer_path) + stop_handler = StopStrHandler(stop_strs, tokenizer) + + total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer) + expected_text = ( + "Sure, here's a short paragraph about emoji, " + "where each word is followed by an emoji:\n\n" + "Hey there! 👋 I'm here to help answer any questions you have about emoji" + ) + + assert total_text == expected_text + + +def test_stop_str_handler_not_stop( + llama_tokenizer_path: str, # pylint: disable=redefined-outer-name +): + stop_strs = ["^^"] + tokenizer = Tokenizer(llama_tokenizer_path) + stop_handler = StopStrHandler(stop_strs, tokenizer) + + total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer) + assert total_text == DECODED_PARAGRAPH + + +def test_stop_str_handler_return_cached_tokens( + llama_tokenizer_path: str, # pylint: disable=redefined-outer-name +): + tokens = para_input_tokens[:26] # until "\n\n" + stop_strs = ["\n\n\n"] + tokenizer = Tokenizer(llama_tokenizer_path) + stop_handler = StopStrHandler(stop_strs, tokenizer) + + total_text = stop_handler_process_tokens(stop_handler, tokens, tokenizer) + expected_text = ( + "Sure, here's a short paragraph about emoji, " + "where each word is followed by an emoji:\n\n" + ) + + assert total_text == expected_text + + +def test_stop_str_handler_throughput( + llama_tokenizer_path: str, # pylint: disable=redefined-outer-name +): + stop_strs = ["[INST]"] + tokenizer = Tokenizer(llama_tokenizer_path) + stop_handler = StopStrHandler(stop_strs, tokenizer) + + tokens = para_input_tokens * 20 + returned_tokens = [] + + tbegin = time.perf_counter() + for token in tokens: + returned_tokens += stop_handler.put(token) + assert not stop_handler.stop_triggered + tend = time.perf_counter() + + throughput = len(tokens) / (tend - tbegin) + print( + f"num tokens = {len(tokens)}, " + f"time elapsed = {tend - tbegin:.5f} sec, " + f"throughput = {throughput}" + ) + assert throughput >= 100000 + + +emoji_tokens_expected_result = [ + # HF: "�����", SentencePiece: "�👀" + ([177, 243, 162, 148, 131], ("�����", "�👀")), + # Both: "👀👀" + ([243, 162, 148, 131, 243, 162, 148, 131], ("👀👀",)), + # Both: "👀👀👀" + ([243, 162, 148, 131, 243, 162, 148, 131, 243, 162, 148, 131], ("👀👀👀",)), + # HF: "👀�������", SentencePiece: "👀���👀" + ([243, 162, 148, 131, 162, 148, 131, 243, 162, 148, 131], ("👀�������", "👀���👀")), + # Both: "👀��� have👀" + ([243, 162, 148, 131, 162, 148, 131, 505, 243, 162, 148, 131], ("👀��� have👀",)), +] + + +@pytest.mark.parametrize("tokens_and_results", emoji_tokens_expected_result) +def test_text_streamer_emojis( + llama_tokenizer_path: str, tokens_and_results: Tuple[List[int], Tuple[str]] +): # pylint: disable=redefined-outer-name + text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path)) + total_text = "" + tokens, expected_results = tokens_and_results + for token in tokens: + total_text += text_streamer.put([token]) + total_text += text_streamer.finish() + assert total_text in expected_results + + +if __name__ == "__main__": + tokenizer_path = _get_tokenizer_path() + test_text_streamer(tokenizer_path) + test_stop_str_handler_stop(tokenizer_path) + test_stop_str_handler_not_stop(tokenizer_path) + test_stop_str_handler_return_cached_tokens(tokenizer_path) + test_stop_str_handler_throughput(tokenizer_path) + + for tokens_and_res in emoji_tokens_expected_result: + test_text_streamer_emojis(tokenizer_path, tokens_and_res) diff --git a/version.py b/version.py new file mode 100644 index 0000000..c7868f8 --- /dev/null +++ b/version.py @@ -0,0 +1,145 @@ +# pylint: disable=missing-docstring +import argparse +import logging +import os +import subprocess + +# Modify the following value during release +# --------------------------------------------------- +# Current version: +# We use the version of the incoming release for code +# that is under development. +# +# It is also fallback version to be used when --git-describe +# is not invoked, or when the repository does not present the +# git tags in a format that this script can use. +# +# Two tag formats are supported: +# - vMAJ.MIN.PATCH (e.g. v0.8.0) or +# - vMAJ.MIN.devN (e.g. v0.8.dev0) + +# --------------------------------------------------- + +__version__ = "0.1.dev0" +PROJ_ROOT = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + + +def py_str(cstr): + return cstr.decode("utf-8") + + +def git_describe_version(): + """Get PEP-440 compatible public and local version using git describe. + + Returns + ------- + pub_ver: str + Public version. + + local_ver: str + Local version (with additional label appended to pub_ver). + + Notes + ----- + - We follow PEP 440's convention of public version + and local versions. + - Only tags conforming to vMAJOR.MINOR.REV (e.g. "v0.7.0") + are considered in order to generate the version string. + See the use of `--match` in the `git` command below. + + Here are some examples: + + - pub_ver = '0.7.0', local_ver = '0.7.0': + We are at the 0.7.0 release. + - pub_ver = '0.8.dev94', local_ver = '0.8.dev94+g0d07a329e': + We are at the 0.8 development cycle. + The current source contains 94 additional commits + after the most recent tag(v0.7.0), + the git short hash tag of the current commit is 0d07a329e. + """ + cmd = [ + "git", + "describe", + "--tags", + "--match", + "v[0-9]*.[0-9]*.[0-9]*", + "--match", + "v[0-9]*.[0-9]*.dev[0-9]*", + ] + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=PROJ_ROOT, + ) as proc: + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = py_str(out) + logging.warning("git describe: %s", msg) + return None, None + describe = py_str(out).strip() + arr_info = describe.split("-") + + # Remove the v prefix, mainly to be robust + # to the case where v is not presented as well. + if arr_info[0].startswith("v"): + arr_info[0] = arr_info[0][1:] + + # hit the exact tag + if len(arr_info) == 1: + return arr_info[0], arr_info[0] + + if len(arr_info) != 3: + logging.warning("Invalid output from git describe %s", describe) + return None, None + + dev_pos = arr_info[0].find(".dev") + + # Development versions: + # The code will reach this point in case it can't match a full release version, such as v0.7.0. + # + # 1. in case the last known label looks like vMAJ.MIN.devN e.g. v0.8.dev0, we use + # the current behavior of just using vMAJ.MIN.devNNNN+gGIT_REV + if dev_pos != -1: + dev_version = arr_info[0][: arr_info[0].find(".dev")] + # 2. in case the last known label looks like vMAJ.MIN.PATCH e.g. v0.8.0 + # then we just carry on with a similar version to what git describe provides, which is + # vMAJ.MIN.PATCH.devNNNN+gGIT_REV + else: + dev_version = arr_info[0] + + pub_ver = f"{dev_version}.dev{arr_info[1]}" + local_ver = f"{pub_ver}+{arr_info[2]}" + return pub_ver, local_ver + + +def main(): + logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser(description="Detect and synchronize version.") + parser.add_argument( + "--print-version", + action="store_true", + help="Print version to the command line. No changes is applied to files.", + ) + parser.add_argument( + "--git-describe", + action="store_true", + help="Use git describe to generate development version.", + ) + parser.add_argument("--dry-run", action="store_true") + pub_ver, local_ver = git_describe_version() + opt = parser.parse_args() + pub_ver, local_ver = None, None + if opt.git_describe: + pub_ver, local_ver = git_describe_version() + if pub_ver is None: + pub_ver = __version__ + if local_ver is None: + local_ver = __version__ + if opt.print_version: + print(local_ver) + + +if __name__ == "__main__": + main()