Skip to content

Commit

Permalink
Split discrete and continuous grid learning rates
Browse files Browse the repository at this point in the history
- Return struct PatternMatch on match iteration
  • Loading branch information
benruijl committed Jul 24, 2024
1 parent 48c5b99 commit 6623a94
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 40 deletions.
2 changes: 1 addition & 1 deletion examples/numerical_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn main() {
}
}

grid.update(1.5);
grid.update(1.5, 1.5);

println!(
"Integral at iteration {:2}: {:.6} ± {:.6}",
Expand Down
9 changes: 6 additions & 3 deletions examples/pattern_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ fn main() {
println!("> Matching pattern {} to {}:", pat_expr, expr.as_view());

let mut it = pattern.pattern_match(expr.as_view(), &conditions, &settings);
while let Some((location, used_flags, _atom, match_stack)) = it.next() {
println!("\t Match at location {:?} - {:?}:", location, used_flags);
for (id, v) in match_stack {
while let Some(m) = it.next() {
println!(
"\t Match at location {:?} - {:?}:",
m.position, m.used_flags
);
for (id, v) in m.match_stack {
print!("\t\t{} = ", State::get_name(*id));
match v {
Match::Single(s) => {
Expand Down
6 changes: 3 additions & 3 deletions examples/pattern_restrictions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ fn main() {
);

let mut it = pattern.pattern_match(expr.as_view(), &conditions, &settings);
while let Some((location, used_flags, _atom, match_stack)) = it.next() {
println!("\tMatch at location {:?} - {:?}:", location, used_flags);
for (id, v) in match_stack {
while let Some(m) = it.next() {
println!("\tMatch at location {:?} - {:?}:", m.position, m.used_flags);
for (id, v) in m.match_stack {
print!("\t\t{} = ", State::get_name(*id));
match v {
Match::Single(s) => {
Expand Down
6 changes: 3 additions & 3 deletions examples/tree_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ fn main() {
println!("> Matching pattern {} to {}:", pat_expr, expr);

let mut it = PatternAtomTreeIterator::new(&pattern, expr.as_view(), &restrictions, &settings);
while let Some((location, used_flags, _atom, match_stack)) = it.next() {
println!("\tMatch at location {:?} - {:?}:", location, used_flags);
for (id, v) in match_stack {
while let Some(m) = it.next() {
println!("\tMatch at location {:?} - {:?}:", m.position, m.used_flags);
for (id, v) in m.match_stack {
print!("\t\t{} = ", State::get_name(*id));
match v {
Match::Single(s) => {
Expand Down
19 changes: 12 additions & 7 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4136,8 +4136,8 @@ impl PythonMatchIterator {
/// Return the next match.
fn __next__(&mut self) -> Option<HashMap<PythonExpression, PythonExpression>> {
self.with_dependent_mut(|_, i| {
i.next().map(|(_, _, _, matches)| {
matches
i.next().map(|m| {
m.match_stack
.into_iter()
.map(|m| {
(Atom::new_var(m.0).into(), {
Expand Down Expand Up @@ -9293,7 +9293,7 @@ impl PythonNumericalIntegrator {
.map_err(|e| pyo3::exceptions::PyAssertionError::new_err(e))
}

/// Update the grid using the `learning_rate`.
/// Update the grid using the `discrete_learning_rate` and `continuous_learning_rate`.
/// Examples
/// --------
/// >>> from symbolica import NumericalIntegrator, Sample
Expand All @@ -9309,10 +9309,15 @@ impl PythonNumericalIntegrator {
/// >>> samples = integrator.sample(10000 + i * 1000)
/// >>> res = integrand(samples)
/// >>> integrator.add_training_samples(samples, res)
/// >>> avg, err, chi_sq = integrator.update(1.5)
/// >>> avg, err, chi_sq = integrator.update(1.5, 1.5)
/// >>> print('Iteration {}: {:.6} +- {:.6}, chi={:.6}'.format(i+1, avg, err, chi_sq))
fn update(&mut self, learing_rate: f64) -> PyResult<(f64, f64, f64)> {
self.grid.update(learing_rate);
fn update(
&mut self,
discrete_learing_rate: f64,
continuous_learing_rate: f64,
) -> PyResult<(f64, f64, f64)> {
self.grid
.update(discrete_learing_rate, continuous_learing_rate);

let stats = self.grid.get_statistics();
Ok((stats.avg, stats.err, stats.chi_sq / stats.cur_iter as f64))
Expand Down Expand Up @@ -9381,7 +9386,7 @@ impl PythonNumericalIntegrator {
self.grid.add_training_sample(s, r).unwrap();
}

self.grid.update(1.5);
self.grid.update(1.5, 1.5);

let stats = self.grid.get_statistics();
if show_stats {
Expand Down
47 changes: 42 additions & 5 deletions src/coefficient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
integer::{Integer, IntegerRing, Z},
rational::{Rational, Q},
rational_polynomial::RationalPolynomial,
EuclideanDomain, Field, Ring,
EuclideanDomain, Field, InternalOrdering, Ring,
},
poly::{polynomial::MultivariatePolynomial, Variable, INLINED_EXPONENTS},
state::{FiniteFieldIndex, State, Workspace},
Expand Down Expand Up @@ -75,6 +75,13 @@ impl From<(i64, i64)> for Coefficient {
}
}

impl<'a> From<(i64, i64)> for CoefficientView<'a> {
#[inline]
fn from(r: (i64, i64)) -> Self {
CoefficientView::Natural(r.0, r.1)
}
}

impl From<Integer> for Coefficient {
fn from(value: Integer) -> Self {
Coefficient::Rational(value.into())
Expand Down Expand Up @@ -117,6 +124,33 @@ impl Default for Coefficient {
}
}

impl PartialOrd for Coefficient {
fn partial_cmp(&self, other: &Coefficient) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Coefficient {
fn cmp(&self, other: &Coefficient) -> Ordering {
match (self, other) {
(Coefficient::Rational(r1), Coefficient::Rational(r2)) => r1.cmp(r2),
(Coefficient::FiniteField(n1, _), Coefficient::FiniteField(n2, _)) => n1.0.cmp(&n2.0),
(Coefficient::Float(f1), Coefficient::Float(f2)) => {
f1.partial_cmp(&f2).unwrap_or(Ordering::Equal)
}
(Coefficient::RationalPolynomial(n1), Coefficient::RationalPolynomial(n2)) => {
n1.internal_cmp(&n2)
}
(Coefficient::Rational(_), _) => Ordering::Less,
(_, Coefficient::Rational(_)) => Ordering::Greater,
(Coefficient::Float(_), _) => Ordering::Less,
(_, Coefficient::Float(_)) => Ordering::Greater,
(Coefficient::FiniteField(_, _), _) => Ordering::Less,
(_, Coefficient::FiniteField(_, _)) => Ordering::Greater,
}
}
}

impl Coefficient {
pub fn new() -> Coefficient {
Coefficient::zero()
Expand Down Expand Up @@ -603,8 +637,8 @@ impl PartialOrd for CoefficientView<'_> {
impl Ord for CoefficientView<'_> {
fn cmp(&self, other: &CoefficientView) -> Ordering {
match (self, other) {
(&CoefficientView::Natural(n1, d1), &CoefficientView::Natural(n2, d2)) => {
Rational::from_unchecked(n1, d1).cmp(&Rational::from_unchecked(n2, d2))
(CoefficientView::Natural(n1, d1), CoefficientView::Natural(n2, d2)) => {
Rational::from_unchecked(*n1, *d1).cmp(&Rational::from_unchecked(*n2, *d2))
}
(CoefficientView::Large(n1), CoefficientView::Large(n2)) => {
n1.to_rat().cmp(&n2.to_rat())
Expand All @@ -622,14 +656,17 @@ impl Ord for CoefficientView<'_> {
.to_float()
.partial_cmp(&f2.to_float())
.unwrap_or(Ordering::Equal),
(CoefficientView::RationalPolynomial(n1), CoefficientView::RationalPolynomial(n2)) => {
n1.deserialize().internal_cmp(&n2.deserialize())
}
(CoefficientView::Natural(_, _), _) => Ordering::Less,
(_, CoefficientView::Natural(_, _)) => Ordering::Greater,
(CoefficientView::Large(_), _) => Ordering::Less,
(_, CoefficientView::Large(_)) => Ordering::Greater,
(CoefficientView::Float(_), _) => Ordering::Less,
(_, CoefficientView::Float(_)) => Ordering::Greater,
(CoefficientView::FiniteField(_, _), _) => Ordering::Less,
(_, CoefficientView::FiniteField(_, _)) => Ordering::Greater,
(CoefficientView::RationalPolynomial(_), _) => Ordering::Less,
(_, CoefficientView::RationalPolynomial(_)) => Ordering::Greater,
}
}
}
Expand Down
87 changes: 79 additions & 8 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,13 +1399,54 @@ impl std::fmt::Debug for PatternRestriction {
}
}

/// A part of an expression that was matched to a wildcard.
#[derive(Clone, PartialEq)]
pub enum Match<'a> {
/// A matched single atom.
Single(AtomView<'a>),
/// A matched subexpression of atoms of the same type.
Multiple(SliceType, Vec<AtomView<'a>>),
/// A matched function name.
FunctionName(Symbol),
}

impl<'a> std::fmt::Display for Match<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Single(a) => a.fmt(f),
Self::Multiple(t, list) => match t {
SliceType::Add | SliceType::Mul | SliceType::Arg | SliceType::Pow => {
f.write_str("(")?;
for (i, a) in list.iter().enumerate() {
if i > 0 {
match t {
SliceType::Add => {
f.write_str("+")?;
}
SliceType::Mul => {
f.write_str("*")?;
}
SliceType::Arg => {
f.write_str(",")?;
}
SliceType::Pow => {
f.write_str("^")?;
}
_ => unreachable!(),
}
}
a.fmt(f)?;
}
f.write_str(")")
}
SliceType::One => list[0].fmt(f),
SliceType::Empty => f.write_str("()"),
},
Self::FunctionName(name) => name.fmt(f),
}
}
}

impl<'a> std::fmt::Debug for Match<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down Expand Up @@ -1491,6 +1532,20 @@ pub struct MatchStack<'a, 'b> {
settings: &'b MatchSettings,
}

impl<'a, 'b> std::fmt::Display for MatchStack<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("[")?;
for (i, (k, v)) in self.stack.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
f.write_fmt(format_args!("{}: {}", k, v))?;
}

f.write_str("]")
}
}

impl<'a, 'b> std::fmt::Debug for MatchStack<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatchStack")
Expand Down Expand Up @@ -2373,6 +2428,18 @@ pub struct PatternAtomTreeIterator<'a, 'b> {
first_match: bool,
}

/// A part of an expression with its position that yields a match.
pub struct PatternMatch<'a, 'b, 'c> {
/// The position (branch) of the match in the tree.
pub position: &'a [usize],
/// Flags which subexpressions are matched in case of matching a range.
pub used_flags: Vec<bool>,
/// The matched target.
pub target: AtomView<'b>,
/// The list of identifications of matched wildcards.
pub match_stack: &'a MatchStack<'b, 'c>,
}

impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> {
pub fn new(
pattern: &'b Pattern,
Expand All @@ -2391,15 +2458,21 @@ impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> {
}
}

pub fn next(&mut self) -> Option<(&[usize], Vec<bool>, AtomView<'a>, &MatchStack<'a, 'b>)> {
/// Generate the next match if it exists.
pub fn next(&mut self) -> Option<PatternMatch<'_, 'a, 'b>> {
loop {
if let Some(ct) = self.current_target {
if let Some(it) = self.pattern_iter.as_mut() {
if let Some((_, used_flags)) = it.next(&mut self.match_stack) {
let a = used_flags.to_vec();

self.first_match = true;
return Some((&self.tree_pos, a, ct, &self.match_stack));
return Some(PatternMatch {
position: &self.tree_pos,
used_flags: a,
target: ct,
match_stack: &self.match_stack,
});
} else {
// no match: bail
self.current_target = None;
Expand Down Expand Up @@ -2570,21 +2643,19 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> {

/// Return the next replacement.
pub fn next(&mut self, out: &mut Atom) -> Option<()> {
if let Some((position, used_flags, _target, match_stack)) =
self.pattern_tree_iterator.next()
{
if let Some(pattern_match) = self.pattern_tree_iterator.next() {
Workspace::get_local().with(|ws| {
let mut new_rhs = ws.new_atom();

self.rhs
.substitute_wildcards(ws, &mut new_rhs, match_stack)
.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack)
.unwrap(); // TODO: escalate?

let mut h = ws.new_atom();
ReplaceIterator::copy_and_replace(
&mut h,
position,
&used_flags,
pattern_match.position,
&pattern_match.used_flags,
self.target,
new_rhs.as_view(),
ws,
Expand Down
14 changes: 7 additions & 7 deletions src/numerical_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,10 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> Grid<T> {
}

/// Update the grid based on the samples added through [`Grid::add_training_sample`].
pub fn update(&mut self, learning_rate: T) {
pub fn update(&mut self, discrete_learning_rate: T, continuous_learning_rate: T) {
match self {
Grid::Continuous(g) => g.update(learning_rate),
Grid::Discrete(g) => g.update(learning_rate),
Grid::Continuous(g) => g.update(continuous_learning_rate),
Grid::Discrete(g) => g.update(discrete_learning_rate, continuous_learning_rate),
}
}

Expand Down Expand Up @@ -527,11 +527,11 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> DiscreteGri
/// and adapt all sub-grids based on the new training samples.
///
/// If `learning_rate` is set to 0, no training happens.
pub fn update(&mut self, learning_rate: T) {
pub fn update(&mut self, discrete_learning_rate: T, continuous_learning_rate: T) {
let mut err_sum = T::new_zero();
for bin in &mut self.bins {
if let Some(sub_grid) = &mut bin.sub_grid {
sub_grid.update(learning_rate);
sub_grid.update(discrete_learning_rate, continuous_learning_rate);
}

let acc = &mut bin.accumulator;
Expand All @@ -542,7 +542,7 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> DiscreteGri
}
}

if learning_rate.is_zero()
if discrete_learning_rate.is_zero()
|| self.bins.iter().all(|x| {
if self.train_on_avg {
x.accumulator.avg == T::new_zero()
Expand Down Expand Up @@ -1126,7 +1126,7 @@ mod test {
}
}

grid.update(1.5);
grid.update(1.5, 1.5);
}

assert_eq!(grid.accumulator.avg, 0.9713543844460519);
Expand Down
6 changes: 3 additions & 3 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2867,8 +2867,8 @@ class NumericalIntegrator:
"""Add the samples and their corresponding function evaluations to the grid.
Call `update` after to update the grid and to obtain the new expected value for the integral."""

def update(self, learning_rate: float) -> Tuple[float, float, float]:
"""Update the grid using the `learning_rate`.
def update(self, discrete_learning_rate: float, continous_learning_rate: float) -> Tuple[float, float, float]:
"""Update the grid using the `discrete_learning_rate` and `continuous_learning_rate`.
Examples
--------
>>> from symbolica import NumericalIntegrator, Sample
Expand All @@ -2884,7 +2884,7 @@ class NumericalIntegrator:
>>> samples = integrator.sample(10000 + i * 1000)
>>> res = integrand(samples)
>>> integrator.add_training_samples(samples, res)
>>> avg, err, chi_sq = integrator.update(1.5)
>>> avg, err, chi_sq = integrator.update(1.5, 1.5)
>>> print('Iteration {}: {:.6} +- {:.6}, chi={:.6}'.format(i+1, avg, err, chi_sq))
"""

Expand Down

0 comments on commit 6623a94

Please sign in to comment.