Skip to content

Commit

Permalink
Added AdaDelta algorithm.
Browse files Browse the repository at this point in the history
  • Loading branch information
geosarr committed Jul 26, 2024
1 parent 8671135 commit 8949f05
Show file tree
Hide file tree
Showing 13 changed files with 546 additions and 196 deletions.
40 changes: 37 additions & 3 deletions benches/steepest_descent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ extern crate test;
extern crate tuutal;

use test::Bencher;
use tuutal::{s, steepest_descent, Array, VecType};
use tuutal::{s, steepest_descent, Array, SteepestDescentParameter, VecType};

fn rosenbrock_nd() -> (
impl Fn(&VecType<f32>) -> f32,
Expand Down Expand Up @@ -36,11 +36,45 @@ fn rosenbrock_nd() -> (
}

#[bench]
fn steepest_descent_bench(bench: &mut Bencher) {
fn armijo_bench(bench: &mut Bencher) {
let (f, gradf) = rosenbrock_nd();
static LENGTH: usize = 500;
let x0 = Array::from_vec(vec![0_f32; LENGTH]);
let params = SteepestDescentParameter::new_armijo(0.01, 0.01);
bench.iter(|| {
let _solution = steepest_descent(&f, &gradf, &x0, &Default::default(), 1e-6, 1000);
let _solution = steepest_descent(&f, &gradf, &x0, &params, 1e-6, 1000);
});
}

#[bench]
fn powell_wolfe_bench(bench: &mut Bencher) {
let (f, gradf) = rosenbrock_nd();
static LENGTH: usize = 500;
let x0 = Array::from_vec(vec![0_f32; LENGTH]);
let params = SteepestDescentParameter::new_powell_wolfe(0.01, 0.1);
bench.iter(|| {
let _solution = steepest_descent(&f, &gradf, &x0, &params, 1e-6, 1000);
});
}

#[bench]
fn adagrad_bench(bench: &mut Bencher) {
let (f, gradf) = rosenbrock_nd();
static LENGTH: usize = 500;
let x0 = Array::from_vec(vec![0_f32; LENGTH]);
let params = SteepestDescentParameter::new_adagrad(0.1, 0.0001);
bench.iter(|| {
let _solution = steepest_descent(&f, &gradf, &x0, &params, 1e-6, 1000);
});
}

#[bench]
fn adadelta_bench(bench: &mut Bencher) {
let (f, gradf) = rosenbrock_nd();
static LENGTH: usize = 500;
let x0 = Array::from_vec(vec![0_f32; LENGTH]);
let params = SteepestDescentParameter::new_adadelta(0.1, 0.0001);
bench.iter(|| {
let _solution = steepest_descent(&f, &gradf, &x0, &params, 1e-6, 1000);
});
}
112 changes: 112 additions & 0 deletions py-tuutal/src/first_order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use crate::{wrap_vec_func_scalar, wrap_vec_func_vec};
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::{conversion::FromPyObjectBound, exceptions::PyRuntimeError, intern, prelude::*};
use tuutal::{steepest_descent, SteepestDescentParameter, TuutalError, VecType};

macro_rules! first_order_method {
($method:ident, $name:ident) => {
#[pyfunction]
pub fn $method<'py>(
py: Python<'py>,
f: PyObject,
g: PyObject,
x0: PyReadonlyArray1<f64>,
gamma: f64,
beta: f64,
gtol: f64,
maxiter: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
match steepest_descent(
wrap_vec_func_scalar!(py, f),
wrap_vec_func_vec!(py, g),
&x0.as_array().to_owned(),
&SteepestDescentParameter::$name(gamma, beta),
gtol,
maxiter.unwrap_or(x0.len().unwrap() * 1000),
) {
Ok(value) => Ok(value.into_pyarray_bound(py)),
Err(error) => match error {
TuutalError::Convergence {
iterate: x,
maxiter: _,
} => {
println!("Maximum number of iterations reached before convergence");
Ok(x.into_pyarray_bound(py))
}
err => Err(PyRuntimeError::new_err(err.to_string())), // Should never come this far.
},
}
}
};
}

first_order_method!(armijo, new_armijo);
first_order_method!(powell_wolfe, new_powell_wolfe);
first_order_method!(adagrad, new_adagrad);
first_order_method!(adadelta, new_adadelta);

// #[pyfunction]
// pub fn armijo<'py>(
// py: Python<'py>,
// f: PyObject,
// g: PyObject,
// x0: PyReadonlyArray1<f64>,
// gamma: f64,
// beta: f64,
// gtol: f64,
// maxiter: Option<usize>,
// ) -> PyResult<Bound<'py, PyArray1<f64>>> {
// match steepest_descent(
// wrap_vec_func_scalar!(py, f),
// wrap_vec_func_vec!(py, g),
// &x0.as_array().to_owned(),
// &SteepestDescentParameter::new_armijo(gamma, beta),
// gtol,
// maxiter.unwrap_or(x0.len().unwrap() * 1000),
// ) {
// Ok(value) => Ok(value.into_pyarray_bound(py)),
// Err(error) => match error {
// TuutalError::Convergence {
// iterate: x,
// maxiter: _,
// } => {
// println!("Maximum number of iterations reached before convergence");
// Ok(x.into_pyarray_bound(py))
// }
// err => Err(PyRuntimeError::new_err(err.to_string())), // Should never come this far.
// },
// }
// }

// #[pyfunction]
// pub fn adagrad<'py>(
// py: Python<'py>,
// f: PyObject,
// g: PyObject,
// x0: PyReadonlyArray1<f64>,
// gamma: f64,
// beta: f64,
// gtol: f64,
// maxiter: Option<usize>,
// ) -> PyResult<Bound<'py, PyArray1<f64>>> {
// match steepest_descent(
// wrap_vec_func_scalar!(py, f),
// wrap_vec_func_vec!(py, g),
// &x0.as_array().to_owned(),
// &SteepestDescentParameter::new_adagr(gamma, beta),
// gtol,
// maxiter.unwrap_or(x0.len().unwrap() * 1000),
// ) {
// Ok(value) => Ok(value.into_pyarray_bound(py)),
// Err(error) => match error {
// TuutalError::Convergence {
// iterate: x,
// maxiter: _,
// } => {
// println!("Maximum number of iterations reached before convergence");
// Ok(x.into_pyarray_bound(py))
// }
// err => Err(PyRuntimeError::new_err(err.to_string())), // Should never come this far.
// },
// }
// }
41 changes: 41 additions & 0 deletions py-tuutal/src/interface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
macro_rules! wrap_scalar_func_scalar {
($py:expr, $py_func:expr) => {
|x: f64| {
$py_func
.call1($py, (x,))
.expect("python objective function failed.")
.extract::<f64>($py)
.expect("python function should return a float-pointing number")
}
};
}

macro_rules! wrap_vec_func_scalar {
($py:expr, $py_func:expr) => {
|x: &VecType<f64>| {
$py_func
.call1($py, (x.clone().into_pyarray_bound($py),))
.expect("python objective function failed.")
.extract::<f64>($py)
.expect("python function should return a float-pointing number")
}
};
}

macro_rules! wrap_vec_func_vec {
($py:expr, $py_func:expr) => {
|x: &VecType<f64>| {
$py_func
.call1($py, (x.clone().into_pyarray_bound($py),))
.expect("python objective function failed.")
.extract::<PyReadonlyArray1<f64>>($py)
.expect("python function should return a python numpy.ndarray")
.as_array()
.to_owned()
}
};
}

pub(crate) use wrap_scalar_func_scalar;
pub(crate) use wrap_vec_func_scalar;
pub(crate) use wrap_vec_func_vec;
8 changes: 8 additions & 0 deletions py-tuutal/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
mod first_order;
mod interface;
mod zero_order;
pub use first_order::{adadelta, adagrad, armijo, powell_wolfe};
pub(crate) use interface::{wrap_scalar_func_scalar, wrap_vec_func_scalar, wrap_vec_func_vec};
use pyo3::prelude::*;
pub use zero_order::{brent_bounded, brent_root, brent_unbounded, brentq, nelder_mead};

Expand All @@ -9,5 +13,9 @@ fn tuutal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(brent_bounded, m)?)?;
m.add_function(wrap_pyfunction!(brent_unbounded, m)?)?;
m.add_function(wrap_pyfunction!(nelder_mead, m)?)?;
m.add_function(wrap_pyfunction!(armijo, m)?)?;
m.add_function(wrap_pyfunction!(adadelta, m)?)?;
m.add_function(wrap_pyfunction!(adagrad, m)?)?;
m.add_function(wrap_pyfunction!(powell_wolfe, m)?)?;
Ok(())
}
Loading

0 comments on commit 8949f05

Please sign in to comment.