Skip to content

Commit

Permalink
Enforce upper bound on loop count on all instructions in loops (vbpf#745
Browse files Browse the repository at this point in the history
)

Emit warning if maximum number of iterations is reached.

Signed-off-by: Alan Jowett <[email protected]>
  • Loading branch information
Alan-Jowett authored Nov 6, 2024
1 parent 076ae86 commit 7616c29
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 43 deletions.
14 changes: 13 additions & 1 deletion src/asm_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "asm_syntax.hpp"
#include "crab/cfg.hpp"
#include "crab/wto.hpp"

using std::optional;
using std::set;
Expand Down Expand Up @@ -358,11 +359,22 @@ std::map<std::string, int> collect_stats(const cfg_t& cfg) {
return res;
}

// ISSUE: 774 - Rationalize the list of bools being passed to prepare_cfg.
cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, const bool simplify,
const bool must_have_exit) {
const bool check_for_termination, const bool must_have_exit) {
// Convert the instruction sequence to a deterministic control-flow graph.
cfg_t det_cfg = instruction_seq_to_cfg(prog, must_have_exit);

// Detect loops using Weak Topological Ordering (WTO) and insert counters at loop entry points. WTO provides a
// hierarchical decomposition of the CFG that identifies all strongly connected components (cycles) and their entry
// points. These entry points serve as natural locations for loop counters that help verify program termination.
if (check_for_termination) {
wto_t wto(det_cfg);
wto.for_each_loop_head([&](const label_t& label) {
det_cfg.get_node(label).insert_front(IncrementLoopCounter{label});
});
}

// Annotate the CFG by adding in assertions before every memory instruction.
explicate_assertions(det_cfg, info);

Expand Down
4 changes: 4 additions & 0 deletions src/asm_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ std::ostream& operator<<(std::ostream& os, ValidAccess const& a) {
return os;
}

std::ostream& operator<<(std::ostream& os, const BoundedLoopCount& a) {
return os << crab::variable_t::loop_counter(to_string(a.name)) << " < " << a.limit;
}

static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); }

std::ostream& operator<<(std::ostream& os, ValidSize const& a) {
Expand Down
11 changes: 10 additions & 1 deletion src/asm_syntax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,17 @@ struct ZeroCtxOffset {
constexpr bool operator==(const ZeroCtxOffset&) const = default;
};

struct BoundedLoopCount {
label_t name;
bool operator==(const BoundedLoopCount&) const = default;
// Maximum number of loop iterations allowed during verification.
// This prevents infinite loops while allowing reasonable bounded loops.
// When exceeded, verification fails as the loop might not terminate.
static constexpr int limit = 100000;
};

using AssertionConstraint = std::variant<Comparable, Addable, ValidDivisor, ValidAccess, ValidStore, ValidSize,
ValidMapKeyValue, ValidCall, TypeConstraint, FuncConstraint, ZeroCtxOffset>;
ValidMapKeyValue, ValidCall, TypeConstraint, FuncConstraint, ZeroCtxOffset, BoundedLoopCount>;

struct Assert {
AssertionConstraint cst;
Expand Down
5 changes: 2 additions & 3 deletions src/assertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class AssertExtractor {
return {};
}

vector<Assert> operator()(IncrementLoopCounter) const {
assert(false);
return {};
vector<Assert> operator()(IncrementLoopCounter& ipc) const {
return {{BoundedLoopCount{ipc.name}}};
}

vector<Assert> operator()(LoadMapFd const&) const { return {}; }
Expand Down
11 changes: 10 additions & 1 deletion src/crab/cfg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class basic_block_t final {
m_ts.push_back(arg);
}

/// Insert an instruction at the front of the basic block.
/// @note Cannot modify entry or exit blocks.
void insert_front(const Instruction& arg) {
assert(label() != label_t::entry);
assert(label() != label_t::exit);
m_ts.insert(m_ts.begin(), arg);
}

explicit basic_block_t(label_t _label) : m_label(_label) {}

~basic_block_t() = default;
Expand Down Expand Up @@ -611,7 +619,8 @@ std::vector<std::string> stats_headers();

std::map<std::string, int> collect_stats(const cfg_t&);

cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, bool simplify, bool must_have_exit = true);
cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, bool simplify, bool check_for_termination,
bool must_have_exit = true);

void explicate_assertions(cfg_t& cfg, const program_info& info);
std::vector<Assert> get_assertions(Instruction ins, const program_info& info, const std::optional<label_t>& label);
Expand Down
11 changes: 10 additions & 1 deletion src/crab/ebpf_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,14 @@ void ebpf_domain_t::operator()(const TypeConstraint& s) {
}
}

void ebpf_domain_t::operator()(const BoundedLoopCount& s) {
// Enforces an upper bound on loop iterations by checking that the loop counter
// does not exceed the specified limit
using namespace crab::dsl_syntax;
const auto counter = variable_t::loop_counter(to_string(s.name));
require(m_inv, counter <= s.limit, "Loop counter is too large");
}

void ebpf_domain_t::operator()(const FuncConstraint& s) {
// Look up the helper function id.
const reg_pack_t& reg = reg_pack(s.reg);
Expand Down Expand Up @@ -2954,6 +2962,7 @@ extended_number ebpf_domain_t::get_loop_count_upper_bound() const {
}

void ebpf_domain_t::operator()(const IncrementLoopCounter& ins) {
this->add(variable_t::loop_counter(to_string(ins.name)), 1);
const auto counter = variable_t::loop_counter(to_string(ins.name));
this->add(counter, 1);
}
} // namespace crab
1 change: 1 addition & 0 deletions src/crab/ebpf_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ebpf_domain_t final {
void operator()(const ValidStore&);
void operator()(const ZeroCtxOffset&);
void operator()(const IncrementLoopCounter&);
void operator()(const BoundedLoopCount&);

void initialize_loop_counter(const label_t& label);
static ebpf_domain_t calculate_constant_limits();
Expand Down
25 changes: 10 additions & 15 deletions src/crab/fwd_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class member_component_visitor final {
class interleaved_fwd_fixpoint_iterator_t final {
using iterator = invariant_table_t::iterator;

cfg_t& _cfg;
const cfg_t& _cfg;
wto_t _wto;
invariant_table_t _pre, _post;

Expand All @@ -68,7 +68,7 @@ class interleaved_fwd_fixpoint_iterator_t final {
void set_pre(const label_t& label, const ebpf_domain_t& v) { _pre[label] = v; }

void transform_to_post(const label_t& label, ebpf_domain_t pre) {
basic_block_t& bb = _cfg.get_node(label);
const basic_block_t& bb = _cfg.get_node(label);
pre(bb);
_post[label] = std::move(pre);
}
Expand Down Expand Up @@ -101,7 +101,7 @@ class interleaved_fwd_fixpoint_iterator_t final {
}

public:
explicit interleaved_fwd_fixpoint_iterator_t(cfg_t& cfg) : _cfg(cfg), _wto(cfg) {
explicit interleaved_fwd_fixpoint_iterator_t(const cfg_t& cfg) : _cfg(cfg), _wto(cfg) {
for (const auto& label : _cfg.labels()) {
_pre.emplace(label, ebpf_domain_t::bottom());
_post.emplace(label, ebpf_domain_t::bottom());
Expand All @@ -116,23 +116,18 @@ class interleaved_fwd_fixpoint_iterator_t final {

void operator()(const std::shared_ptr<wto_cycle_t>& cycle);

friend std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(cfg_t& cfg, ebpf_domain_t entry_inv);
friend std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(const cfg_t& cfg, ebpf_domain_t entry_inv);
};

std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(cfg_t& cfg, ebpf_domain_t entry_inv) {
std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(const cfg_t& cfg, ebpf_domain_t entry_inv) {
// Go over the CFG in weak topological order (accounting for loops).
interleaved_fwd_fixpoint_iterator_t analyzer(cfg);
if (thread_local_options.check_termination) {
std::vector<label_t> cycle_heads;
for (auto& component : analyzer._wto) {
if (const auto pc = std::get_if<std::shared_ptr<wto_cycle_t>>(&component)) {
cycle_heads.push_back((*pc)->head());
}
}
for (const label_t& label : cycle_heads) {
entry_inv.initialize_loop_counter(label);
cfg.get_node(label).insert(IncrementLoopCounter{label});
}
// Initialize loop counters for potential loop headers.
// This enables enforcement of upper bounds on loop iterations
// during program verification.
// TODO: Consider making this an instruction instead of an explicit call.
analyzer._wto.for_each_loop_head([&](const label_t& label) { entry_inv.initialize_loop_counter(label); });
}
analyzer.set_pre(cfg.entry_label(), entry_inv);
for (auto& component : analyzer._wto) {
Expand Down
2 changes: 1 addition & 1 deletion src/crab/fwd_analyzer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ namespace crab {

using invariant_table_t = std::map<label_t, ebpf_domain_t>;

std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(cfg_t& cfg, ebpf_domain_t entry_inv);
std::pair<invariant_table_t, invariant_table_t> run_forward_analyzer(const cfg_t& cfg, ebpf_domain_t entry_inv);

} // namespace crab
17 changes: 17 additions & 0 deletions src/crab/wto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,21 @@ class wto_t final {

friend std::ostream& operator<<(std::ostream& o, const wto_t& wto);
const wto_nesting_t& nesting(const label_t& label);

/**
* Visit the heads of all loops in the WTO.
*
* @param f The callable to be invoked for each loop head.
*
* The order in which the heads are visited is not specified.
*/
template<typename F>
void for_each_loop_head(F&& f) const {
for (const auto& component : *this) {
if (const auto pc = std::get_if<std::shared_ptr<wto_cycle_t>>(&component)) {
f((*pc)->head());
}
}
}

};
27 changes: 14 additions & 13 deletions src/crab_verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ struct checks_db final {
checks_db() = default;
};

static checks_db generate_report(cfg_t& cfg, const crab::invariant_table_t& pre_invariants,
static checks_db generate_report(const cfg_t& cfg, const crab::invariant_table_t& pre_invariants,
const crab::invariant_table_t& post_invariants) {
checks_db m_db;
for (const label_t& label : cfg.sorted_labels()) {
basic_block_t& bb = cfg.get_node(label);
const basic_block_t& bb = cfg.get_node(label);
ebpf_domain_t from_inv(pre_invariants.at(label));
from_inv.set_require_check(
[&m_db, label](auto& inv, const crab::linear_constraint_t& cst, const std::string& s) {
Expand Down Expand Up @@ -95,8 +95,13 @@ static checks_db generate_report(cfg_t& cfg, const crab::invariant_table_t& pre_
}

if (thread_local_options.check_termination) {
const auto last_inv = post_invariants.at(cfg.exit_label());
m_db.max_loop_count = last_inv.get_loop_count_upper_bound();
// Gather the upper bound of loop counts from post-invariants.
for (const auto [label, inv] : post_invariants) {
if (inv.is_bottom()) {
continue;
}
m_db.max_loop_count = std::max(m_db.max_loop_count, inv.get_loop_count_upper_bound());
}
}
return m_db;
}
Expand Down Expand Up @@ -127,13 +132,9 @@ static void print_report(std::ostream& os, const checks_db& db, const Instructio
}
}
os << "\n";
const crab::number_t max_loop_count{100000};
if (db.max_loop_count > max_loop_count) {
os << "Could not prove termination.\n";
}
}

static checks_db get_analysis_report(std::ostream& s, cfg_t& cfg, const crab::invariant_table_t& pre_invariants,
static checks_db get_analysis_report(std::ostream& s, const cfg_t& cfg, const crab::invariant_table_t& pre_invariants,
const crab::invariant_table_t& post_invariants,
const std::optional<InstructionSeq>& prog = std::nullopt) {
// Analyze the control-flow graph.
Expand Down Expand Up @@ -161,7 +162,7 @@ static checks_db get_analysis_report(std::ostream& s, cfg_t& cfg, const crab::in
return db;
}

static checks_db get_ebpf_report(std::ostream& s, cfg_t& cfg, program_info info, const ebpf_verifier_options_t* options,
static checks_db get_ebpf_report(std::ostream& s, const cfg_t& cfg, program_info info, const ebpf_verifier_options_t* options,
const std::optional<InstructionSeq>& prog = std::nullopt) {
global_program_info = std::move(info);
crab::domains::clear_global_state();
Expand All @@ -182,7 +183,7 @@ static checks_db get_ebpf_report(std::ostream& s, cfg_t& cfg, program_info info,
}

/// Returned value is true if the program passes verification.
bool run_ebpf_analysis(std::ostream& s, cfg_t& cfg, const program_info& info, const ebpf_verifier_options_t* options,
bool run_ebpf_analysis(std::ostream& s, const cfg_t& cfg, const program_info& info, const ebpf_verifier_options_t* options,
ebpf_verifier_stats_t* stats) {
if (options == nullptr) {
options = &ebpf_verifier_default_options;
Expand Down Expand Up @@ -219,7 +220,7 @@ std::tuple<string_invariant, bool> ebpf_analyze_program_for_test(std::ostream& o
throw std::runtime_error("Entry invariant is inconsistent");
}
try {
cfg_t cfg = prepare_cfg(prog, info, options.simplify, false);
const cfg_t cfg = prepare_cfg(prog, info, options.simplify, options.check_termination, false);
auto [pre_invariants, post_invariants] = run_forward_analyzer(cfg, std::move(entry_inv));
const checks_db report = get_analysis_report(std::cerr, cfg, pre_invariants, post_invariants);
print_report(os, report, prog, false);
Expand All @@ -242,7 +243,7 @@ bool ebpf_verify_program(std::ostream& os, const InstructionSeq& prog, const pro

// Convert the instruction sequence to a control-flow graph
// in a "passive", non-deterministic form.
cfg_t cfg = prepare_cfg(prog, info, options->simplify);
const cfg_t cfg = prepare_cfg(prog, info, options->simplify, options->check_termination);

std::optional<InstructionSeq> prog_opt = std::nullopt;
if (options->print_failures) {
Expand Down
2 changes: 1 addition & 1 deletion src/crab_verifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "spec_type_descriptors.hpp"
#include "string_constraints.hpp"

bool run_ebpf_analysis(std::ostream& s, cfg_t& cfg, const program_info& info, const ebpf_verifier_options_t* options,
bool run_ebpf_analysis(std::ostream& s, const cfg_t& cfg, const program_info& info, const ebpf_verifier_options_t* options,
ebpf_verifier_stats_t* stats);

bool ebpf_verify_program(std::ostream& os, const InstructionSeq& prog, const program_info& info,
Expand Down
4 changes: 2 additions & 2 deletions src/main/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ int main(int argc, char** argv) {
return !res;
} else if (domain == "stats") {
// Convert the instruction sequence to a control-flow graph.
cfg_t cfg = prepare_cfg(prog, raw_prog.info, ebpf_verifier_options.simplify);
const cfg_t cfg = prepare_cfg(prog, raw_prog.info, ebpf_verifier_options.simplify, ebpf_verifier_options.check_termination);

// Just print eBPF program stats.
auto stats = collect_stats(cfg);
Expand All @@ -264,7 +264,7 @@ int main(int argc, char** argv) {
std::cout << "\n";
} else if (domain == "cfg") {
// Convert the instruction sequence to a control-flow graph.
cfg_t cfg = prepare_cfg(prog, raw_prog.info, ebpf_verifier_options.simplify);
const cfg_t cfg = prepare_cfg(prog, raw_prog.info, ebpf_verifier_options.simplify, ebpf_verifier_options.check_termination);
std::cout << cfg;
std::cout << "\n";
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/test/test_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ TEST_SECTION_FAIL("cilium", "bpf_xdp_snat_linux.o", "2/16")
// False positive, unknown cause
TEST_SECTION_FAIL("linux", "test_map_in_map_kern.o", "kprobe/sys_connect")

void test_analyze_thread(cfg_t* cfg, program_info* info, bool* res) {
void test_analyze_thread(const cfg_t* cfg, program_info* info, bool* res) {
*res = run_ebpf_analysis(std::cout, *cfg, *info, nullptr, nullptr);
}

Expand All @@ -598,15 +598,15 @@ TEST_CASE("multithreading", "[verify][multithreading]") {
auto prog_or_error1 = unmarshal(raw_prog1);
auto prog1 = std::get_if<InstructionSeq>(&prog_or_error1);
REQUIRE(prog1 != nullptr);
cfg_t cfg1 = prepare_cfg(*prog1, raw_prog1.info, true);
const cfg_t cfg1 = prepare_cfg(*prog1, raw_prog1.info, true, true);

auto raw_progs2 = read_elf("ebpf-samples/bpf_cilium_test/bpf_netdev.o", "2/2", nullptr, &g_ebpf_platform_linux);
REQUIRE(raw_progs2.size() == 1);
raw_program raw_prog2 = raw_progs2.back();
auto prog_or_error2 = unmarshal(raw_prog2);
auto prog2 = std::get_if<InstructionSeq>(&prog_or_error2);
REQUIRE(prog2 != nullptr);
cfg_t cfg2 = prepare_cfg(*prog2, raw_prog2.info, true);
const cfg_t cfg2 = prepare_cfg(*prog2, raw_prog2.info, true, true);

bool res1, res2;
std::thread a(test_analyze_thread, &cfg1, &raw_prog1.info, &res1);
Expand Down
Loading

0 comments on commit 7616c29

Please sign in to comment.