Skip to content

Commit

Permalink
Keep function in builder struct
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Young <[email protected]>
  • Loading branch information
seanyoung committed Apr 8, 2024
1 parent ce56539 commit e975c56
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions irp/src/build_bpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use inkwell::{
module::Module,
targets::{CodeModel, FileType, RelocMode, Target, TargetTriple},
types::{BasicType, StructType},
values::{BasicValue, GlobalValue, IntValue, PointerValue},
values::{BasicValue, FunctionValue, GlobalValue, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use once_cell::sync::OnceCell;
Expand All @@ -36,12 +36,14 @@ impl DFA {
let (map, decoder_state_ty) = define_map_def(&module, &vars, &context);
define_license(&module, &context);

let function = define_function(&module, &context, options.name);
let builder = context.create_builder();

let mut builder = Builder {
dfa: self,
options,
module,
function,
builder,
vars,
decoder_state_ty,
Expand All @@ -51,7 +53,7 @@ impl DFA {
.const_null(),
};

builder.define_function(map, &context);
builder.define_function_body(map, &context);

if let Some(path) = options.llvm_ir {
builder.module.print_to_file(path).unwrap();
Expand Down Expand Up @@ -122,6 +124,7 @@ struct Builder<'a> {
options: &'a Options<'a>,
dfa: &'a DFA,
module: Module<'a>,
function: FunctionValue<'a>,
builder: builder::Builder<'a>,
decoder_state_ty: StructType<'a>,
decoder_state: PointerValue<'a>,
Expand All @@ -136,19 +139,13 @@ struct VarValue<'a> {
}

impl<'a> Builder<'a> {
fn define_function(&mut self, map_def: GlobalValue<'a>, context: &'a Context) {
fn define_function_body(&mut self, map_def: GlobalValue<'a>, context: &'a Context) {
let i32 = context.i32_type();
let i64 = context.i64_type();
let i32_ptr = context.i32_type().ptr_type(AddressSpace::default());
let i64_ptr = context.i64_type().ptr_type(AddressSpace::default());

let fn_type = i32.fn_type(&[i32_ptr.into()], false);

let function = self.module.add_function("bpf_decoder", fn_type, None);

function.set_section(Some(&format!("lirc_mode2/{}", self.options.name)));

let entry = context.append_basic_block(function, "entry");
let entry = context.append_basic_block(self.function, "entry");
self.builder.position_at_end(entry);

// get map entry 0 (which we use as the decoder state)
Expand All @@ -173,8 +170,8 @@ impl<'a> Builder<'a> {
.unwrap()
.into_pointer_value();

let non_zero_key = context.append_basic_block(function, "non_zero_key");
let zero_key = context.append_basic_block(function, "zero_key");
let non_zero_key = context.append_basic_block(self.function, "non_zero_key");
let zero_key = context.append_basic_block(self.function, "zero_key");

let res = self
.builder
Expand Down Expand Up @@ -205,7 +202,10 @@ impl<'a> Builder<'a> {
.builder
.build_load(
i32,
function.get_first_param().unwrap().into_pointer_value(),
self.function
.get_first_param()
.unwrap()
.into_pointer_value(),
"lirc_mode2",
)
.unwrap()
Expand All @@ -216,8 +216,8 @@ impl<'a> Builder<'a> {
.build_right_shift(lirc_mode2, i32.const_int(24, false), false, "lirc_mode2_ty")
.unwrap();

let lirc_ok = context.append_basic_block(function, "lirc_ok");
let error = context.append_basic_block(function, "error");
let lirc_ok = context.append_basic_block(self.function, "lirc_ok");
let error = context.append_basic_block(self.function, "error");

self.builder
.build_switch(
Expand Down Expand Up @@ -278,29 +278,29 @@ impl<'a> Builder<'a> {

// we will add a switch statement at the end of lirc_ok block once we have built all the cases
for (state_no, vert) in self.dfa.verts.iter().enumerate() {
let block = context.append_basic_block(function, &format!("state_{state_no}"));
let block = context.append_basic_block(self.function, &format!("state_{state_no}"));
self.builder.position_at_end(block);

cases.push((i64.const_int(state_no as u64, false), block));

for edge in &vert.edges {
let next = context.append_basic_block(function, "next");
let next = context.append_basic_block(self.function, "next");

for action in &edge.actions {
match action {
Action::Flash {
length: Length::Range(min, max),
..
} => {
let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

self.builder
.build_conditional_branch(is_pulse, ok, next)
.unwrap();

self.builder.position_at_end(ok);

let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

let res = self
.builder
Expand All @@ -319,7 +319,7 @@ impl<'a> Builder<'a> {
self.builder.position_at_end(ok);

if let Some(max) = max {
let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

let res = self
.builder
Expand All @@ -342,15 +342,15 @@ impl<'a> Builder<'a> {
length: Length::Range(min, max),
..
} => {
let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

self.builder
.build_conditional_branch(is_pulse, next, ok)
.unwrap();

self.builder.position_at_end(ok);

let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

let res = self
.builder
Expand All @@ -369,7 +369,7 @@ impl<'a> Builder<'a> {
self.builder.position_at_end(ok);

if let Some(max) = max {
let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

let res = self
.builder
Expand All @@ -396,7 +396,7 @@ impl<'a> Builder<'a> {
let left = self.expression(left, context);
let right = self.expression(right, context);

let ok = context.append_basic_block(function, "ok");
let ok = context.append_basic_block(self.function, "ok");

let res = self
.builder
Expand Down Expand Up @@ -751,8 +751,9 @@ fn define_map_def<'ctx>(
i32.const_int(1, false).into(),
// map_flags
i32.const_zero().into(),
// padding
// id
i32.const_zero().into(),
// pinning type
i32.const_zero().into(),
]);

Expand All @@ -777,3 +778,20 @@ fn define_license<'ctx>(module: &Module<'ctx>, context: &'ctx Context) {
gv.set_initializer(&context.const_string(b"GPL", true));
gv.set_section(Some("license"));
}

fn define_function<'ctx>(
module: &Module<'ctx>,
context: &'ctx Context,
name: &'ctx str,
) -> FunctionValue<'ctx> {
let i32 = context.i32_type();
let i32_ptr = context.i32_type().ptr_type(AddressSpace::default());

let fn_type = i32.fn_type(&[i32_ptr.into()], false);

let function = module.add_function("bpf_decoder", fn_type, None);

function.set_section(Some(&format!("lirc_mode2/{}", name)));

function
}

0 comments on commit e975c56

Please sign in to comment.