diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index faaaf51f8d..c4dea87029 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -624,12 +624,13 @@ getSqMANP(mlir::tensor::ExpandShapeOp op, static std::optional getSqMANP(mlir::concretelang::FHELinalg::SumOp op, llvm::ArrayRef operandMANPs) { + llvm::APInt result = llvm::APInt(1, 1, false); auto inputType = op.getOperand().getType().dyn_cast(); uint64_t numberOfElementsInTheInput = inputType.getNumElements(); if (numberOfElementsInTheInput == 0) { - return llvm::APInt{1, 1, false}; + return result; } uint64_t numberOfElementsAddedTogetherInEachOutputCell = 1; @@ -659,9 +660,8 @@ getSqMANP(mlir::concretelang::FHELinalg::SumOp op, "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); - llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().value(); - - return APIntWidthExtendUMul(noiseMultiplier, operandMANP); + result = operandMANPs[0]->getValue().getMANP().value(); + return APIntWidthExtendUMul(noiseMultiplier, result); } static std::optional @@ -755,10 +755,11 @@ class MANPAnalysis debug(debug) {} void setToEntryState(MANPLattice *lattice) override { + auto baseMANP = llvm::APInt(1, 1); if (isEncryptedFunctionParameter(lattice->getPoint())) { // Set minimal MANP for encrypted function arguments - propagateIfChanged(lattice, lattice->join(MANPLatticeValue{ - std::optional{llvm::APInt(1, 1)}})); + propagateIfChanged( + lattice, lattice->join(MANPLatticeValue{std::optional{baseMANP}})); } // In case of block arguments used in the block of a linalg.genric // operation: map the MANP values of the operands into the block arguments @@ -781,8 +782,8 @@ class MANPAnalysis operandRange = genericOp.getOutputs(); } auto v = operandRange[argIndex]; - auto manp = this->getLatticeElement(v)->getValue().getMANP().value_or( - llvm::APInt(1, 1)); + auto manp = + this->getLatticeElement(v)->getValue().getMANP().value_or(baseMANP); propagateIfChanged(lattice, lattice->join(MANPLatticeValue{manp})); } } else {