Skip to content

Commit

Permalink
Add llama and resnet emitc tests
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT committed Feb 7, 2025
1 parent b8f6d2e commit e640031
Show file tree
Hide file tree
Showing 9 changed files with 3,248 additions and 6 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace mlir::tt::ttnn::utils {
// Get or insert device for the given operation.
GetDeviceOp getOrInsertDevice(mlir::PatternRewriter &rewriter,
GetDeviceOp getOrInsertDevice(mlir::RewriterBase &rewriter,
mlir::Operation *op);

// Helper method to insert a ToLayoutOp to convert the input operand to the
Expand Down
45 changes: 44 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,49 @@ class RepeatOpConversionPattern
}
};

// RepeatInterleave op conversion pattern
//
class RepeatInterleaveOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::RepeatInterleaveOp> {
public:
using TTNNToEmitCBaseOpConversionPattern<
ttnn::RepeatInterleaveOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(ttnn::RepeatInterleaveOp repeatInterleaveOp,
ttnn::RepeatInterleaveOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create operands vector
//
llvm::SmallVector<Value, 2> operands{
adaptor.getOperands()[0],
};

// Create ArrayAttr object holding attributes and pointers to operands
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // input tensor
repeatInterleaveOp.getRepeatsAttr(), repeatInterleaveOp.getDimAttr(),
repeatInterleaveOp.getMemoryConfig().has_value()
? (operands.append(1, ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter,
repeatInterleaveOp.getMemoryConfigAttr(),
repeatInterleaveOp.getLoc())
->getResult(0)),
mlir::cast<Attribute>(rewriter.getIndexAttr(1)))
: ttnn_to_emitc::utils::createStdNullopt(
rewriter), // ttnn::MemoryConfig
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
repeatInterleaveOp,
this->getTypeConverter()->convertType(repeatInterleaveOp.getType()),
this->convertOpName(repeatInterleaveOp), arrayAttrs, nullptr, operands);

return success();
}
};

// GetDeviceOp conversion pattern
//
namespace {
Expand Down Expand Up @@ -1322,7 +1365,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<TransposeOpConversionPattern, ConcatOpConversionPattern,
ReshapeOpConversionPattern, RepeatOpConversionPattern,
DefaultOpConversionPattern<ttnn::RepeatInterleaveOp>,
RepeatInterleaveOpConversionPattern,
DefaultOpConversionPattern<ttnn::SliceOp>,
DefaultOpConversionPattern<ttnn::PermuteOp>>(typeConverter, ctx);

Expand Down
18 changes: 16 additions & 2 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Analysis/Liveness.h"
Expand Down Expand Up @@ -214,11 +215,24 @@ class TTNNCreateInputGenerators

// Create a new tensor
//
mlir::Value tensorValue = rewriter.create<ttnn::OnesOp>(
ttnn::OnesOp onesOp = rewriter.create<ttnn::OnesOp>(
forwardFuncOp->getLoc(), tensorType, shapeAttr, dTypeAttr,
tensorLayoutAttr, nullptr, nullptr);

generatedTensors.push_back(tensorValue);
// If tensor is meant to be on device, add ToDevice op
//
if (layoutAttr.isDeviceBufferType()) {
ttnn::GetDeviceOp device =
ttnn::utils::getOrInsertDevice(rewriter, onesOp);

mlir::Value tensorOnDevice = rewriter.create<ttnn::ToDeviceOp>(
forwardFuncOp->getLoc(), tensorType, onesOp.getResult(),
device.getResult(), nullptr);

generatedTensors.push_back(tensorOnDevice);
} else {
generatedTensors.push_back(onesOp.getResult());
}
}

// Return the generated tensors
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Utils/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace mlir::tt::ttnn::utils {
// Gets or inserts a GetDeviceOp at the top of the current block of the given
// operation.
GetDeviceOp getOrInsertDevice(PatternRewriter &rewriter, Operation *op) {
GetDeviceOp getOrInsertDevice(RewriterBase &rewriter, Operation *op) {
Block *block = op->getBlock();
for (auto &op : block->getOperations()) {
if (auto deviceOp = dyn_cast<ttnn::GetDeviceOp>(op)) {
Expand Down
2,667 changes: 2,667 additions & 0 deletions test/ttmlir/EmitC/TTNN/models/llama_prefill.mlir

Large diffs are not rendered by default.

508 changes: 508 additions & 0 deletions test/ttmlir/EmitC/TTNN/models/resnet.mlir

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions test/ttmlir/EmitC/TTNN/tensor/repeat_interleave.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn
// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir
// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp

func.func @repeat_interleave(%arg0: tensor<4x6xf32>) -> tensor<4x24xf32> {
%0 = tensor.empty() : tensor<4x24xf32>
%1 = "ttir.repeat_interleave"(%arg0, %0) {repeats = 4 : ui32, dim = 1 : si32} : (tensor<4x6xf32>, tensor<4x24xf32>) -> tensor<4x24xf32>
return %1 : tensor<4x24xf32>
}
1 change: 0 additions & 1 deletion tools/ttnn-standalone/ci_compile_dylib.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def compile_shared_object(cpp_file_path, output_dir):
destination_path = os.path.join(output_dir, output_file_name)
shutil.copy2(compiled_so_path, destination_path)
print(f" Successfully copied compiled file to {destination_path}.")
os.remove(source_cpp_path)
except subprocess.CalledProcessError as e:
print(f" Error during build process: {e}")
print(e.stderr)
Expand Down
1 change: 1 addition & 0 deletions tools/ttnn-standalone/ttnn-precompiled.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "operations/creation.hpp"
#include "operations/data_movement/concat/concat.hpp"
#include "operations/data_movement/repeat/repeat.hpp"
#include "operations/data_movement/repeat_interleave/repeat_interleave.hpp"
#include "operations/data_movement/transpose/transpose.hpp"
#include "operations/eltwise/binary/binary.hpp"
#include "operations/eltwise/binary/binary_composite.hpp"
Expand Down

0 comments on commit e640031

Please sign in to comment.