diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9f110f0b97e5dd..44cba05d717afd 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -674,6 +674,7 @@ const std::vector kPirMkldnnBf16Passes{ "cpu_bfloat16_placement_pass", "cpu_bfloat16_pass", "cpu_bfloat16_type_placement_pass", + "cpu_special_ops_bf16_pass", "cpu_bf16_quantize_squash_pass", }; diff --git a/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.cc b/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.cc new file mode 100644 index 00000000000000..eb586c40c16773 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h" + +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +template +static pir::Type create_type(pir::Type type, + pir::Type out_dtype, + pir::IrContext *ctx) { + auto input_type = type.dyn_cast(); + return IrType2::get(ctx, + out_dtype, + input_type.dims(), + input_type.data_layout(), + input_type.lod(), + input_type.offset()); +} + +// For ops like conv and concat, their input is sometimes packed as VectorType, +// hence current quantization doesn't work. Here we deal with them specifically. +class ConcatBf16QuantizePattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern< + paddle::onednn::dialect::ConcatOp>::OpRewritePattern; + bool MatchAndRewrite( + paddle::onednn::dialect::ConcatOp op, + pir::PatternRewriter &rewriter) const override { // NOLINT + // The input should come from combine. + pir::CombineOp pre_op = + pir::GetDefiningOpForInput(op, 0)->dyn_cast(); + if (!pre_op) return false; + if (!pre_op.out().HasOneUse()) return false; + + auto op_attributes = op->attributes(); + auto onednn_data_type = op_attributes.at("mkldnn_data_type") + .dyn_cast() + .AsString(); + if (onednn_data_type == "bfloat16") return false; + op_attributes["mkldnn_data_type"] = rewriter.str_attr("bfloat16"); + + auto combine_inputs = pre_op.inputs(); + + for (size_t idx = 0; idx < combine_inputs.size(); idx++) { + auto type = pre_op->operand_type(idx); + // Currently we only process case where elements are all DenseTensor(s) + if (!type.isa()) return false; + // All Tensors should be fp32 + auto dtype = pir::GetDataTypeFromValue(pre_op->operand_source(idx)); + if (!dtype.isa()) return false; + } + + pir::IrContext *ctx = rewriter.ir_context(); + + std::unordered_map q_attributes; + q_attributes["scale"] = rewriter.float_attr(1.0f); + q_attributes["shift"] = rewriter.float_attr(0.0f); + q_attributes["is_negative_input"] = rewriter.bool_attr(false); + q_attributes["output_format"] = rewriter.str_attr("NCHW"); + q_attributes["bfloat16"] = rewriter.bool_attr(true); + + // Insert quantize before combine + std::vector new_combine_inputs(combine_inputs.size()); + for (size_t idx = 0; idx < combine_inputs.size(); idx++) { + paddle::onednn::dialect::QuantizeOp quant_op = + rewriter.Build( + combine_inputs[idx], q_attributes); + auto type = quant_op->result_type(0); + pir::Type new_type = + create_type( + type, pir::BFloat16Type::get(ctx), ctx); + quant_op->result(0).set_type(new_type); + new_combine_inputs[idx] = quant_op.output(); + } + // Create new combine + pir::CombineOp new_combine = + rewriter.Build(new_combine_inputs); + rewriter.ReplaceAllUsesWith(pre_op.out(), new_combine.out()); + rewriter.EraseOp(pre_op); + + // Create new concat + auto concat_info = + ctx->GetRegisteredOpInfo(paddle::onednn::dialect::ConcatOp::name()); + if (!concat_info) return false; + + std::vector op_item_inner_output_types; + auto type = op->result_type(0); + pir::Type new_type = + create_type( + type, pir::BFloat16Type::get(ctx), ctx); + op_item_inner_output_types.push_back(new_type); + + paddle::onednn::dialect::ConcatOp new_concat = + rewriter + .Build({new_combine.out(), op.axis()}, + op_attributes, + op_item_inner_output_types, + concat_info) + ->dyn_cast(); + + // Insert dequant op under concat + std::unordered_map dq_attributes; + dq_attributes["scale"] = rewriter.float_attr(1.0f); + dq_attributes["shift"] = rewriter.float_attr(0.0f); + paddle::onednn::dialect::DequantizeOp dequant_op = + rewriter.Build(new_concat.out(), + dq_attributes); + + rewriter.ReplaceAllUsesWith(op.out(), dequant_op.output()); + rewriter.EraseOp(op); + return true; + } +}; + +class CPUSpecialOpsBf16Pass : public pir::PatternRewritePass { + public: + CPUSpecialOpsBf16Pass() + : pir::PatternRewritePass("cpu_special_ops_bf16_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + uint32_t benefit = 100; + + auto concat_bf16_quant_pattern = + std::make_unique( + context, benefit--, std::vector{}); + ps.Add(std::move(concat_bf16_quant_pattern)); + + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateCPUSpecialOpsBf16Pass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(cpu_special_ops_bf16_pass, CPUSpecialOpsBf16Pass); diff --git a/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h b/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h new file mode 100644 index 00000000000000..9dcf771121c24b --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateCPUSpecialOpsBf16Pass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index 40c73e8c5cb257..c3a6629cac9e7f 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -91,6 +91,7 @@ USE_PIR_PASS(cpu_bfloat16_placement_pass); USE_PIR_PASS(cpu_bfloat16_type_placement_pass); USE_PIR_PASS(cpu_bfloat16_pass); USE_PIR_PASS(cpu_bf16_quantize_squash_pass); +USE_PIR_PASS(cpu_special_ops_bf16_pass); #endif #ifdef PADDLE_WITH_XPU diff --git a/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py b/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py index 7caef918729b2f..734611b5fe52ff 100644 --- a/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py +++ b/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py @@ -1121,5 +1121,52 @@ def test_check_output(self): self.check_pass_correct() +class TestConcatBfloatQuantizePass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[5, 5, 5, 5], dtype='float32' + ) + z = paddle.static.data( + name='z', shape=[5, 5, 5, 5], dtype='float32' + ) + out = paddle.concat((x, y, z)) + out = paddle.assign(out) + self.pass_attr_list = [ + {'onednn_placement_pass': {}}, + {'cpu_special_ops_bf16_pass': {}}, + ] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "y": np.random.random((5, 5, 5, 5)).astype("float32"), + "z": np.random.random((5, 5, 5, 5)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.concat": 1, + "onednn_op.dequantize": 1, + "onednn_op.quantize": 3, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + + def test_check_output(self): + self.check_pass_correct(rtol=1e-02, atol=1e-02) + + if __name__ == "__main__": unittest.main()