Skip to content

Commit

Permalink
[core] Modify quake.subveq such that it folds constant arguments.
Browse files Browse the repository at this point in the history
There was some technical debt with quake.subveq where constant
operands were not folded into the operation like extract_ref and
many others.

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi committed Nov 22, 2024
1 parent c5a6a7c commit 2edf57b
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 170 deletions.
125 changes: 125 additions & 0 deletions include/cudaq/Optimizer/Dialect/Quake/Canonical.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/****************************************************************-*- C++ -*-****
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#pragma once

#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"

namespace quake::canonical {

inline mlir::Value createCast(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value inVal) {
auto i64Ty = rewriter.getI64Type();
assert(inVal.getType() != rewriter.getIndexType() &&
"use of index type is deprecated");
return rewriter.create<cudaq::cc::CastOp>(loc, i64Ty, inVal,
cudaq::cc::CastOpMode::Unsigned);
}

class ExtractRefFromSubVeqPattern
: public mlir::OpRewritePattern<ExtractRefOp> {
public:
using OpRewritePattern::OpRewritePattern;

// Replace a pattern such as:
// ```
// %1 = ... : !quake.veq<4>
// %2 = quake.subveq %1, %c2, %c3 : (!quake.veq<4>, i32, i32) ->
// !quake.veq<2>
// %3 = quake.extract_ref %2[0] : (!quake.veq<2>) -> !quake.ref
// ```
// with:
// ```
// %1 = ... : !quake.veq<4>
// %3 = quake.extract_ref %1[2] : (!uwake.veq<4>) -> !quake.ref
// ```
mlir::LogicalResult
matchAndRewrite(ExtractRefOp extract,
mlir::PatternRewriter &rewriter) const override {
auto subveq = extract.getVeq().getDefiningOp<SubVeqOp>();
if (!subveq)
return mlir::failure();
// Let the combining of back-to-back subveq ops happen first.
if (isa<SubVeqOp>(subveq.getVeq().getDefiningOp()))
return mlir::failure();

mlir::Value offset;
auto loc = extract.getLoc();
auto low = [&]() -> mlir::Value {
if (subveq.hasConstantLowerBound())
return rewriter.create<mlir::arith::ConstantIntOp>(
loc, subveq.getConstantLowerBound(), 64);
return subveq.getLower();
}();
if (extract.hasConstantIndex()) {
mlir::Value cv = rewriter.create<mlir::arith::ConstantIntOp>(
loc, extract.getConstantIndex(), low.getType());
offset = rewriter.create<mlir::arith::AddIOp>(loc, cv, low);
} else {
auto cast1 = createCast(rewriter, loc, extract.getIndex());
auto cast2 = createCast(rewriter, loc, low);
offset = rewriter.create<mlir::arith::AddIOp>(loc, cast1, cast2);
}
rewriter.replaceOpWithNewOp<ExtractRefOp>(extract, subveq.getVeq(), offset);
return mlir::success();
}
};

// Combine back-to-back quake.subveq operations.
//
// %10 = quake.subveq %4, 1, 6 : (!quake.veq<?>) -> !quake.veq<7>
// %11 = quake.subveq %10, 0, 2 : (!quake.veq<7>) -> !quake.veq<3>
// ───────────────────────────────────────────────────────────────
// %11 = quake.subveq %4, 1, 3 : (!quake.veq<?>) -> !quake.veq<3>
class CombineSubVeqsPattern : public mlir::OpRewritePattern<SubVeqOp> {
public:
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(SubVeqOp subveq,
mlir::PatternRewriter &rewriter) const override {
auto prior = subveq.getVeq().getDefiningOp<SubVeqOp>();
if (!prior)
return mlir::failure();

auto loc = subveq.getLoc();

// Lambda to create a Value for the lower bound of `s`.
auto lofunc = [&](SubVeqOp s) -> mlir::Value {
if (s.hasConstantLowerBound())
return rewriter.create<mlir::arith::ConstantIntOp>(
loc, s.getConstantLowerBound(), 64);
return s.getLower();
};
auto priorlo = lofunc(prior);
auto svlo = lofunc(subveq);

// Lambda for creating the upper bound Value.
auto svup = [&]() -> mlir::Value {
if (subveq.hasConstantUpperBound())
return rewriter.create<mlir::arith::ConstantIntOp>(
loc, subveq.getConstantUpperBound(), 64);
return subveq.getUpper();
}();
auto cast1 = createCast(rewriter, loc, priorlo);
auto cast2 = createCast(rewriter, loc, svlo);
auto cast3 = createCast(rewriter, loc, svup);
mlir::Value sum1 = rewriter.create<mlir::arith::AddIOp>(loc, cast1, cast2);
mlir::Value sum2 = rewriter.create<mlir::arith::AddIOp>(loc, cast1, cast3);
auto veqTy = subveq.getType();
rewriter.replaceOpWithNewOp<SubVeqOp>(subveq, veqTy, prior.getVeq(), sum1,
sum2);
return mlir::success();
}
};

} // namespace quake::canonical
24 changes: 0 additions & 24 deletions include/cudaq/Optimizer/Dialect/Quake/Canonical.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,4 @@ def FuseConstantToExtractRefPattern : Pat<
(createExtractRefOp $veq, $index),
[(SizeIsPresentPred $index)]>;

def createSizedSubVeqOp : NativeCodeCall<
"quake::createSizedSubVeqOp($_builder, $_loc, $0, $1, $2, $3)">;

def ArgIsConstantPred : Constraint<CPred<
"dyn_cast_or_null<mlir::arith::ConstantOp>($0.getDefiningOp())">>;

def IsUnknownVec : Constraint<CPred<
"dyn_cast_or_null<mlir::arith::ConstantOp>($0.getDefiningOp())">>;

// %1 = constant 4 : i64
// %2 = constant 10 : i64
// %3 = quake.subveq (%0 : !quake.ref<12>, %1 : i64, %2 : i64) : !quake.ref<?>
// ─────────────────────────────────────────────────────────────────────────────
// %1 = constant 4 : i64
// %2 = constant 10 : i64
// %new3 = quake.subveq (%0 : !quake.ref<12>, %1 : i64, %2 : i64) :
// !quake.ref<7>
// %3 = quake.relax_size %new3 : (!quake.ref<7>) -> !quake.ref<?>
def FuseConstantToSubveqPattern : Pat<
(quake_SubVeqOp:$subveq $v, $lo, $hi),
(createSizedSubVeqOp $subveq, $v, $lo, $hi),
[(UnknownSizePred $subveq), (ArgIsConstantPred $lo),
(ArgIsConstantPred $hi)]>;

#endif
5 changes: 0 additions & 5 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ mlir::Value createConstantAlloca(mlir::PatternRewriter &builder,
mlir::Location loc, mlir::OpResult result,
mlir::ValueRange args);

mlir::Value createSizedSubVeqOp(mlir::PatternRewriter &builder,
mlir::Location loc, mlir::OpResult result,
mlir::Value inVec, mlir::Value lo,
mlir::Value hi);

void getResetEffectsImpl(
mlir::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
Expand Down
34 changes: 30 additions & 4 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def quake_RelaxSizeOp : QuakeOp<"relax_size", [Pure]> {
let hasCanonicalizer = 1;
}

def quake_SubVeqOp : QuakeOp<"subveq", [Pure]> {
def quake_SubVeqOp : QuakeOp<"subveq", [AttrSizedOperandSegments, Pure]> {
let summary = "Extract a subvector from a veq reference value.";
let description = [{
The `subveq` operation returns a subvector of references, type
Expand All @@ -298,16 +298,42 @@ def quake_SubVeqOp : QuakeOp<"subveq", [Pure]> {

let arguments = (ins
VeqType:$veq,
AnySignlessIntegerOrIndex:$low,
AnySignlessIntegerOrIndex:$high
Optional<AnySignlessIntegerOrIndex>:$lower,
Optional<AnySignlessIntegerOrIndex>:$upper,
I64Attr:$rawLower,
I64Attr:$rawUpper
);
let results = (outs VeqType:$qsub);

let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
$veq `,` custom<RawIndex>($lower, $rawLower) `,` custom<RawIndex>($upper,
$rawUpper) `:` functional-type(operands, results) attr-dict
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;

let builders = [
OpBuilder<(ins "mlir::Type":$veqTy, "mlir::Value":$input,
"mlir::Value":$lower, "mlir::Value":$upper), [{
return build($_builder, $_state, veqTy, input, lower, upper,
quake::SubVeqOp::kDynamicIndex, quake::SubVeqOp::kDynamicIndex);
}]>,
OpBuilder<(ins "mlir::Type":$veqTy, "mlir::Value":$input,
"std::int64_t":$lower, "std::int64_t":$upper), [{
return build($_builder, $_state, veqTy, input, {}, {}, lower, upper);
}]>
];

let extraClassDeclaration = [{
static constexpr std::size_t kDynamicIndex =
std::numeric_limits<std::size_t>::max();

bool hasConstantLowerBound() { return getRawLower() != kDynamicIndex; }
bool hasConstantUpperBound() { return getRawUpper() != kDynamicIndex; }
std::size_t getConstantLowerBound() { return getRawLower(); }
std::size_t getConstantUpperBound() { return getRawUpper(); }
}];
}

def quake_VeqSizeOp : QuakeOp<"veq_size", [Pure]> {
Expand Down
18 changes: 14 additions & 4 deletions lib/Optimizer/CodeGen/QuakeToCC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,19 @@ class SubveqOpRewrite : public OpConversionPattern<quake::SubVeqOp> {
matchAndRewrite(quake::SubVeqOp subveq, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = subveq.getLoc();
auto diff = rewriter.create<arith::SubIOp>(loc, adaptor.getHigh(),
adaptor.getLow());
auto up = [&]() -> Value {
if (!adaptor.getUpper())
return rewriter.create<arith::ConstantIntOp>(loc, adaptor.getRawUpper(),
64);
return adaptor.getUpper();
}();
auto lo = [&]() -> Value {
if (!adaptor.getLower())
return rewriter.create<arith::ConstantIntOp>(loc, adaptor.getRawLower(),
64);
return adaptor.getLower();
}();
auto diff = rewriter.create<arith::SubIOp>(loc, up, lo);
auto one = rewriter.create<arith::ConstantIntOp>(loc, 1, 64);
auto length = rewriter.create<arith::AddIOp>(loc, diff, one);
// Compute the pointer to the first element in the subveq and build a new
Expand All @@ -260,8 +271,7 @@ class SubveqOpRewrite : public OpConversionPattern<quake::SubVeqOp> {
loc, ptrptrTy, adaptor.getVeq(), ArrayRef<cudaq::cc::ComputePtrArg>{0});
auto qspanData = rewriter.create<cudaq::cc::LoadOp>(loc, qspanDataPtr);
auto buffer = rewriter.create<cudaq::cc::ComputePtrOp>(
loc, ptrI64Ty, qspanData,
ArrayRef<cudaq::cc::ComputePtrArg>{adaptor.getLow()});
loc, ptrI64Ty, qspanData, ArrayRef<cudaq::cc::ComputePtrArg>{lo});
auto qspanTy = cudaq::opt::getCudaqQubitSpanType(rewriter.getContext());
Value newspan = rewriter.create<cudaq::cc::AllocaOp>(loc, qspanTy);
rewriter.create<func::CallOp>(loc, std::nullopt,
Expand Down
14 changes: 12 additions & 2 deletions lib/Optimizer/CodeGen/QuakeToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,18 @@ class SubveqOpRewrite : public ConvertOpToLLVMPattern<quake::SubVeqOp> {
rtSubveqFuncName, arrayTy, {arrayTy, i32Ty, i64Ty, i64Ty, i64Ty},
parentModule);

Value lowArg = adaptor.getOperands()[1];
Value highArg = adaptor.getOperands()[2];
auto lowArg = [&]() -> Value {
if (!adaptor.getLower())
return rewriter.create<arith::ConstantIntOp>(loc, adaptor.getRawLower(),
64);
return adaptor.getLower();
}();
auto highArg = [&]() -> Value {
if (!adaptor.getUpper())
return rewriter.create<arith::ConstantIntOp>(loc, adaptor.getRawUpper(),
64);
return adaptor.getUpper();
}();
auto extend = [&](Value &v) -> Value {
if (v.getType().isa<IntegerType>() &&
v.getType().cast<IntegerType>().getWidth() < 64)
Expand Down
Loading

0 comments on commit 2edf57b

Please sign in to comment.