diff --git a/src/bin/bench.rs b/src/bin/bench.rs index a82560e..37c2528 100644 --- a/src/bin/bench.rs +++ b/src/bin/bench.rs @@ -9,8 +9,9 @@ fn main() { let dt = 0.01; let eom = model::Lorenz63::default(); let teo = explicit::rk4(eom, dt); - let mut x = arr1(&[1.0, 0.0, 0.0]); + let mut buf = explicit::RK4Buffer::new_buffer(&teo); + let mut x: Array1 = arr1(&[1.0, 0.0, 0.0]); for _ in 0..100_000_000 { - teo.iterate(&mut x); + teo.iterate_buf(&mut x, &mut buf); } } diff --git a/src/explicit.rs b/src/explicit.rs index e69b92b..a34f974 100644 --- a/src/explicit.rs +++ b/src/explicit.rs @@ -140,3 +140,73 @@ impl TimeEvolutionBase for RK4 k4 } } + +pub struct RK4Buffer { + x: Array, + k1: Array, + k2: Array, + k3: Array, +} + +impl RK4Buffer + where A: Scalar, + D: Dimension +{ + pub fn new_buffer(t: &T) -> RK4Buffer + where T: ModelSize + { + RK4Buffer { + x: Array::zeros(t.model_size()), + k1: Array::zeros(t.model_size()), + k2: Array::zeros(t.model_size()), + k3: Array::zeros(t.model_size()), + } + } +} + +impl TimeEvolutionBufferedBase> for RK4 + where A: Scalar, + S: DataMut, + D: Dimension, + F: Explicit +{ + type Scalar = F::Scalar; + + fn iterate_buf<'a>(&self, + mut x: &'a mut ArrayBase, + mut buf: &mut RK4Buffer) + -> &'a mut ArrayBase { + let dt = self.dt; + let dt_2 = self.dt * into_scalar(0.5); + let dt_6 = self.dt / into_scalar(6.0); + buf.x.zip_mut_with(x, |buf, x| *buf = *x); + // k1 + let mut k1 = self.f.rhs(x); + buf.k1.zip_mut_with(k1, |buf, k1| *buf = *k1); + Zip::from(&mut *k1) + .and(&buf.x) + .apply(|k1, &x| { *k1 = k1.mul_real(dt_2) + x; }); + // k2 + let mut k2 = self.f.rhs(k1); + buf.k2.zip_mut_with(k2, |buf, k| *buf = *k); + Zip::from(&mut *k2) + .and(&buf.x) + .apply(|k2, &x| { *k2 = x + k2.mul_real(dt_2); }); + // k3 + let mut k3 = self.f.rhs(k2); + buf.k3.zip_mut_with(k3, |buf, k| *buf = *k); + Zip::from(&mut *k3) + .and(&buf.x) + .apply(|k3, &x| { *k3 = x + k3.mul_real(dt); }); + let mut k4 = self.f.rhs(k3); + Zip::from(&mut *k4) + .and(&buf.x) + .and(&buf.k1) + .and(&buf.k2) + .and(&buf.k3) + .apply(|k4, &x, &k1, &k2, &k3| { + *k4 = x + (k1 + (k2 + k3).mul_real(into_scalar(2.0)) + *k4).mul_real(dt_6); + }); + k4 + } +} diff --git a/src/traits.rs b/src/traits.rs index 218b455..162c2e4 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -45,7 +45,6 @@ pub trait TimeEvolutionBase: ModelSize + TimeStep fn iterate<'a>(&self, &'a mut ArrayBase) -> &'a mut ArrayBase; } - pub trait TimeEvolution : TimeEvolutionBase, D, Scalar = A, Time = A::Real> + TimeEvolutionBase, D, Scalar = A, Time = A::Real> @@ -54,3 +53,22 @@ pub trait TimeEvolution D: Dimension { } + +/// Time-evolution operator with buffer +pub trait TimeEvolutionBufferedBase: ModelSize + TimeStep + where S: DataMut, + D: Dimension +{ + type Scalar: Scalar; + /// calculate next step + fn iterate_buf<'a>(&self, &'a mut ArrayBase, &mut Buffer) -> &'a mut ArrayBase; +} + +pub trait TimeEvolutionBuffered + : TimeEvolutionBufferedBase, D, Buffer, Scalar = A, Time = A::Real> + + TimeEvolutionBufferedBase, D, Buffer, Scalar = A, Time = A::Real> + + for<'a> TimeEvolutionBufferedBase, D, Buffer, Scalar = A, Time = A::Real> + where A: Scalar, + D: Dimension +{ +}