Skip to content

Commit

Permalink
Merge branch 'triton-lang:main' into streamkv0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohuguo2023 authored Nov 6, 2024
2 parents a493381 + 1cf7b1b commit 6979834
Show file tree
Hide file tree
Showing 159 changed files with 9,357 additions and 3,358 deletions.
19 changes: 18 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ set(CMAKE_CXX_STANDARD 17)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

project(triton)
project(triton CXX)
include(CTest)

if(NOT WIN32)
Expand All @@ -26,8 +26,25 @@ option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")

if(TRITON_BUILD_WITH_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
CACHE STRING "CXX compiler launcher")
else()
message(
STATUS
"Could not find ccache. Consider installing ccache to speed up compilation."
)
endif()
endif()


# Ensure Python3 vars are set correctly
# used conditionally in this file and by lit tests

Expand Down
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ The Triton Conference is happening again on September 17th, 2024 in Fremont (CA)

If you are interested in attending, please fill up [this form](https://docs.google.com/forms/d/e/1FAIpQLSecHC1lkalcm0h3JDUbspekDX5bmBvMxgVTLaK3e-61bzDDbg/viewform).


| **`Documentation`** | **`Nightly Wheels`** |
|-------------------- | -------------------- |
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |


# Triton

This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
Expand All @@ -24,20 +22,21 @@ The [official documentation](https://triton-lang.org) contains installation inst

You can install the latest stable release of Triton from pip:

```bash
```shell
pip install triton
```

Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9.

And the latest nightly release:

```bash
```shell
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```

# Install from source

```
```shell
git clone https://github.com/triton-lang/triton.git;
cd triton;

Expand All @@ -47,7 +46,7 @@ pip install -e python

Or with a virtualenv:

```
```shell
git clone https://github.com/triton-lang/triton.git;
cd triton;

Expand Down Expand Up @@ -156,14 +155,14 @@ $ lit test
You may find it helpful to make a symlink to the builddir and tell your local
git to ignore it.

```
```shell
$ ln -s python/build/cmake<...> build
$ echo build >> .git/info/exclude
```

Then you can e.g. rebuild and run lit with the following command.

```
```shell
$ ninja -C build && ( cd build ; lit test )
```

Expand Down Expand Up @@ -217,6 +216,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
# Changelog

Version 2.0 is out! New features include:

- Many, many bug fixes
- Performance improvements
- Backend rewritten to use MLIR
Expand All @@ -226,13 +226,14 @@ Version 2.0 is out! New features include:

Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/triton-lang/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md).


# Compatibility

Supported Platforms:
* Linux

- Linux

Supported Hardware:
* NVIDIA GPUs (Compute Capability 7.0+)
* AMD GPUs (ROCm 5.2+)
* Under development: CPUs

- NVIDIA GPUs (Compute Capability 8.0+)
- AMD GPUs (ROCm 5.2+)
- Under development: CPUs
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b5cc222d7429fe6f18c787f633d5262fac2e676f
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def documenter(app, obj, parent):
autosummary_generate = True

# versioning config
smv_tag_whitelist = r'^(v3.0.0)$'
smv_tag_whitelist = r'^(v3.2.0)$'
smv_branch_whitelist = r'^main$'
smv_remote_whitelist = None
smv_released_pattern = r'^tags/.*$'
Expand Down
4 changes: 1 addition & 3 deletions docs/python-api/triton-semantics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ The algorithm is as follows:

2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``

3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``

3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.
3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16``

4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``

Expand Down
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Linear Algebra Ops
:nosignatures:

dot
dot_scaled


Memory/Pointer Ops
Expand Down
6 changes: 1 addition & 5 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down Expand Up @@ -214,10 +214,6 @@ bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,6 @@ namespace mlir::triton {

namespace gpu {

SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -187,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -209,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* A backend-specific callback for appending auxiliary data during
* `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
23 changes: 18 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}

// -----------------------------------------------------------------------
// MXFP utilities
// -----------------------------------------------------------------------

// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);

} // namespace LLVM

/* ------------------------------------ */
Expand Down Expand Up @@ -1366,11 +1379,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
// Type for ScaleDotElemType kind of floats.
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
"ScaleDotElemType", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

]>{
let cppNamespace = "::mlir::triton";
Expand Down
29 changes: 19 additions & 10 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";

let hasVerifier = 1;

let hasFolder = 1;
}

//
Expand Down Expand Up @@ -685,15 +687,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
// inputs are floats if we have a type for them, otherwise (fp4),
// they are packed in pairs in an I8Tensor
RankedTensorOf<[TT_Float,I8]>:$lhs,
RankedTensorOf<[TT_Float,I8]>:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
RankedTensorOf<[I8]>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);
Expand Down Expand Up @@ -776,7 +778,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable]> {

let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
Expand All @@ -788,6 +791,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
let results = (outs TT_Type:$result);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";

let extraClassDeclaration = [{
// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

}

//
Expand Down Expand Up @@ -891,7 +900,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
`tt.assert` takes a condition tensor and a message string.
If the condition is false, the message is printed, and the program is aborted.
}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}

Expand Down
Loading

0 comments on commit 6979834

Please sign in to comment.