Skip to content

Commit

Permalink
Make the assertions part of a generalized instruction (#779)
Browse files Browse the repository at this point in the history
Make the assertions part of the generalized instruction

Each guarded instruction is a pair (Instruction, vector<Assertion>).

Jump instructions are retained in the nondeterministic graph, since they store the preconditions of the jump

Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
elazarg authored Nov 8, 2024
1 parent bd55b33 commit f07d66a
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 227 deletions.
10 changes: 5 additions & 5 deletions scripts/format-code
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ check_clang-format()
cf=$(command -v clang-format 2> /dev/null)
if [[ ! -x ${cf} ]]; then
echo "clang-format is not installed"
exit 1
exit 0 # do not fail, just warn
fi
cf="clang-format"
fi

local required_cfver='17.0.3'
# shellcheck disable=SC2155
local cfver=$(${cf} --version | grep -o -E '[0-9]+\.[0-9]+\.[0-9]+' | head -1)
check_version "${required_cfver}" "${cfver}"
# local required_cfver='17.0.3'
# # shellcheck disable=SC2155
# local cfver=$(${cf} --version | grep -o -E '[0-9]+\.[0-9]+\.[0-9]+' | head -1)
# check_version "${required_cfver}" "${cfver}"
}

check_clang-format
Expand Down
32 changes: 15 additions & 17 deletions src/asm_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
basic_block_t& caller_node = cfg.get_node(caller_label);
const std::string stack_frame_prefix = to_string(caller_label);
for (auto& inst : caller_node) {
if (const auto pcall = std::get_if<CallLocal>(&inst)) {
if (const auto pcall = std::get_if<CallLocal>(&inst.cmd)) {
pcall->stack_frame_prefix = stack_frame_prefix;
}
}
Expand All @@ -73,9 +73,9 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
const label_t label{macro_label.from, macro_label.to, stack_frame_prefix};
auto& bb = cfg.insert(label);
for (auto inst : cfg.get_node(macro_label)) {
if (const auto pexit = std::get_if<Exit>(&inst)) {
if (const auto pexit = std::get_if<Exit>(&inst.cmd)) {
pexit->stack_frame_prefix = label.stack_frame_prefix;
} else if (const auto pcall = std::get_if<Call>(&inst)) {
} else if (const auto pcall = std::get_if<Call>(&inst.cmd)) {
pcall->stack_frame_prefix = label.stack_frame_prefix;
}
bb.insert(inst);
Expand Down Expand Up @@ -123,7 +123,7 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
for (const auto& macro_label : seen_labels) {
for (const label_t label(macro_label.from, macro_label.to, caller_label_str);
const auto& inst : cfg.get_node(label)) {
if (const auto pins = std::get_if<CallLocal>(&inst)) {
if (const auto pins = std::get_if<CallLocal>(&inst.cmd)) {
if (stack_frame_depth >= MAX_CALL_STACK_FRAMES) {
throw std::runtime_error{"too many call stack frames"};
}
Expand Down Expand Up @@ -153,7 +153,7 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, const bool must
cfg.get_node(cfg.entry_label()) >> bb;
}

bb.insert(inst);
bb.insert({.cmd = inst});
if (falling_from) {
cfg.get_node(*falling_from) >> bb;
falling_from = {};
Expand Down Expand Up @@ -236,9 +236,7 @@ static cfg_t to_nondet(const cfg_t& cfg) {
basic_block_t& newbb = res.insert(this_label);

for (const auto& ins : bb) {
if (!std::holds_alternative<Jmp>(ins)) {
newbb.insert(ins);
}
newbb.insert(ins);
}

for (const label_t& prev_label : bb.prev_blocks_set()) {
Expand All @@ -250,7 +248,7 @@ static cfg_t to_nondet(const cfg_t& cfg) {
auto nextlist = bb.next_blocks_set();
if (nextlist.size() == 2) {
label_t mid_label = this_label;
auto jmp = std::get<Jmp>(*bb.rbegin());
auto jmp = std::get<Jmp>(bb.rbegin()->cmd);

nextlist.erase(jmp.target);
label_t fallthrough = *nextlist.begin();
Expand All @@ -262,7 +260,7 @@ static cfg_t to_nondet(const cfg_t& cfg) {
for (const auto& [next_label, cond1] : jumps) {
label_t jump_label = label_t::make_jump(mid_label, next_label);
basic_block_t& jump_bb = res.insert(jump_label);
jump_bb.insert(Assume{cond1});
jump_bb.insert({.cmd = Assume{cond1}});
newbb >> jump_bb;
jump_bb >> res.insert(next_label);
}
Expand All @@ -275,7 +273,7 @@ static cfg_t to_nondet(const cfg_t& cfg) {
return res;
}

/// Get the type of given instruction.
/// Get the type of given Instruction.
/// Most of these type names are also statistics header labels.
static std::string instype(Instruction ins) {
if (const auto pcall = std::get_if<Call>(&ins)) {
Expand Down Expand Up @@ -333,21 +331,21 @@ std::map<std::string, int> collect_stats(const cfg_t& cfg) {
res["basic_blocks"]++;
basic_block_t const& bb = cfg.get_node(this_label);

for (Instruction ins : bb) {
if (const auto pins = std::get_if<LoadMapFd>(&ins)) {
for (const auto& ins : bb) {
if (const auto pins = std::get_if<LoadMapFd>(&ins.cmd)) {
if (pins->mapfd == -1) {
res["map_in_map"] = 1;
}
}
if (const auto pins = std::get_if<Call>(&ins)) {
if (const auto pins = std::get_if<Call>(&ins.cmd)) {
if (pins->reallocate_packet) {
res["reallocate"] = 1;
}
}
if (const auto pins = std::get_if<Bin>(&ins)) {
if (const auto pins = std::get_if<Bin>(&ins.cmd)) {
res[pins->is64 ? "arith64" : "arith32"]++;
}
res[instype(ins)]++;
res[instype(ins.cmd)]++;
}
if (unique(bb.prev_blocks()).size() > 1) {
res["joins"]++;
Expand All @@ -369,7 +367,7 @@ cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, const pr
if (options.check_for_termination) {
const wto_t wto(det_cfg);
wto.for_each_loop_head(
[&](const label_t& label) { det_cfg.get_node(label).insert_front(IncrementLoopCounter{label}); });
[&](const label_t& label) { det_cfg.get_node(label).insert_front({.cmd = IncrementLoopCounter{label}}); });
}

// Annotate the CFG by adding in assertions before every memory instruction.
Expand Down
2 changes: 0 additions & 2 deletions src/asm_marshal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ struct MarshalVisitor {

vector<ebpf_inst> operator()(Assume const&) const { throw std::invalid_argument("Cannot marshal assumptions"); }

vector<ebpf_inst> operator()(Assert const&) const { throw std::invalid_argument("Cannot marshal assertions"); }

vector<ebpf_inst> operator()(Jmp const& b) const {
if (b.cond) {
ebpf_inst res{
Expand Down
160 changes: 87 additions & 73 deletions src/asm_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,87 +105,87 @@ std::ostream& operator<<(std::ostream& os, const Condition::Op op) {

static string size(const int w) { return string("u") + std::to_string(w * 8); }

std::ostream& operator<<(std::ostream& os, ValidStore const& a) {
return os << a.mem << ".type != stack -> " << TypeConstraint{a.val, TypeGroup::number};
}
// ReSharper disable CppMemberFunctionMayBeConst
struct AssertionPrinterVisitor {
std::ostream& _os;

void operator()(ValidStore const& a) {
_os << a.mem << ".type != stack -> " << TypeConstraint{a.val, TypeGroup::number};
}

void operator()(ValidAccess const& a) {
if (a.or_null) {
_os << "(" << TypeConstraint{a.reg, TypeGroup::number} << " and " << a.reg << ".value == 0) or ";
}
_os << "valid_access(" << a.reg << ".offset";
if (a.offset > 0) {
_os << "+" << a.offset;
} else if (a.offset < 0) {
_os << a.offset;
}

std::ostream& operator<<(std::ostream& os, ValidAccess const& a) {
if (a.or_null) {
os << "(" << TypeConstraint{a.reg, TypeGroup::number} << " and " << a.reg << ".value == 0) or ";
}
os << "valid_access(" << a.reg << ".offset";
if (a.offset > 0) {
os << "+" << a.offset;
} else if (a.offset < 0) {
os << a.offset;
}

if (a.width == Value{Imm{0}}) {
// a.width == 0, meaning we only care it's an in-bound pointer,
// so it can be compared with another pointer to the same region.
os << ") for comparison/subtraction";
} else {
os << ", width=" << a.width << ") for ";
if (a.access_type == AccessType::read) {
os << "read";
if (a.width == Value{Imm{0}}) {
// a.width == 0, meaning we only care it's an in-bound pointer,
// so it can be compared with another pointer to the same region.
_os << ") for comparison/subtraction";
} else {
os << "write";
_os << ", width=" << a.width << ") for ";
if (a.access_type == AccessType::read) {
_os << "read";
} else {
_os << "write";
}
}
}
return os;
}

std::ostream& operator<<(std::ostream& os, const BoundedLoopCount& a) {
return os << crab::variable_t::loop_counter(to_string(a.name)) << " < " << a.limit;
}

static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); }
void operator()(const BoundedLoopCount& a) {
_os << crab::variable_t::loop_counter(to_string(a.name)) << " < " << a.limit;
}

std::ostream& operator<<(std::ostream& os, ValidSize const& a) {
const auto op = a.can_be_zero ? " >= " : " > ";
return os << a.reg << ".value" << op << 0;
}
static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); }

std::ostream& operator<<(std::ostream& os, ValidCall const& a) {
const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func);
return os << "valid call(" << proto.name << ")";
}
void operator()(ValidSize const& a) {
const auto op = a.can_be_zero ? " >= " : " > ";
_os << a.reg << ".value" << op << 0;
}

std::ostream& operator<<(std::ostream& os, ValidMapKeyValue const& a) {
return os << "within stack(" << a.access_reg << ":" << (a.key ? "key_size" : "value_size") << "(" << a.map_fd_reg
<< "))";
}
void operator()(ValidCall const& a) {
const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func);
_os << "valid call(" << proto.name << ")";
}

std::ostream& operator<<(std::ostream& os, ZeroCtxOffset const& a) {
return os << crab::variable_t::reg(crab::data_kind_t::ctx_offsets, a.reg.v) << " == 0";
}
void operator()(ValidMapKeyValue const& a) {
_os << "within stack(" << a.access_reg << ":" << (a.key ? "key_size" : "value_size") << "(" << a.map_fd_reg
<< "))";
}

std::ostream& operator<<(std::ostream& os, Comparable const& a) {
if (a.or_r2_is_number) {
os << TypeConstraint{a.r2, TypeGroup::number} << " or ";
void operator()(ZeroCtxOffset const& a) {
_os << crab::variable_t::reg(crab::data_kind_t::ctx_offsets, a.reg.v) << " == 0";
}
return os << typereg(a.r1) << " == " << typereg(a.r2) << " in " << TypeGroup::singleton_ptr;
}

std::ostream& operator<<(std::ostream& os, Addable const& a) {
return os << TypeConstraint{a.ptr, TypeGroup::pointer} << " -> " << TypeConstraint{a.num, TypeGroup::number};
}
void operator()(Comparable const& a) {
if (a.or_r2_is_number) {
_os << TypeConstraint{a.r2, TypeGroup::number} << " or ";
}
_os << typereg(a.r1) << " == " << typereg(a.r2) << " in " << TypeGroup::singleton_ptr;
}

std::ostream& operator<<(std::ostream& os, ValidDivisor const& a) { return os << a.reg << " != 0"; }
void operator()(Addable const& a) {
_os << TypeConstraint{a.ptr, TypeGroup::pointer} << " -> " << TypeConstraint{a.num, TypeGroup::number};
}

std::ostream& operator<<(std::ostream& os, TypeConstraint const& tc) {
const string cmp_op = is_singleton_type(tc.types) ? "==" : "in";
return os << typereg(tc.reg) << " " << cmp_op << " " << tc.types;
}
void operator()(ValidDivisor const& a) { _os << a.reg << " != 0"; }

std::ostream& operator<<(std::ostream& os, FuncConstraint const& fc) { return os << typereg(fc.reg) << " is helper"; }
void operator()(TypeConstraint const& tc) {
const string cmp_op = is_singleton_type(tc.types) ? "==" : "in";
_os << typereg(tc.reg) << " " << cmp_op << " " << tc.types;
}

std::ostream& operator<<(std::ostream& os, AssertionConstraint const& a) {
return std::visit([&](const auto& a) -> std::ostream& { return os << a; }, a);
}
void operator()(FuncConstraint const& fc) { _os << typereg(fc.reg) << " is helper"; }
};

// ReSharper disable CppMemberFunctionMayBeConst
struct InstructionPrinterVisitor {
struct CommandPrinterVisitor {
std::ostream& os_;

void visit(const auto& item) { std::visit(*this, item); }
Expand Down Expand Up @@ -259,7 +259,7 @@ struct InstructionPrinterVisitor {
void operator()(Exit const& b) { os_ << "exit"; }

void operator()(Jmp const& b) {
// A "standalone" jump instruction.
// A "standalone" jump Instruction.
// Print the label without offset calculations.
if (b.cond) {
os_ << "if ";
Expand Down Expand Up @@ -351,8 +351,6 @@ struct InstructionPrinterVisitor {
print(b.cond);
}

void operator()(Assert const& a) { os_ << "assert " << a.cst; }

void operator()(IncrementLoopCounter const& a) { os_ << crab::variable_t::loop_counter(to_string(a.name)) << "++"; }
};
// ReSharper restore CppMemberFunctionMayBeConst
Expand All @@ -364,7 +362,7 @@ string to_string(label_t const& label) {
}

std::ostream& operator<<(std::ostream& os, Instruction const& ins) {
std::visit(InstructionPrinterVisitor{os}, ins);
std::visit(CommandPrinterVisitor{os}, ins);
return os;
}

Expand All @@ -374,7 +372,12 @@ string to_string(Instruction const& ins) {
return str.str();
}

string to_string(AssertionConstraint const& constraint) {
std::ostream& operator<<(std::ostream& os, const Assertion& a) {
std::visit(AssertionPrinterVisitor{os}, a);
return os;
}

string to_string(Assertion const& constraint) {
std::stringstream str;
str << constraint;
return str.str();
Expand Down Expand Up @@ -407,10 +410,10 @@ void print(const InstructionSeq& insts, std::ostream& out, const std::optional<c
const auto pc_of_label = get_labels(insts);
pc_t pc = 0;
std::string previous_source;
InstructionPrinterVisitor visitor{out};
CommandPrinterVisitor visitor{out};
for (const LabeledInstruction& labeled_inst : insts) {
const auto& [label, ins, line_info] = labeled_inst;
if (!label_to_print.has_value() || (label == label_to_print)) {
if (!label_to_print.has_value() || label == label_to_print) {
if (line_info.has_value() && print_line_info) {
auto& [file, source, line, column] = line_info.value();
// Only decorate the first instruction associated with a source line.
Expand Down Expand Up @@ -469,7 +472,10 @@ void print_dot(const cfg_t& cfg, std::ostream& out) {

const auto& bb = cfg.get_node(label);
for (const auto& ins : bb) {
out << ins << "\\l";
for (const auto& pre : ins.preconditions) {
out << "assert " << pre << "\\l";
}
out << ins.cmd << "\\l";
}

out << "\"];\n";
Expand All @@ -492,7 +498,11 @@ void print_dot(const cfg_t& cfg, const std::string& outfile) {
std::ostream& operator<<(std::ostream& o, const basic_block_t& bb) {
o << bb.label() << ":\n";
for (const auto& s : bb) {
o << " " << s << ";\n";
for (const auto& pre : s.preconditions) {
o << " "
<< "assert " << pre << ";\n";
}
o << " " << s.cmd << ";\n";
}
auto [it, et] = bb.next_blocks();
if (it != et) {
Expand All @@ -515,7 +525,11 @@ std::ostream& operator<<(std::ostream& o, const basic_block_t& bb) {
std::ostream& operator<<(std::ostream& o, const crab::basic_block_rev_t& bb) {
o << bb.label() << ":\n";
for (const auto& s : bb) {
o << " " << s << ";\n";
for (const auto& pre : s.preconditions) {
o << " "
<< "assert " << pre << ";\n";
}
o << " " << s.cmd << ";\n";
}
o << "--> [";
for (const label_t& label : bb.next_blocks_set()) {
Expand Down
Loading

0 comments on commit f07d66a

Please sign in to comment.