Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler & runtime] move device_file_name from func.func to byre.compute #439

Merged
merged 11 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions compiler/include/byteir/Conversion/FuncToByre/FuncToByre.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ class ModuleOp;
void populateFuncToByreTensorPattern(RewritePatternSet &patterns,
bool appendArgTypes);

void populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns,
bool useBarePtrCallConv);
void populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns);

std::unique_ptr<OperationPass<ModuleOp>>
createConvertFuncToByreTensorPass(bool appendArgTypes = false);

std::unique_ptr<Pass>
createConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv = false);
std::unique_ptr<Pass> createConvertGPULaunchFuncToByrePass();

} // namespace mlir

Expand Down
12 changes: 4 additions & 8 deletions compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def GenPTXConfig : Pass<"gen-ptx-config", "func::FuncOp"> {
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers."
"All memrefs must have static shape">
"All memrefs must have static shape">,
Option<"fileName", "file-name", "std::string",
/*default*/"\"unified\"",
"The target ptx file for current PTXOp.">
];
}

Expand Down Expand Up @@ -353,13 +356,6 @@ def ConvertGPULaunchFuncToByre : Pass<"gpu-launch-func-to-byre"> {
"mlir::byre::ByreDialect",
"mlir::gpu::GPUDialect"
];

let options = [
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers."
"All memrefs must have static shape">,
];
}

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion compiler/include/byteir/Conversion/ToPTX/ToPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class FuncOp;
} // namespace func

std::unique_ptr<OperationPass<func::FuncOp>>
createGenPTXConfigPass(bool useBarePtrCallConv = false);
createGenPTXConfigPass(bool useBarePtrCallConv = false,
const std::string &fileName = "unified");

// TODO move to general GPU
std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
8 changes: 0 additions & 8 deletions compiler/include/byteir/Pipelines/ByreHost.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ struct ByreHostPipelineOptions
*this, "entry-func",
llvm::cl::desc("An optional string to speicify entry function."),
llvm::cl::init("main")};
Option<std::string> deviceFile{
*this, "device-file-name",
llvm::cl::desc("An optional string to speicify device file name."),
llvm::cl::init("kernel")};
Option<std::string> target{
*this, "target",
llvm::cl::desc("An optional attribute to target device."),
llvm::cl::init("")};
};

void createByreHostPipeline(OpPassManager &pm,
Expand Down
4 changes: 4 additions & 0 deletions compiler/include/byteir/Pipelines/GPU/GPUOpt.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct GPUOptPipelineOptions
llvm::cl::desc("An optional attribute to speicify whether using bare ptr "
"call convention."),
llvm::cl::init(false)};
Option<std::string> fileName{
*this, "device-file-name",
llvm::cl::desc("To specify the generated kernel will be written to."),
llvm::cl::init("device_kernel.ptx")};
};

void createGPUOptPipeline(OpPassManager &pm,
Expand Down
30 changes: 12 additions & 18 deletions compiler/lib/Conversion/FuncToByre/FuncToByre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ class ConvertGPULaunchFuncToByrePattern
: public OpRewritePattern<gpu::LaunchFuncOp> {

public:
ConvertGPULaunchFuncToByrePattern(MLIRContext *ctx, bool useBarePtrCallConv)
: OpRewritePattern<gpu::LaunchFuncOp>(ctx),
useBarePtrCallConv(useBarePtrCallConv) {}
ConvertGPULaunchFuncToByrePattern(MLIRContext *ctx)
: OpRewritePattern<gpu::LaunchFuncOp>(ctx) {}

LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -129,10 +128,11 @@ class ConvertGPULaunchFuncToByrePattern
computeOp->setAttr("BlockSize.y", rewriter.getI32IntegerAttr(by));
computeOp->setAttr("BlockSize.z", rewriter.getI32IntegerAttr(bz));

if (useBarePtrCallConv) {
computeOp->setAttr(byre::getKernelCallConventionAttrName(),
rewriter.getStringAttr("bare_ptr"));
// get discardable attributes from gpu.launch
for (auto attr : launchOp->getDiscardableAttrDictionary().getValue()) {
computeOp->setAttr(attr.getName(), attr.getValue());
}

rewriter.eraseOp(launchOp);

return success();
Expand Down Expand Up @@ -164,14 +164,11 @@ struct ConvertFuncToByreTensorPass
struct ConvertGPULaunchFuncToByrePass
: public ConvertGPULaunchFuncToByreBase<ConvertGPULaunchFuncToByrePass> {
public:
ConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv)
: ConvertGPULaunchFuncToByreBase() {
this->useBarePtrCallConv = useBarePtrCallConv;
}
ConvertGPULaunchFuncToByrePass() : ConvertGPULaunchFuncToByreBase() {}
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
populateGPULaunchFuncToByrePattern(patterns, useBarePtrCallConv);
populateGPULaunchFuncToByrePattern(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
Expand All @@ -186,18 +183,15 @@ void mlir::populateFuncToByreTensorPattern(RewritePatternSet &patterns,
appendArgTypes);
}

void mlir::populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns,
bool useBarePtrCallConv) {
patterns.add<ConvertGPULaunchFuncToByrePattern>(patterns.getContext(),
useBarePtrCallConv);
void mlir::populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns) {
patterns.add<ConvertGPULaunchFuncToByrePattern>(patterns.getContext());
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertFuncToByreTensorPass(bool appendArgTypes) {
return std::make_unique<ConvertFuncToByreTensorPass>(appendArgTypes);
}

std::unique_ptr<Pass>
mlir::createConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv) {
return std::make_unique<ConvertGPULaunchFuncToByrePass>(useBarePtrCallConv);
std::unique_ptr<Pass> mlir::createConvertGPULaunchFuncToByrePass() {
return std::make_unique<ConvertGPULaunchFuncToByrePass>();
}
4 changes: 4 additions & 0 deletions compiler/lib/Conversion/ToPTX/CollectGPUKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
//===----------------------------------------------------------------------===//

#include "byteir/Conversion/ToPTX/ToPTX.h"
#include "byteir/Dialect/Byre/Common.h"
#include "byteir/Dialect/mhlo/Transforms/HloFuser.h"
#include "byteir/Transforms/ShapeFuncOutlining.h"
#include "byteir/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -30,6 +33,7 @@ using namespace mlir;
using namespace mlir::byre;
using namespace mlir::gpu;
using namespace llvm;
using namespace func;

namespace {

Expand Down
46 changes: 19 additions & 27 deletions compiler/lib/Conversion/ToPTX/GenPTXConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,25 @@ static bool isAliasOp(Operation &op) {
// support static for now
// TODO extend it to support dynamic block/grid sizes
// TODO unify CUDA/PTX into the same pass with compilation option
static void addFuncAttrs(func::FuncOp func, bool useBarePtrCallConv) {
static void addFuncAttrs(func::FuncOp func, bool useBarePtrCallConv,
std::string &fileName) {
mlir::OpBuilder opBuilder(func);

for (auto launchOp : func.getOps<gpu::LaunchFuncOp>()) {
launchOp->setAttr("device_file_name", opBuilder.getStringAttr(fileName));
if (useBarePtrCallConv) {
launchOp->setAttr(byre::getKernelCallConventionAttrName(),
opBuilder.getStringAttr("bare_ptr"));
}
}

// handle elementwise fusion
if (func->hasAttr(getByteIRElementwiseFusionAttrName())) {
mlir::OpBuilder opBuilder(func);

if (func.getOps<gpu::LaunchFuncOp>().empty())
return;

gpu::LaunchFuncOp launchOp = *func.getOps<gpu::LaunchFuncOp>().begin();

func->setAttr(getByrePrefix() + "kernel_name",
opBuilder.getStringAttr(launchOp.getKernelName().getValue()));

// Handle 1D only, since element-wise is only using 1D (linearized)
auto grid = launchOp.getGridSizeOperandValues();
int64_t gx = cast<ConstantIndexOp>(grid.x.getDefiningOp()).value();
func->setAttr(getByrePrefix() + "GridSize.x",
opBuilder.getIntegerAttr(opBuilder.getIntegerType(32), gx));

auto block = launchOp.getBlockSizeOperandValues();
int64_t bx = cast<ConstantIndexOp>(block.x.getDefiningOp()).value();
func->setAttr(getByrePrefix() + "BlockSize.x",
opBuilder.getIntegerAttr(opBuilder.getIntegerType(32), bx));

func->setAttr(getByreComputeName(), opBuilder.getStringAttr("PTXOp"));
func->setAttr(getByreForceComputeNameAttrName(), opBuilder.getUnitAttr());
if (useBarePtrCallConv)
func->setAttr(getByrePrefix() + getKernelCallConventionAttrName(),
opBuilder.getStringAttr("bare_ptr"));

// Handle arg mapping here
// LWC: this is tentative when we are using GPU Kernel Outlining.
// TODO: drop this when we are arrange our arg placement in our own gpu
Expand Down Expand Up @@ -146,19 +135,22 @@ static void addFuncAttrs(func::FuncOp func, bool useBarePtrCallConv) {

// Main Pass
struct GenPTXConfigPass : public GenPTXConfigBase<GenPTXConfigPass> {
GenPTXConfigPass(bool useBarePtrCallConv) : GenPTXConfigBase() {
GenPTXConfigPass(bool useBarePtrCallConv, const std::string &fileName)
: GenPTXConfigBase() {
this->useBarePtrCallConv = useBarePtrCallConv;
this->fileName = fileName;
}

void runOnOperation() override {
func::FuncOp func = getOperation();
addFuncAttrs(func, this->useBarePtrCallConv);
addFuncAttrs(func, this->useBarePtrCallConv, this->fileName);
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createGenPTXConfigPass(bool useBarePtrCallConv) {
return std::make_unique<GenPTXConfigPass>(useBarePtrCallConv);
mlir::createGenPTXConfigPass(bool useBarePtrCallConv,
const std::string &fileName) {
return std::make_unique<GenPTXConfigPass>(useBarePtrCallConv, fileName);
}
21 changes: 4 additions & 17 deletions compiler/lib/Pipelines/ByreHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,15 @@ using namespace mlir;
using namespace mlir::byre;

namespace {
void createByreHostPipelineImpl(OpPassManager &pm, const std::string &entryFunc,
const std::string &deviceFile,
const std::string &target) {
void createByreHostPipelineImpl(OpPassManager &pm,
const std::string &entryFunc) {
pm.addPass(createCollectFuncPass(
byre::ByreDialect::getEntryPointFunctionAttrName()));

std::string stringAttr = "device_file_name:String:" + deviceFile;
pm.addPass(createFuncTagPass(/*anchorTag=*/"", stringAttr, entryFunc));

// currently use SetOpSpace + SetArgSpace to specify space here
// TODO: later move to GPUOpt after general copy finish
if (!target.empty()) {
// FIXME(chhuang) disable set-op-space here to avoid set discardable attr to
// host side ops, which leads to serialize fail.
// pm.addNestedPass<func::FuncOp>(createSetOpSpacePass(entryFunc, target));
pm.addPass(createSetArgSpacePass(entryFunc, target, true));
}
}
} // namespace

void mlir::createByreHostPipeline(OpPassManager &pm,
const ByreHostPipelineOptions &options) {
invokeOpPassPipelineBuilder(createByreHostPipelineImpl, pm, options.entryFunc,
options.deviceFile, options.target);
invokeOpPassPipelineBuilder(createByreHostPipelineImpl, pm,
options.entryFunc);
}
12 changes: 7 additions & 5 deletions compiler/lib/Pipelines/GPU/GPUOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ using namespace mlir::bufferization;

namespace {
void createElementwiseGPUOptPipelineImpl(OpPassManager &pm,
const bool &useBarePtrCallConv,
const std::string &target) {
// apply PromotoBufferStack to func's with
// getByteIRElementwiseFusionAttrName
Expand Down Expand Up @@ -79,7 +78,6 @@ void createElementwiseGPUOptPipelineImpl(OpPassManager &pm,
pm.addPass(createConvertFuncToGPUPass(/*bs=*/{256, 1, 1}));

addCleanUpExtPassPipeline(pm);
pm.addNestedPass<func::FuncOp>(createGenPTXConfigPass(useBarePtrCallConv));
}

void createReductionGPUOptPipelineImpl(OpPassManager &pm) {
Expand Down Expand Up @@ -127,9 +125,12 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) {
}

void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv,
const std::string &target) {
createElementwiseGPUOptPipelineImpl(pm, useBarePtrCallConv, target);
const std::string &target,
const std::string &fileName) {
createElementwiseGPUOptPipelineImpl(pm, target);
createReductionGPUOptPipelineImpl(pm);
pm.addNestedPass<func::FuncOp>(
createGenPTXConfigPass(useBarePtrCallConv, fileName));
pm.addPass(createCollectGPUKernelPass("unified", false));
}

Expand All @@ -138,5 +139,6 @@ void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv,
void mlir::createGPUOptPipeline(OpPassManager &pm,
const GPUOptPipelineOptions &options) {
invokeOpPassPipelineBuilder(createGPUOptPipelineImpl, pm,
options.useBarePtrCallConv, options.target);
options.useBarePtrCallConv, options.target,
options.fileName);
}
30 changes: 9 additions & 21 deletions compiler/python/byteir/compile.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,14 @@ def _compile_cuda(
_print_verbose(module, "// IR Dump After SCF Opt:") if verbose else ...
with context:
if useBarePtrCallConv:
PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true})").run(module.operation)
PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true device-file-name="+ output_file_prefix + ".ptx" + "})").run(module.operation)
else:
PassManager.parse("builtin.module(gpu-opt)").run(module.operation)
PassManager.parse("builtin.module(gpu-opt{device-file-name=" + output_file_prefix + ".ptx" + "})").run(module.operation)
_print_verbose(module, "// IR Dump After GPU Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(module.operation)
PassManager.parse("builtin.module(inline)").run(module.operation)
PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation)
if useBarePtrCallConv:
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre{use-bare-ptr-memref-call-conv=true}))").run(module.operation)
else:
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(module.operation)
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(module.operation)
PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(module.operation)
PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Set Space Opt:") if verbose else ...
Expand All @@ -159,7 +155,7 @@ def _compile_cuda(

# create host mlir
with context:
PassManager.parse("builtin.module(byre-host{device-file-name=" + output_file_prefix + ".ptx" + " " + target_str + " " + entry_func_str + "})").run(module.operation)
PassManager.parse("builtin.module(byre-host)").run(module.operation)
_print_verbose(module, "// IR Dump After Byre Host:") if verbose else ...
# write to output host mlir file
assert output_type is OutputType.MLIR, "TBD: emit mlirbc"
Expand Down Expand Up @@ -236,18 +232,14 @@ def _compile_cuda_with_ait(
_print_verbose(processor.module, "// IR Dump After SCF Opt:") if verbose else ...
with context:
if useBarePtrCallConv:
PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true})").run(processor.module.operation)
PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true device-file-name="+ output_file_prefix + ".ptx" + "})").run(module.operation)
else:
PassManager.parse("builtin.module(gpu-opt)").run(processor.module.operation)
PassManager.parse("builtin.module(gpu-opt{device-file-name=" + output_file_prefix + ".ptx" + "})").run(module.operation)
_print_verbose(processor.module, "// IR Dump After GPU Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(processor.module.operation)
PassManager.parse("builtin.module(inline)").run(processor.module.operation)
PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation)
if useBarePtrCallConv:
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre{use-bare-ptr-memref-call-conv=true}))").run(processor.module.operation)
else:
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation)
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation)
PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(processor.module.operation)
PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Set Space Opt:") if verbose else ...
Expand All @@ -268,7 +260,7 @@ def _compile_cuda_with_ait(
byteir.translate_to_ptx(device_module, output_file_dir + "/" + output_file_prefix, gpu_arch)

with context:
PassManager.parse("builtin.module(byre-host{device-file-name=" + output_file_prefix + ".ptx" + " " + target_str + " " + entry_func_str + "})").run(processor.module.operation)
PassManager.parse("builtin.module(byre-host)").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Byre Host:") if verbose else ...
# write to output host mlir
assert output_type is OutputType.MLIR, "TBD: emit mlirbc"
Expand Down Expand Up @@ -348,12 +340,8 @@ def _compile_cpu(

# create host module
with context:
PassManager.parse("builtin.module(byre-host{" + target_str + " " + entry_func_str + "})").run(module.operation)
PassManager.parse("builtin.module(byre-host)").run(module.operation)
_print_verbose(module, "// IR Dump After Byre Host:") if verbose else ...
# FIXME: remove `device_file_name` attr as byre v1.0.0 not support this attr.
target_attr_name = "device_file_name"
PassManager.parse("builtin.module(remove-func-tag{" + f"attr-name={target_attr_name} " + f" func-name={entry_func} " + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Remove func tag:") if verbose else ...

output_host_mlir_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIR.value)
output_host_mlirbc_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIRBC.value)
Expand Down
Loading
Loading