Documentation |
Nightly Wheels |
---|---|
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The foundations of this project are described in the following MAPL2019 publication: Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. Please consider citing this work if you use Triton!
The official documentation contains installation instructions and tutorials. See also these third-party Triton puzzles, which can all be run using the Triton interpreter -- no GPU required.
You can install the latest stable release of Triton from pip:
pip install triton
Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9.
And the latest nightly release:
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
git clone https://github.com/triton-lang/triton.git;
cd triton;
pip install ninja cmake wheel pybind11; # build-time dependencies
pip install -e python
Or with a virtualenv:
git clone https://github.com/triton-lang/triton.git;
cd triton;
python -m venv .venv --prompt triton;
source .venv/bin/activate;
pip install ninja cmake wheel pybind11; # build-time dependencies
pip install -e python
Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build downloads a prebuilt LLVM, but you can also build LLVM from source and use that.
LLVM does not have a stable API, so the Triton build will not work at an arbitrary LLVM version.
-
Find the version of LLVM that Triton builds against. Check
cmake/llvm-hash.txt
to see the current version. For example, if it says: 49af6502c6dcb4a7f7520178bd14df396f78240cThis means that the version of Triton you have builds against LLVM 49af6502.
-
git checkout
LLVM at this revision. Optionally, make additional modifications to LLVM. -
Build LLVM. For example, you might run
$ cd $HOME/llvm-project # your clone of LLVM. $ mkdir build $ cd build $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" $ ninja
-
Grab a snack, this will take a while.
-
Build Triton as above, but set the following environment variables.
# Modify as appropriate to point to your LLVM build. $ export LLVM_BUILD_DIR=$HOME/llvm-project/build $ cd <triton install> $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ LLVM_SYSPATH=$LLVM_BUILD_DIR \ pip install -e python
-
Set
TRITON_BUILD_WITH_CLANG_LLD=true
as an environment variable to use clang and lld. lld in particular results in faster builds. -
Set
TRITON_BUILD_WITH_CCACHE=true
to build with ccache. -
Set
TRITON_HOME=/some/path
to change the location of the.triton
directory where Triton's cache is located and downloads are stored during the build. By default, this is the user's home directory. It can be changed anytime. -
Pass
--no-build-isolation
topip install
to make nop builds faster. Without this, every invocation ofpip install
uses a different symlink to cmake, and this forces ninja to rebuild most of the.a
files. -
vscode intellisense has some difficulty figuring out how to build Triton's C++ (probably because, in our build, users don't invoke cmake directly, but instead use setup.py). Teach vscode how to compile Triton as follows.
- Do a local build. Run command
pip install -e python
- Get the full path to the
compile_commands.json
file produced by the build:find python/build -name 'compile_commands.json' | xargs readlink -f
. You might get a full path similar to/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json
- In vscode, install the
C/C++
extension,
then open the command palette (
Shift + Command + P
on Mac, orShift + Ctrl + P
on Windows/Linux) and openC/C++: Edit Configurations (UI)
. - Open "Advanced Settings" and paste the full path to
compile_commands.json
into the "Compile Commands" textbox.
- Do a local build. Run command
There currently isn't a turnkey way to run all the Triton tests, but you can follow the following recipe.
# One-time setup. Note we have to reinstall local Triton because torch
# overwrites it with the public version.
$ pip install scipy numpy torch pytest lit pandas matplotlib && pip install -e python
# Run Python tests using your local GPU.
$ python3 -m pytest python/test/unit
# Move to builddir. Fill in <...> with the full path, e.g.
# `cmake.linux-x86_64-cpython-3.11`.
$ cd python/build/cmake<...>
# Run C++ unit tests.
$ ctest -j32
# Run lit tests.
$ lit test
You may find it helpful to make a symlink to the builddir and tell your local git to ignore it.
$ ln -s python/build/cmake<...> build
$ echo build >> .git/info/exclude
Then you can e.g. rebuild and run lit with the following command.
$ ninja -C build && ( cd build ; lit test )
For detailed instructions on how to debug Triton's frontend, please refer to this tutorial. The following includes additional tips for hacking on Triton's backend.
Helpful environment variables
-
MLIR_ENABLE_DUMP=1
dumps the IR before every MLIR pass Triton runs, for all kernels. UseMLIR_ENABLE_DUMP=kernelName
to dump for a specific kernel only.- Triton cache can interfere with the dump. In cases where
MLIR_ENABLE_DUMP=1
does not work, try cleaning your triton cache:rm -r ~/.triton/cache/*
- Triton cache can interfere with the dump. In cases where
-
LLVM_IR_ENABLE_DUMP=1
dumps the IR before every pass run over the LLVM IR. -
TRITON_REPRODUCER_PATH=<reproducer_path>
will generate an MLIR reproducer file at<reproducer_path>
before each MLIR compiler stage. If any of the stages fail,<reproducer_path>
will be a local MLIR reproducer captured right before the failing pass. -
TRITON_INTERPRET=1
uses the Triton interpreter instead of running on the GPU. You can insert Python breakpoints in your kernel code! -
TRITON_ENABLE_LLVM_DEBUG=1
passes-debug
to LLVM, printing a lot of debugging information to stdout. If this is too noisy, run with justTRITON_LLVM_DEBUG_ONLY
instead to limit the output.An alternative way to reduce output noisiness is running with
LLVM_IR_ENABLE_DUMP=1
, extract the IR before the LLVM pass of interest, and then run LLVM'sopt
standalone, perhaps passing-debug-only=foo
on the command line. -
TRITON_LLVM_DEBUG_ONLY=<comma-separated>
is the equivalent of LLVM's-debug-only
command-line option. This limits the LLVM debug output to specific pass or component names (which are specified using#define DEBUG_TYPE
throughout LLVM and Triton) in order to allow the debug output to be less noisy.TRITON_LLVM_DEBUG_ONLY
allows for one or more comma separated values to be specified (egTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions
orTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"
). -
USE_IR_LOC={ttir,ttgir}
reparses the IR such that the location information will be the line number of the IR file with that particular extension, instead of line number of the python file. This can provide a direct mapping from the IR to llir/ptx. When used with performance tools, it can provide a breakdown on IR instructions. -
TRITON_PRINT_AUTOTUNING=1
prints out the best autotuning config and total time spent for each kernel after autotuning is complete. -
DISABLE_LLVM_OPT
will disable llvm optimizations for make_llir and make_ptx if its value is true when parsing as Bool. Otherwise, it will be parsed as a list of flags to disable llvm optimizations. One usage case isDISABLE_LLVM_OPT="disable-lsr"
Loop strength reduction is known to cause up to 10% performance changes for certain kernels with register pressure. -
TRITON_ALWAYS_COMPILE=1
forces to compile kernels regardless of cache hit. -
MLIR_ENABLE_TIMING
dumps the timing information for each MLIR pass. -
LLVM_ENABLE_TIMING
dumps the timing information for each LLVM pass. -
TRITON_DEFAULT_FP_FUSION
overrides the default behavior of allowing fp fusion (mul+add->fma). -
MLIR_ENABLE_REMARK
enables the performance warnings that are emitted as remarks. -
TRITON_KERNEL_DUMP
enables the dumping of the IR from each compilation stage and the final ptx. -
TRITON_DUMP_DIR
specifies the directory to save the dumped IR and ptx whenTRITON_KERNEL_DUMP
is set to 1. -
TRITON_KERNEL_OVERRIDE
enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage. -
TRITON_OVERRIDE_DIR
specifies the directory from which to load the IR/ptx files whenTRITON_KERNEL_OVERRIDE
is set to 1.
Kernel Override Steps
export TRITON_ALWAYS_COMPILE=1
export TRITON_KERNEL_DUMP=1
export TRITON_DUMP_DIR=<dump_dir>
export TRITON_KERNEL_OVERRIDE=1
export TRITON_OVERRIDE_DIR=<override_dir>
# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR
# Step 2: Copy $TRITON_DUMP_DIR/<kernel_hash> to $TRITON_OVERRIDE_DIR
# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override
# Step 4: Run the kernel again to see the overridden result
Version 2.0 is out! New features include:
- Many, many bug fixes
- Performance improvements
- Backend rewritten to use MLIR
- Support for kernels that contain back-to-back matmuls (e.g., flash attention)
Community contributions are more than welcome, whether it be to fix bugs or to add new features at github. For more detailed instructions, please visit our contributor's guide.
Supported Platforms:
- Linux
Supported Hardware:
- NVIDIA GPUs (Compute Capability 8.0+)
- AMD GPUs (ROCm 5.2+)
- Under development: CPUs