Skip to content

Commit

Permalink
Create buffer once instead of allocating array
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Aug 13, 2024
1 parent e7d4af1 commit 38b050f
Showing 1 changed file with 105 additions and 22 deletions.
127 changes: 105 additions & 22 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,13 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
String::new()
};

res += &format!(
"extern \"C\" std::complex<double> *{}_create_buffer_complex()\n{{\n\treturn new std::complex<double>[{}];\n}}\n\n",
function_name,
self.stack.len()
);
res += &"extern \"C\" void drop_buffer_complex(std::complex<double> *buffer)\n{\n\tdelete[] buffer;\n}\n\n";

res += &format!(
"static const std::complex<double> CONSTANTS_complex[{}] = {{{}}};\n\n",
self.reserved_indices - self.param_count + 1,
Expand All @@ -1197,15 +1204,20 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
}
);

res += &format!("extern \"C\" void {}_complex(const std::complex<double> *params, std::complex<double> *out)\n{{\n", function_name);

// TODO: pass as argument to prevent stack reallocation
res += &format!("\tstd::complex<double> Z[{}];\n", self.stack.len());
res += &format!("extern \"C\" void {}_complex(const std::complex<double> *params, std::complex<double> *Z, std::complex<double> *out)\n{{\n", function_name);

self.export_asm_complex_impl(&self.instructions, &mut res);

res += "\treturn;\n}\n\n";

res += &format!(
"extern \"C\" double *{}_create_buffer_double()\n{{\n\treturn new double[{}];\n}}\n\n",
function_name,
self.stack.len()
);
res +=
&"extern \"C\" void drop_buffer_double(double *buffer)\n{\n\tdelete[] buffer;\n}\n\n";

res += &format!(
"static const double CONSTANTS_double[{}] = {{{}}};\n\n",
self.reserved_indices - self.param_count + 1,
Expand All @@ -1219,12 +1231,10 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
);

res += &format!(
"extern \"C\" void {}_double(const double *params, double *out)\n{{\n",
"extern \"C\" void {}_double(const double *params, double* Z, double *out)\n{{\n",
function_name
);

res += &format!("\tdouble Z[{}];\n", self.stack.len());

self.export_asm_double_impl(&self.instructions, &mut res);

res += "\treturn;\n}\n";
Expand Down Expand Up @@ -1867,7 +1877,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
\t\t\"mulpd xmm2, xmm2\\n\\t\"
\t\t\"haddpd xmm2, xmm2\\n\\t\"
\t\t\"divpd xmm0, xmm2\\n\\t\"
\t\t\"movupd XMMWORD {}, xmm0\\n\\t\"",
\t\t\"movupd XMMWORD {}, xmm0\\n\\t\"\n",
format_addr!(*b),
(self.reserved_indices - self.param_count) * 16,
format_addr!(*o)
Expand Down Expand Up @@ -3220,16 +3230,33 @@ type L = std::sync::Arc<libloading::Library>;

#[derive(Debug)]
struct EvaluatorFunctions<'a> {
fn_name: String,
eval_double: libloading::Symbol<'a, unsafe extern "C" fn(params: *const f64, out: *mut f64)>,
eval_double: libloading::Symbol<
'a,
unsafe extern "C" fn(params: *const f64, buffer: *mut f64, out: *mut f64),
>,
eval_complex: libloading::Symbol<
'a,
unsafe extern "C" fn(params: *const Complex<f64>, out: *mut Complex<f64>),
unsafe extern "C" fn(
params: *const Complex<f64>,
buffer: *mut Complex<f64>,
out: *mut Complex<f64>,
),
>,
create_buffer_double: libloading::Symbol<'a, unsafe extern "C" fn() -> *mut f64>,
create_buffer_complex: libloading::Symbol<'a, unsafe extern "C" fn() -> *mut Complex<f64>>,
drop_buffer_double: libloading::Symbol<'a, unsafe extern "C" fn(buffer: *mut f64)>,
drop_buffer_complex: libloading::Symbol<'a, unsafe extern "C" fn(buffer: *mut Complex<f64>)>,
}

pub struct CompiledEvaluator {
fn_name: String,
library: Library,
buffer_double: *mut f64,
buffer_complex: *mut Complex<f64>,
}

self_cell!(
pub struct CompiledEvaluator {
struct Library {
owner: L,

#[covariant]
Expand All @@ -3241,8 +3268,7 @@ self_cell!(

impl Clone for CompiledEvaluator {
fn clone(&self) -> Self {
self.load_new_function(&self.with_dependent(|_, d| &d.fn_name))
.unwrap()
self.load_new_function(&self.fn_name).unwrap()
}
}

Expand All @@ -3265,22 +3291,49 @@ impl CompiledEvaluatorFloat for Complex<f64> {
}
}

impl Drop for CompiledEvaluator {
fn drop(&mut self) {
unsafe {
(self.library.borrow_dependent().drop_buffer_double)(self.buffer_double);
(self.library.borrow_dependent().drop_buffer_complex)(self.buffer_complex);
}
}
}

impl CompiledEvaluator {
/// Load a new function from the same library.
pub fn load_new_function(&self, function_name: &str) -> Result<CompiledEvaluator, String> {
unsafe {
CompiledEvaluator::try_new(self.borrow_owner().clone(), |lib| {
let library = unsafe {
Library::try_new::<String>(self.library.borrow_owner().clone(), |lib| {
Ok(EvaluatorFunctions {
fn_name: function_name.to_string(),
eval_double: lib
.get(format!("{}_double", function_name).as_bytes())
.map_err(|e| e.to_string())?,
eval_complex: lib
.get(format!("{}_complex", function_name).as_bytes())
.map_err(|e| e.to_string())?,
create_buffer_double: lib
.get(format!("{}_create_buffer_double", function_name).as_bytes())
.map_err(|e| e.to_string())?,
create_buffer_complex: lib
.get(format!("{}_create_buffer_complex", function_name).as_bytes())
.map_err(|e| e.to_string())?,
drop_buffer_double: lib
.get("drop_buffer_double".as_bytes())
.map_err(|e| e.to_string())?,
drop_buffer_complex: lib
.get("drop_buffer_complex".as_bytes())
.map_err(|e| e.to_string())?,
})
})
}
}?;

Ok(CompiledEvaluator {
fn_name: function_name.to_string(),
buffer_double: unsafe { (library.borrow_dependent().create_buffer_double)() },
buffer_complex: unsafe { (library.borrow_dependent().create_buffer_complex)() },
library,
})
}

/// Load a compiled evaluator from a shared library.
Expand All @@ -3293,16 +3346,34 @@ impl CompiledEvaluator {
}
};

CompiledEvaluator::try_new(std::sync::Arc::new(lib), |lib| {
let library = Library::try_new::<String>(std::sync::Arc::new(lib), |lib| {
Ok(EvaluatorFunctions {
fn_name: function_name.to_string(),
eval_double: lib
.get(format!("{}_double", function_name).as_bytes())
.map_err(|e| e.to_string())?,
eval_complex: lib
.get(format!("{}_complex", function_name).as_bytes())
.map_err(|e| e.to_string())?,
create_buffer_double: lib
.get(format!("{}_create_buffer_double", function_name).as_bytes())
.map_err(|e| e.to_string())?,
create_buffer_complex: lib
.get(format!("{}_create_buffer_complex", function_name).as_bytes())
.map_err(|e| e.to_string())?,
drop_buffer_double: lib
.get("drop_buffer_double".as_bytes())
.map_err(|e| e.to_string())?,
drop_buffer_complex: lib
.get("drop_buffer_complex".as_bytes())
.map_err(|e| e.to_string())?,
})
})?;

Ok(CompiledEvaluator {
fn_name: function_name.to_string(),
buffer_double: (library.borrow_dependent().create_buffer_double)(),
buffer_complex: (library.borrow_dependent().create_buffer_complex)(),
library,
})
}
}
Expand All @@ -3316,13 +3387,25 @@ impl CompiledEvaluator {
/// Evaluate the compiled code with double-precision floating point numbers.
#[inline(always)]
pub fn evaluate_double(&self, args: &[f64], out: &mut [f64]) {
unsafe { (self.borrow_dependent().eval_double)(args.as_ptr(), out.as_mut_ptr()) }
unsafe {
(self.library.borrow_dependent().eval_double)(
args.as_ptr(),
self.buffer_double,
out.as_mut_ptr(),
)
}
}

/// Evaluate the compiled code with complex numbers.
#[inline(always)]
pub fn evaluate_complex(&self, args: &[Complex<f64>], out: &mut [Complex<f64>]) {
unsafe { (self.borrow_dependent().eval_complex)(args.as_ptr(), out.as_mut_ptr()) }
unsafe {
(self.library.borrow_dependent().eval_complex)(
args.as_ptr(),
self.buffer_complex,
out.as_mut_ptr(),
)
}
}
}

Expand Down

0 comments on commit 38b050f

Please sign in to comment.