Skip to content

Commit

Permalink
Add some missing type verification. (#2400)
Browse files Browse the repository at this point in the history
* Add some missing type verification.

Signed-off-by: Eric Schweitz <[email protected]>

* Remove unneeded change.

Signed-off-by: Eric Schweitz <[email protected]>

---------

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi authored Nov 21, 2024
1 parent 78f6014 commit 4f2104b
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 8 deletions.
3 changes: 2 additions & 1 deletion include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def VeqType : QuakeType<"Veq", "veq"> {

let parameters = (ins "std::size_t":$size);

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
Expand All @@ -177,8 +176,10 @@ def StruqType : QuakeType<"Struq", "struq"> {

let parameters = (ins
"mlir::StringAttr":$name,
// members must be NonStruqRefType.
ArrayRefParameter<"mlir::Type">:$members
);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
Expand Down
10 changes: 10 additions & 0 deletions lib/Optimizer/Dialect/Quake/QuakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ LogicalResult quake::AllocaOp::verify() {
return emitOpError("size operand required");
}
}
} else {
// Size has no semantics for any type other than quake.veq.
if (getSize())
return emitOpError("cannot specify size with this quantum type");

if (auto stqTy = dyn_cast<StruqType>(getResult().getType()))
for (auto t : stqTy.getMembers())
if (auto vt = dyn_cast<VeqType>(t))
if (!vt.hasSpecifiedSize())
return emitOpError("struq type must have specified size");
}

// Check the uses. If any use is a InitializeStateOp, then it must be the only
Expand Down
10 changes: 3 additions & 7 deletions lib/Optimizer/Dialect/Quake/QuakeTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ Type quake::VeqType::parse(AsmParser &parser) {
return get(parser.getContext(), size);
}

LogicalResult
quake::VeqType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
std::size_t size) {
// FIXME: Do we want to check the size of the veq for some bound?
return success();
}

//===----------------------------------------------------------------------===//

Type quake::StruqType::parse(AsmParser &parser) {
Expand All @@ -77,6 +70,9 @@ Type quake::StruqType::parse(AsmParser &parser) {
break;
if (!succeeded(*optTy))
return {};
if (!llvm::isa<quake::RefType, quake::VeqType>(member))
parser.emitError(parser.getCurrentLocation(),
"invalid struq member type");
members.push_back(member);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseGreater())
Expand Down
17 changes: 17 additions & 0 deletions test/Quake/invalid.qke
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ func.func @test_struq() {
%0 = quake.alloca !quake.veq<4>
%1 = arith.constant 1 : i32
%2 = arith.constant 2.0 : f32
// expected-error@+2 {{invalid struq member type}}
// expected-error@+1 {{must be non-struct quantum reference type}}
%6 = quake.make_struq %0, %1, %2 : (!quake.veq<4>, i32, f32) -> !quake.struq<!quake.veq<?>, i32, f32>
return
Expand Down Expand Up @@ -54,3 +55,19 @@ func.func @test_struq(%arg : !quake.struq<!quake.veq<1>, !quake.veq<2>, !quake.v
%6 = quake.get_member %arg[3] : (!quake.struq<!quake.veq<1>, !quake.veq<2>, !quake.veq<3>>) -> !quake.veq<1>
return
}

// -----

func.func @test_struq() {
// expected-error@+1 {{struq type must have specified size}}
%0 = quake.alloca !quake.struq<!quake.veq<1>, !quake.veq<?>>
return
}

// -----

func.func @test_struq() {
// expected-error@+1 {{invalid struq member type}}
%0 = quake.alloca !quake.struq<!quake.struq<!quake.veq<3>, !quake.ref>, !quake.veq<7>>
return
}
2 changes: 2 additions & 0 deletions test/Quake/roundtrip-ops.qke
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ func.func @quantum_product_type() {
%22 = quake.make_struq %20, %21 : (!quake.veq<7>, !quake.veq<8>) -> !quake.struq<!quake.veq<7>, !quake.veq<8>>
%23 = quake.get_member %22[0] : (!quake.struq<!quake.veq<7>, !quake.veq<8>>) -> !quake.veq<7>
%24 = quake.get_member %22[1] : (!quake.struq<!quake.veq<7>, !quake.veq<8>>) -> !quake.veq<8>
%25 = quake.alloca !quake.struq<!quake.veq<8>, !quake.veq<5>>
return
}

Expand All @@ -837,5 +838,6 @@ func.func @quantum_product_type() {
// CHECK: %[[VAL_12:.*]] = quake.make_struq %[[VAL_10]], %[[VAL_11]] : (!quake.veq<7>, !quake.veq<8>) -> !quake.struq<!quake.veq<7>, !quake.veq<8>>
// CHECK: %[[VAL_13:.*]] = quake.get_member %[[VAL_12]][0] : (!quake.struq<!quake.veq<7>, !quake.veq<8>>) -> !quake.veq<7>
// CHECK: %[[VAL_14:.*]] = quake.get_member %[[VAL_12]][1] : (!quake.struq<!quake.veq<7>, !quake.veq<8>>) -> !quake.veq<8>
// CHECK: %[[VAL_15:.*]] = quake.alloca !quake.struq<!quake.veq<8>, !quake.veq<5>>
// CHECK: return
// CHECK: }

0 comments on commit 4f2104b

Please sign in to comment.