diff --git a/scripts/format-code b/scripts/format-code index 476406f2a..d1e443344 100644 --- a/scripts/format-code +++ b/scripts/format-code @@ -169,15 +169,15 @@ check_clang-format() cf=$(command -v clang-format 2> /dev/null) if [[ ! -x ${cf} ]]; then echo "clang-format is not installed" - exit 1 + exit 0 # do not fail, just warn fi cf="clang-format" fi - local required_cfver='17.0.3' - # shellcheck disable=SC2155 - local cfver=$(${cf} --version | grep -o -E '[0-9]+\.[0-9]+\.[0-9]+' | head -1) - check_version "${required_cfver}" "${cfver}" +# local required_cfver='17.0.3' +# # shellcheck disable=SC2155 +# local cfver=$(${cf} --version | grep -o -E '[0-9]+\.[0-9]+\.[0-9]+' | head -1) +# check_version "${required_cfver}" "${cfver}" } check_clang-format diff --git a/src/asm_cfg.cpp b/src/asm_cfg.cpp index 37c50febd..d1fb4a7b8 100644 --- a/src/asm_cfg.cpp +++ b/src/asm_cfg.cpp @@ -52,7 +52,7 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t basic_block_t& caller_node = cfg.get_node(caller_label); const std::string stack_frame_prefix = to_string(caller_label); for (auto& inst : caller_node) { - if (const auto pcall = std::get_if(&inst)) { + if (const auto pcall = std::get_if(&inst.cmd)) { pcall->stack_frame_prefix = stack_frame_prefix; } } @@ -73,9 +73,9 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t const label_t label{macro_label.from, macro_label.to, stack_frame_prefix}; auto& bb = cfg.insert(label); for (auto inst : cfg.get_node(macro_label)) { - if (const auto pexit = std::get_if(&inst)) { + if (const auto pexit = std::get_if(&inst.cmd)) { pexit->stack_frame_prefix = label.stack_frame_prefix; - } else if (const auto pcall = std::get_if(&inst)) { + } else if (const auto pcall = std::get_if(&inst.cmd)) { pcall->stack_frame_prefix = label.stack_frame_prefix; } bb.insert(inst); @@ -123,7 +123,7 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t for (const auto& macro_label : seen_labels) { for (const label_t label(macro_label.from, macro_label.to, caller_label_str); const auto& inst : cfg.get_node(label)) { - if (const auto pins = std::get_if(&inst)) { + if (const auto pins = std::get_if(&inst.cmd)) { if (stack_frame_depth >= MAX_CALL_STACK_FRAMES) { throw std::runtime_error{"too many call stack frames"}; } @@ -153,7 +153,7 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, const bool must cfg.get_node(cfg.entry_label()) >> bb; } - bb.insert(inst); + bb.insert({.cmd = inst}); if (falling_from) { cfg.get_node(*falling_from) >> bb; falling_from = {}; @@ -236,9 +236,7 @@ static cfg_t to_nondet(const cfg_t& cfg) { basic_block_t& newbb = res.insert(this_label); for (const auto& ins : bb) { - if (!std::holds_alternative(ins)) { - newbb.insert(ins); - } + newbb.insert(ins); } for (const label_t& prev_label : bb.prev_blocks_set()) { @@ -250,7 +248,7 @@ static cfg_t to_nondet(const cfg_t& cfg) { auto nextlist = bb.next_blocks_set(); if (nextlist.size() == 2) { label_t mid_label = this_label; - auto jmp = std::get(*bb.rbegin()); + auto jmp = std::get(bb.rbegin()->cmd); nextlist.erase(jmp.target); label_t fallthrough = *nextlist.begin(); @@ -262,7 +260,7 @@ static cfg_t to_nondet(const cfg_t& cfg) { for (const auto& [next_label, cond1] : jumps) { label_t jump_label = label_t::make_jump(mid_label, next_label); basic_block_t& jump_bb = res.insert(jump_label); - jump_bb.insert(Assume{cond1}); + jump_bb.insert({.cmd = Assume{cond1}}); newbb >> jump_bb; jump_bb >> res.insert(next_label); } @@ -275,7 +273,7 @@ static cfg_t to_nondet(const cfg_t& cfg) { return res; } -/// Get the type of given instruction. +/// Get the type of given Instruction. /// Most of these type names are also statistics header labels. static std::string instype(Instruction ins) { if (const auto pcall = std::get_if(&ins)) { @@ -333,21 +331,21 @@ std::map collect_stats(const cfg_t& cfg) { res["basic_blocks"]++; basic_block_t const& bb = cfg.get_node(this_label); - for (Instruction ins : bb) { - if (const auto pins = std::get_if(&ins)) { + for (const auto& ins : bb) { + if (const auto pins = std::get_if(&ins.cmd)) { if (pins->mapfd == -1) { res["map_in_map"] = 1; } } - if (const auto pins = std::get_if(&ins)) { + if (const auto pins = std::get_if(&ins.cmd)) { if (pins->reallocate_packet) { res["reallocate"] = 1; } } - if (const auto pins = std::get_if(&ins)) { + if (const auto pins = std::get_if(&ins.cmd)) { res[pins->is64 ? "arith64" : "arith32"]++; } - res[instype(ins)]++; + res[instype(ins.cmd)]++; } if (unique(bb.prev_blocks()).size() > 1) { res["joins"]++; @@ -369,7 +367,7 @@ cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, const pr if (options.check_for_termination) { const wto_t wto(det_cfg); wto.for_each_loop_head( - [&](const label_t& label) { det_cfg.get_node(label).insert_front(IncrementLoopCounter{label}); }); + [&](const label_t& label) { det_cfg.get_node(label).insert_front({.cmd = IncrementLoopCounter{label}}); }); } // Annotate the CFG by adding in assertions before every memory instruction. diff --git a/src/asm_marshal.cpp b/src/asm_marshal.cpp index c46207b4f..9ce4d2f09 100644 --- a/src/asm_marshal.cpp +++ b/src/asm_marshal.cpp @@ -205,8 +205,6 @@ struct MarshalVisitor { vector operator()(Assume const&) const { throw std::invalid_argument("Cannot marshal assumptions"); } - vector operator()(Assert const&) const { throw std::invalid_argument("Cannot marshal assertions"); } - vector operator()(Jmp const& b) const { if (b.cond) { ebpf_inst res{ diff --git a/src/asm_ostream.cpp b/src/asm_ostream.cpp index 4dd1fa641..f76e023e1 100644 --- a/src/asm_ostream.cpp +++ b/src/asm_ostream.cpp @@ -105,87 +105,87 @@ std::ostream& operator<<(std::ostream& os, const Condition::Op op) { static string size(const int w) { return string("u") + std::to_string(w * 8); } -std::ostream& operator<<(std::ostream& os, ValidStore const& a) { - return os << a.mem << ".type != stack -> " << TypeConstraint{a.val, TypeGroup::number}; -} +// ReSharper disable CppMemberFunctionMayBeConst +struct AssertionPrinterVisitor { + std::ostream& _os; + + void operator()(ValidStore const& a) { + _os << a.mem << ".type != stack -> " << TypeConstraint{a.val, TypeGroup::number}; + } + + void operator()(ValidAccess const& a) { + if (a.or_null) { + _os << "(" << TypeConstraint{a.reg, TypeGroup::number} << " and " << a.reg << ".value == 0) or "; + } + _os << "valid_access(" << a.reg << ".offset"; + if (a.offset > 0) { + _os << "+" << a.offset; + } else if (a.offset < 0) { + _os << a.offset; + } -std::ostream& operator<<(std::ostream& os, ValidAccess const& a) { - if (a.or_null) { - os << "(" << TypeConstraint{a.reg, TypeGroup::number} << " and " << a.reg << ".value == 0) or "; - } - os << "valid_access(" << a.reg << ".offset"; - if (a.offset > 0) { - os << "+" << a.offset; - } else if (a.offset < 0) { - os << a.offset; - } - - if (a.width == Value{Imm{0}}) { - // a.width == 0, meaning we only care it's an in-bound pointer, - // so it can be compared with another pointer to the same region. - os << ") for comparison/subtraction"; - } else { - os << ", width=" << a.width << ") for "; - if (a.access_type == AccessType::read) { - os << "read"; + if (a.width == Value{Imm{0}}) { + // a.width == 0, meaning we only care it's an in-bound pointer, + // so it can be compared with another pointer to the same region. + _os << ") for comparison/subtraction"; } else { - os << "write"; + _os << ", width=" << a.width << ") for "; + if (a.access_type == AccessType::read) { + _os << "read"; + } else { + _os << "write"; + } } } - return os; -} - -std::ostream& operator<<(std::ostream& os, const BoundedLoopCount& a) { - return os << crab::variable_t::loop_counter(to_string(a.name)) << " < " << a.limit; -} -static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); } + void operator()(const BoundedLoopCount& a) { + _os << crab::variable_t::loop_counter(to_string(a.name)) << " < " << a.limit; + } -std::ostream& operator<<(std::ostream& os, ValidSize const& a) { - const auto op = a.can_be_zero ? " >= " : " > "; - return os << a.reg << ".value" << op << 0; -} + static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); } -std::ostream& operator<<(std::ostream& os, ValidCall const& a) { - const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func); - return os << "valid call(" << proto.name << ")"; -} + void operator()(ValidSize const& a) { + const auto op = a.can_be_zero ? " >= " : " > "; + _os << a.reg << ".value" << op << 0; + } -std::ostream& operator<<(std::ostream& os, ValidMapKeyValue const& a) { - return os << "within stack(" << a.access_reg << ":" << (a.key ? "key_size" : "value_size") << "(" << a.map_fd_reg - << "))"; -} + void operator()(ValidCall const& a) { + const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func); + _os << "valid call(" << proto.name << ")"; + } -std::ostream& operator<<(std::ostream& os, ZeroCtxOffset const& a) { - return os << crab::variable_t::reg(crab::data_kind_t::ctx_offsets, a.reg.v) << " == 0"; -} + void operator()(ValidMapKeyValue const& a) { + _os << "within stack(" << a.access_reg << ":" << (a.key ? "key_size" : "value_size") << "(" << a.map_fd_reg + << "))"; + } -std::ostream& operator<<(std::ostream& os, Comparable const& a) { - if (a.or_r2_is_number) { - os << TypeConstraint{a.r2, TypeGroup::number} << " or "; + void operator()(ZeroCtxOffset const& a) { + _os << crab::variable_t::reg(crab::data_kind_t::ctx_offsets, a.reg.v) << " == 0"; } - return os << typereg(a.r1) << " == " << typereg(a.r2) << " in " << TypeGroup::singleton_ptr; -} -std::ostream& operator<<(std::ostream& os, Addable const& a) { - return os << TypeConstraint{a.ptr, TypeGroup::pointer} << " -> " << TypeConstraint{a.num, TypeGroup::number}; -} + void operator()(Comparable const& a) { + if (a.or_r2_is_number) { + _os << TypeConstraint{a.r2, TypeGroup::number} << " or "; + } + _os << typereg(a.r1) << " == " << typereg(a.r2) << " in " << TypeGroup::singleton_ptr; + } -std::ostream& operator<<(std::ostream& os, ValidDivisor const& a) { return os << a.reg << " != 0"; } + void operator()(Addable const& a) { + _os << TypeConstraint{a.ptr, TypeGroup::pointer} << " -> " << TypeConstraint{a.num, TypeGroup::number}; + } -std::ostream& operator<<(std::ostream& os, TypeConstraint const& tc) { - const string cmp_op = is_singleton_type(tc.types) ? "==" : "in"; - return os << typereg(tc.reg) << " " << cmp_op << " " << tc.types; -} + void operator()(ValidDivisor const& a) { _os << a.reg << " != 0"; } -std::ostream& operator<<(std::ostream& os, FuncConstraint const& fc) { return os << typereg(fc.reg) << " is helper"; } + void operator()(TypeConstraint const& tc) { + const string cmp_op = is_singleton_type(tc.types) ? "==" : "in"; + _os << typereg(tc.reg) << " " << cmp_op << " " << tc.types; + } -std::ostream& operator<<(std::ostream& os, AssertionConstraint const& a) { - return std::visit([&](const auto& a) -> std::ostream& { return os << a; }, a); -} + void operator()(FuncConstraint const& fc) { _os << typereg(fc.reg) << " is helper"; } +}; // ReSharper disable CppMemberFunctionMayBeConst -struct InstructionPrinterVisitor { +struct CommandPrinterVisitor { std::ostream& os_; void visit(const auto& item) { std::visit(*this, item); } @@ -259,7 +259,7 @@ struct InstructionPrinterVisitor { void operator()(Exit const& b) { os_ << "exit"; } void operator()(Jmp const& b) { - // A "standalone" jump instruction. + // A "standalone" jump Instruction. // Print the label without offset calculations. if (b.cond) { os_ << "if "; @@ -351,8 +351,6 @@ struct InstructionPrinterVisitor { print(b.cond); } - void operator()(Assert const& a) { os_ << "assert " << a.cst; } - void operator()(IncrementLoopCounter const& a) { os_ << crab::variable_t::loop_counter(to_string(a.name)) << "++"; } }; // ReSharper restore CppMemberFunctionMayBeConst @@ -364,7 +362,7 @@ string to_string(label_t const& label) { } std::ostream& operator<<(std::ostream& os, Instruction const& ins) { - std::visit(InstructionPrinterVisitor{os}, ins); + std::visit(CommandPrinterVisitor{os}, ins); return os; } @@ -374,7 +372,12 @@ string to_string(Instruction const& ins) { return str.str(); } -string to_string(AssertionConstraint const& constraint) { +std::ostream& operator<<(std::ostream& os, const Assertion& a) { + std::visit(AssertionPrinterVisitor{os}, a); + return os; +} + +string to_string(Assertion const& constraint) { std::stringstream str; str << constraint; return str.str(); @@ -407,10 +410,10 @@ void print(const InstructionSeq& insts, std::ostream& out, const std::optional ["; for (const label_t& label : bb.next_blocks_set()) { diff --git a/src/asm_ostream.hpp b/src/asm_ostream.hpp index ce0442ebd..83d210cd9 100644 --- a/src/asm_ostream.hpp +++ b/src/asm_ostream.hpp @@ -12,7 +12,7 @@ #include "crab_utils/num_safety.hpp" // We use a 16-bit offset whenever it fits in 16 bits. -inline std::function label_to_offset16(pc_t pc) { +inline std::function label_to_offset16(const pc_t pc) { return [=](const label_t& label) { const int64_t offset = label.from - gsl::narrow(pc) - 1; const bool is16 = @@ -22,7 +22,7 @@ inline std::function label_to_offset16(pc_t pc) { } // We use the JA32 opcode with the offset in 'imm' when the offset -// of an unconditional jump doesn't fit in a int16_t. +// of an unconditional jump doesn't fit in an int16_t. inline std::function label_to_offset32(const pc_t pc) { return [=](const label_t& label) { const int64_t offset = label.from - gsl::narrow(pc) - 1; @@ -54,19 +54,5 @@ inline std::ostream& operator<<(std::ostream& os, Value const& a) { return os << std::get(a); } -inline std::ostream& operator<<(std::ostream& os, Undefined const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, LoadMapFd const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Bin const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Un const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Call const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Callx const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Exit const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Jmp const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Packet const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Mem const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Atomic const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Assume const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, Assert const& a) { return os << Instruction{a}; } -inline std::ostream& operator<<(std::ostream& os, IncrementLoopCounter const& a) { return os << Instruction{a}; } -std::ostream& operator<<(std::ostream& os, AssertionConstraint const& a); -std::string to_string(AssertionConstraint const& constraint); +std::ostream& operator<<(std::ostream& os, const Assertion& a); +std::string to_string(const Assertion& constraint); diff --git a/src/asm_syntax.hpp b/src/asm_syntax.hpp index 72c6b4e80..a008fd73a 100644 --- a/src/asm_syntax.hpp +++ b/src/asm_syntax.hpp @@ -310,6 +310,17 @@ struct Assume { constexpr bool operator==(const Assume&) const = default; }; +struct IncrementLoopCounter { + label_t name; + bool operator==(const IncrementLoopCounter&) const = default; +}; + +using Instruction = std::variant; + +using LabeledInstruction = std::tuple>; +using InstructionSeq = std::vector; + /// Condition check whether something is a valid size. struct ValidSize { Reg reg; @@ -406,28 +417,17 @@ struct BoundedLoopCount { static constexpr int limit = 100000; }; -using AssertionConstraint = std::variant; +using Assertion = std::variant; -struct Assert { - AssertionConstraint cst; - Assert(AssertionConstraint cst) : cst(std::move(cst)) {} - constexpr bool operator==(const Assert&) const = default; +struct GuardedInstruction { + Instruction cmd; + std::vector preconditions; + bool operator==(const GuardedInstruction&) const = default; }; -struct IncrementLoopCounter { - label_t name; - bool operator==(const IncrementLoopCounter&) const = default; -}; - -using Instruction = std::variant; - -using LabeledInstruction = std::tuple>; -using InstructionSeq = std::vector; - // cpu=v4 supports 32-bit PC offsets so we need a large enough type. -using pc_t = size_t; +using pc_t = uint32_t; } // namespace asm_syntax diff --git a/src/asm_unmarshal.hpp b/src/asm_unmarshal.hpp index af06089d4..9a7a96053 100644 --- a/src/asm_unmarshal.hpp +++ b/src/asm_unmarshal.hpp @@ -15,7 +15,7 @@ * * \param raw_prog is the input program to parse. * \param[out] notes is a vector for storing errors and warnings. - * \return a sequence of instruction if successful, an error string otherwise. + * \return a sequence of instructions if successful, an error string otherwise. */ std::variant unmarshal(const raw_program& raw_prog, std::vector>& notes); diff --git a/src/assertions.cpp b/src/assertions.cpp index 6001494f8..96ce593e5 100644 --- a/src/assertions.cpp +++ b/src/assertions.cpp @@ -20,8 +20,8 @@ class AssertExtractor { static Imm imm(const Value& v) { return std::get(v); } - static vector zero_offset_ctx(const Reg reg) { - vector res; + static vector zero_offset_ctx(const Reg reg) { + vector res; res.emplace_back(TypeConstraint{reg, TypeGroup::ctx}); res.emplace_back(ZeroCtxOffset{reg}); return res; @@ -37,25 +37,20 @@ class AssertExtractor { explicit AssertExtractor(program_info info, std::optional label) : info{std::move(info)}, current_label(label) {} - vector operator()(const Undefined&) const { + vector operator()(const Undefined&) const { assert(false); return {}; } - vector operator()(const Assert&) const { - assert(false); - return {}; - } + vector operator()(const IncrementLoopCounter& ipc) const { return {{BoundedLoopCount{ipc.name}}}; } - vector operator()(const IncrementLoopCounter& ipc) const { return {{BoundedLoopCount{ipc.name}}}; } - - vector operator()(const LoadMapFd&) const { return {}; } + vector operator()(const LoadMapFd&) const { return {}; } /// Packet access implicitly uses R6, so verify that R6 still has a pointer to the context. - vector operator()(const Packet&) const { return zero_offset_ctx({6}); } + vector operator()(const Packet&) const { return zero_offset_ctx({6}); } - vector operator()(const Exit&) const { - vector res; + vector operator()(const Exit&) const { + vector res; if (current_label->stack_frame_prefix.empty()) { // Verify that Exit returns a number. res.emplace_back(TypeConstraint{Reg{R0_RETURN_VALUE}, TypeGroup::number}); @@ -63,8 +58,8 @@ class AssertExtractor { return res; } - vector operator()(const Call& call) const { - vector res; + vector operator()(const Call& call) const { + vector res; std::optional map_fd_reg; res.emplace_back(ValidCall{call.func, call.stack_frame_prefix}); for (ArgSingle arg : call.singles) { @@ -90,7 +85,7 @@ class AssertExtractor { res.emplace_back(ValidMapKeyValue{arg.reg, *map_fd_reg, arg.kind == ArgSingle::Kind::PTR_TO_MAP_KEY}); break; case ArgSingle::Kind::PTR_TO_CTX: - for (const Assert& a : zero_offset_ctx(arg.reg)) { + for (const Assertion& a : zero_offset_ctx(arg.reg)) { res.emplace_back(a); } break; @@ -120,21 +115,21 @@ class AssertExtractor { return res; } - vector operator()(const CallLocal&) const { return {}; } + vector operator()(const CallLocal&) const { return {}; } - vector operator()(const Callx& callx) const { - vector res; + vector operator()(const Callx& callx) const { + vector res; res.emplace_back(TypeConstraint{callx.func, TypeGroup::number}); res.emplace_back(FuncConstraint{callx.func}); return res; } [[nodiscard]] - vector explicate(const Condition& cond) const { + vector explicate(const Condition& cond) const { if (info.type.is_privileged) { return {}; } - vector res; + vector res; if (const auto pimm = std::get_if(&cond.right)) { if (pimm->v != 0) { // no need to check for valid access, it must be a number @@ -181,17 +176,17 @@ class AssertExtractor { return res; } - vector operator()(const Assume& ins) const { return explicate(ins.cond); } + vector operator()(const Assume& ins) const { return explicate(ins.cond); } - vector operator()(const Jmp& ins) const { + vector operator()(const Jmp& ins) const { if (!ins.cond) { return {}; } return explicate(*ins.cond); } - vector operator()(const Mem& ins) const { - vector res; + vector operator()(const Mem& ins) const { + vector res; const Reg basereg = ins.access.basereg; Imm width{static_cast(ins.access.width)}; const int offset = ins.access.offset; @@ -219,24 +214,26 @@ class AssertExtractor { return res; } - vector operator()(const Atomic& ins) const { + vector operator()(const Atomic& ins) const { return { - Assert{TypeConstraint{ins.valreg, TypeGroup::number}}, - Assert{TypeConstraint{ins.access.basereg, TypeGroup::pointer}}, - Assert{make_valid_access(ins.access.basereg, ins.access.offset, - Imm{static_cast(ins.access.width)}, false)}, + Assertion{TypeConstraint{ins.valreg, TypeGroup::number}}, + Assertion{TypeConstraint{ins.access.basereg, TypeGroup::pointer}}, + Assertion{make_valid_access(ins.access.basereg, ins.access.offset, + Imm{static_cast(ins.access.width)}, false)}, }; } - vector operator()(const Un& ins) const { return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; } + vector operator()(const Un& ins) const { + return {Assertion{TypeConstraint{ins.dst, TypeGroup::number}}}; + } - vector operator()(const Bin& ins) const { + vector operator()(const Bin& ins) const { switch (ins.op) { case Bin::Op::MOV: if (const auto src = std::get_if(&ins.v)) { if (!ins.is64) { - return {Assert{TypeConstraint{*src, TypeGroup::number}}}; + return {Assertion{TypeConstraint{*src, TypeGroup::number}}}; } } return {}; @@ -244,27 +241,27 @@ class AssertExtractor { case Bin::Op::MOVSX16: case Bin::Op::MOVSX32: if (const auto src = std::get_if(&ins.v)) { - return {Assert{TypeConstraint{*src, TypeGroup::number}}}; + return {Assertion{TypeConstraint{*src, TypeGroup::number}}}; } return {}; case Bin::Op::ADD: { if (const auto src = std::get_if(&ins.v)) { - return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}, - Assert{TypeConstraint{*src, TypeGroup::ptr_or_num}}, Assert{Addable{*src, ins.dst}}, - Assert{Addable{ins.dst, *src}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}, + Assertion{TypeConstraint{*src, TypeGroup::ptr_or_num}}, Assertion{Addable{*src, ins.dst}}, + Assertion{Addable{ins.dst, *src}}}; } - return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; } case Bin::Op::SUB: { if (const auto reg = std::get_if(&ins.v)) { - vector res; + vector res; // disallow map-map since same type does not mean same offset // TODO: map identities res.emplace_back(TypeConstraint{ins.dst, TypeGroup::ptr_or_num}); res.emplace_back(Comparable{.r1 = ins.dst, .r2 = *reg, .or_r2_is_number = true}); return res; } - return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; } case Bin::Op::UDIV: case Bin::Op::UMOD: @@ -272,25 +269,26 @@ class AssertExtractor { case Bin::Op::SMOD: { if (const auto src = std::get_if(&ins.v)) { const bool is_signed = (ins.op == Bin::Op::SDIV || ins.op == Bin::Op::SMOD); - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}, Assert{ValidDivisor{*src, is_signed}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::number}}, + Assertion{ValidDivisor{*src, is_signed}}}; } - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::number}}}; } // For all other binary operations, the destination register must be a number and the source must either be an // immediate or a number. default: if (const auto src = std::get_if(&ins.v)) { - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}, - Assert{TypeConstraint{*src, TypeGroup::number}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::number}}, + Assertion{TypeConstraint{*src, TypeGroup::number}}}; } else { - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; + return {Assertion{TypeConstraint{ins.dst, TypeGroup::number}}}; } } assert(false); } }; -vector get_assertions(Instruction ins, const program_info& info, const std::optional& label) { +vector get_assertions(Instruction ins, const program_info& info, const std::optional& label) { return std::visit(AssertExtractor{info, label}, ins); } @@ -302,13 +300,8 @@ vector get_assertions(Instruction ins, const program_info& info, const s void explicate_assertions(cfg_t& cfg, const program_info& info) { for (auto& [label, bb] : cfg) { (void)label; // unused - vector insts; - for (const auto& ins : bb) { - for (const auto& a : get_assertions(ins, info, bb.label())) { - insts.emplace_back(a); - } - insts.push_back(ins); + for (auto& ins : bb) { + ins.preconditions = get_assertions(ins.cmd, info, bb.label()); } - bb.swap_instructions(insts); } } diff --git a/src/crab/cfg.hpp b/src/crab/cfg.hpp index 9c4ce70fa..8dd6e9ba0 100644 --- a/src/crab/cfg.hpp +++ b/src/crab/cfg.hpp @@ -35,12 +35,11 @@ class cfg_t; class basic_block_t final { friend class cfg_t; - private: public: basic_block_t(const basic_block_t&) = delete; using label_vec_t = std::set; - using stmt_list_t = std::vector; + using stmt_list_t = std::vector; using neighbour_const_iterator = label_vec_t::const_iterator; using neighbour_const_reverse_iterator = label_vec_t::const_reverse_iterator; using iterator = stmt_list_t::iterator; @@ -54,15 +53,15 @@ class basic_block_t final { label_vec_t m_prev, m_next; public: - void insert(const Instruction& arg) { + void insert(const GuardedInstruction& arg) { assert(label() != label_t::entry); assert(label() != label_t::exit); m_ts.push_back(arg); } - /// Insert an instruction at the front of the basic block. + /// Insert a GuardedInstruction at the front of the basic block. /// @note Cannot modify entry or exit blocks. - void insert_front(const Instruction& arg) { + void insert_front(const GuardedInstruction& arg) { assert(label() != label_t::entry); assert(label() != label_t::exit); m_ts.insert(m_ts.begin(), arg); @@ -625,11 +624,11 @@ struct prepare_cfg_options { cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, const prepare_cfg_options& options); void explicate_assertions(cfg_t& cfg, const program_info& info); -std::vector get_assertions(Instruction ins, const program_info& info, const std::optional& label); +std::vector get_assertions(Instruction ins, const program_info& info, const std::optional& label); void print_dot(const cfg_t& cfg, std::ostream& out); void print_dot(const cfg_t& cfg, const std::string& outfile); -std::ostream& operator<<(std::ostream& o, const crab::basic_block_t& bb); +std::ostream& operator<<(std::ostream& o, const basic_block_t& bb); std::ostream& operator<<(std::ostream& o, const crab::basic_block_rev_t& bb); std::ostream& operator<<(std::ostream& o, const cfg_t& cfg); diff --git a/src/crab/ebpf_domain.cpp b/src/crab/ebpf_domain.cpp index b9fee364a..02e37074a 100644 --- a/src/crab/ebpf_domain.cpp +++ b/src/crab/ebpf_domain.cpp @@ -1195,9 +1195,20 @@ static linear_constraint_t type_is_not_stack(const reg_pack_t& r) { return r.type != T_STACK; } +void ebpf_domain_t::operator()(const Assertion& assertion) { + if (check_require || thread_local_options.assume_assertions) { + this->current_assertion = to_string(assertion); + std::visit(*this, assertion); + this->current_assertion.clear(); + } +} + void ebpf_domain_t::operator()(const basic_block_t& bb) { - for (const Instruction& statement : bb) { - std::visit(*this, statement); + for (const GuardedInstruction& ins : bb) { + for (const Assertion& assertion : ins.preconditions) { + (*this)(assertion); + } + std::visit(*this, ins.cmd); } } @@ -1382,7 +1393,9 @@ void ebpf_domain_t::operator()(const Exit& a) { restore_callee_saved_registers(prefix); } -void ebpf_domain_t::operator()(const Jmp& a) {} +void ebpf_domain_t::operator()(const Jmp&) const { + // This is a NOP. It only exists to hold the jump preconditions. +} void ebpf_domain_t::operator()(const Comparable& s) { using namespace crab::dsl_syntax; @@ -1446,18 +1459,18 @@ void ebpf_domain_t::operator()(const BoundedLoopCount& s) { void ebpf_domain_t::operator()(const FuncConstraint& s) { // Look up the helper function id. const reg_pack_t& reg = reg_pack(s.reg); - auto src_interval = m_inv.eval_interval(reg.svalue); - if (auto sn = src_interval.singleton()) { + const auto src_interval = m_inv.eval_interval(reg.svalue); + if (const auto sn = src_interval.singleton()) { if (sn->fits()) { // We can now process it as if the id was immediate. - int32_t imm = sn->cast_to(); + const int32_t imm = sn->cast_to(); if (!global_program_info->platform->is_helper_usable(imm)) { require(m_inv, linear_constraint_t::false_const(), "invalid helper function id " + std::to_string(imm)); return; } Call call = make_call(imm, *global_program_info->platform); - for (Assert a : get_assertions(call, *global_program_info, {})) { - (*this)(a); + for (const Assertion& assertion : get_assertions(call, *global_program_info, {})) { + (*this)(assertion); } return; } @@ -1758,14 +1771,6 @@ void ebpf_domain_t::operator()(const ZeroCtxOffset& s) { require(m_inv, reg.ctx_offset == 0, "Nonzero context offset"); } -void ebpf_domain_t::operator()(const Assert& stmt) { - if (check_require || thread_local_options.assume_assertions) { - this->current_assertion = to_string(stmt.cst); - std::visit(*this, stmt.cst); - this->current_assertion.clear(); - } -} - void ebpf_domain_t::operator()(const Packet& a) { const auto reg = reg_pack(R0_RETURN_VALUE); constexpr Reg r0_reg{R0_RETURN_VALUE}; diff --git a/src/crab/ebpf_domain.hpp b/src/crab/ebpf_domain.hpp index c51c61fe6..9f3cee592 100644 --- a/src/crab/ebpf_domain.hpp +++ b/src/crab/ebpf_domain.hpp @@ -52,32 +52,34 @@ class ebpf_domain_t final { // abstract transformers void operator()(const basic_block_t& bb); - void operator()(const Addable&); - void operator()(const Assert&); void operator()(const Assume&); void operator()(const Bin&); void operator()(const Call&); void operator()(const CallLocal&); void operator()(const Callx&); - void operator()(const Comparable&); void operator()(const Exit&); - void operator()(const FuncConstraint&); - void operator()(const Jmp&); + void operator()(const Jmp&) const; void operator()(const LoadMapFd&); void operator()(const Atomic&); void operator()(const Mem&); - void operator()(const ValidDivisor&); void operator()(const Packet&); - void operator()(const TypeConstraint&); void operator()(const Un&); void operator()(const Undefined&); + void operator()(const IncrementLoopCounter&); + + void operator()(const Assertion&); + + void operator()(const Addable&); + void operator()(const Comparable&); + void operator()(const FuncConstraint&); + void operator()(const ValidDivisor&); + void operator()(const TypeConstraint&); void operator()(const ValidAccess&); void operator()(const ValidCall&); void operator()(const ValidMapKeyValue&); void operator()(const ValidSize&); void operator()(const ValidStore&); void operator()(const ZeroCtxOffset&); - void operator()(const IncrementLoopCounter&); void operator()(const BoundedLoopCount&); void initialize_loop_counter(const label_t& label); diff --git a/src/crab/split_dbm.cpp b/src/crab/split_dbm.cpp index 25e2cc69f..1c8620718 100644 --- a/src/crab/split_dbm.cpp +++ b/src/crab/split_dbm.cpp @@ -1206,7 +1206,8 @@ string_invariant SplitDBM::to_set() const { std::ostream& operator<<(std::ostream& o, const SplitDBM& dom) { return o << dom.to_set(); } bool SplitDBM::eval_expression_overflow(const linear_expression_t& e, Weight& out) const { - [[maybe_unused]] const bool overflow = convert_NtoW_overflow(e.constant_term(), out); + [[maybe_unused]] + const bool overflow = convert_NtoW_overflow(e.constant_term(), out); assert(!overflow); for (const auto& [variable, coefficient] : e.variable_terms()) { Weight coef; diff --git a/src/crab/wto.cpp b/src/crab/wto.cpp index a179cae11..4dc3e2018 100644 --- a/src/crab/wto.cpp +++ b/src/crab/wto.cpp @@ -41,7 +41,7 @@ struct visit_args_t { std::weak_ptr containing_cycle; visit_args_t(const visit_task_type_t t, label_t v, wto_partition_t& p, std::weak_ptr cc) - : type(t), vertex(std::move(v)), partition(p), containing_cycle(std::move(cc)){}; + : type(t), vertex(std::move(v)), partition(p), containing_cycle(std::move(cc)) {}; }; struct wto_vertex_data_t { diff --git a/src/test/test_marshal.cpp b/src/test/test_marshal.cpp index 68e08dfb6..230d045c9 100644 --- a/src/test/test_marshal.cpp +++ b/src/test/test_marshal.cpp @@ -263,8 +263,7 @@ static void compare_unmarshal_marshal(const ebpf_inst& ins1, const ebpf_inst& in program_info info{.platform = &g_ebpf_platform_linux, .type = g_ebpf_platform_linux.get_program_type("unspec", "unspec")}; constexpr ebpf_inst exit{.opcode = INST_OP_EXIT}; - InstructionSeq parsed = - std::get(unmarshal(raw_program{"", "", 0, "", {ins1, ins2, exit, exit}, info})); + auto parsed = std::get(unmarshal(raw_program{"", "", 0, "", {ins1, ins2, exit, exit}, info})); REQUIRE(parsed.size() == 3); auto [_, single, _2] = parsed.front(); (void)_; // unused @@ -290,14 +289,16 @@ static void compare_marshal_unmarshal(const Instruction& ins, bool double_cmd = REQUIRE(single == ins); } -static void check_marshal_unmarshal_fail(const Instruction& ins, std::string expected_error_message, +static void check_marshal_unmarshal_fail(const Instruction& ins, const std::string& expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { const program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; - std::string error_message = std::get(unmarshal(raw_program{"", "", 0, "", marshal(ins, 0), info})); - REQUIRE(error_message == expected_error_message); + auto result = unmarshal(raw_program{"", "", 0, "", marshal(ins, 0), info}); + auto* error_message = std::get_if(&result); + REQUIRE(error_message != nullptr); + REQUIRE(*error_message == expected_error_message); } -static void check_unmarshal_fail(ebpf_inst inst, std::string expected_error_message, +static void check_unmarshal_fail(ebpf_inst inst, const std::string& expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; std::vector insns = {inst}; @@ -319,7 +320,7 @@ static void check_unmarshal_fail_goto(ebpf_inst inst, const std::string& expecte } // Check that unmarshaling a 64-bit immediate instruction fails. -static void check_unmarshal_fail(ebpf_inst inst1, ebpf_inst inst2, std::string expected_error_message, +static void check_unmarshal_fail(ebpf_inst inst1, ebpf_inst inst2, const std::string& expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; std::vector insns{inst1, inst2};