Skip to content

Commit

Permalink
critical.md
Browse files Browse the repository at this point in the history
  • Loading branch information
agarret7 committed Mar 5, 2024
1 parent 93f3cac commit cafa88e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 12 deletions.
2 changes: 1 addition & 1 deletion gen-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "gen-rs"
description = "a experimental library for probabilistic programming in Rust."
version = "0.2.2"
version = "0.2.3"
edition = "2021"
keywords = ["statistics", "ppl", "mcmc", "importance-sampling", "particle-filtering"]
categories = ["science", "simulation"]
Expand Down
1 change: 0 additions & 1 deletion gen-rs/src/gfi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ pub trait GenFn<Args,Data,Ret> {
constraints: Data // Data := forward choices
) -> (Trace<Args,Data,Ret>, Data, f64); // Data := backward choices


/// Call a generative function and return the output.
fn call(&self, args: Args) -> Ret {
self.simulate(args).retv.unwrap()
Expand Down
21 changes: 14 additions & 7 deletions gen-rs/src/modeling/triefn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ impl<A: 'static,T: 'static> TrieFnState<A,T> {

let has_previous = data.has_leaf_node(addr);
let constrained = constraints.has_leaf_node(addr);
let logp;
let mut prev_logp = 0.;
if has_previous {
let val = data.remove_leaf_node(addr).unwrap();
Expand All @@ -174,22 +175,25 @@ impl<A: 'static,T: 'static> TrieFnState<A,T> {
} else {
x = prev_x;
}
logp = dist.logpdf(x.as_ref(), args);
*weight += logp;
*weight -= prev_logp;
} else {
if constrained {
x = constraints.remove_leaf_node(addr).unwrap().downcast::<V>().ok().unwrap();
logp = dist.logpdf(x.as_ref(), args);
*weight += logp;
} else {
x = Rc::new(GLOBAL_RNG.with_borrow_mut(|rng| {
dist.random(rng, args.clone())
}));
logp = dist.logpdf(x.as_ref(), args);
}
}

let logp = dist.logpdf(x.as_ref(), args);
let d_logp = logp - prev_logp;
*weight += d_logp;

data.insert_leaf_node(addr, (x.clone(), logp));
trace.logp += d_logp;
trace.logp += logp;
trace.logp -= prev_logp;

x.as_ref().clone()
}
Expand Down Expand Up @@ -280,17 +284,21 @@ impl<A: 'static,T: 'static> TrieFnState<A,T> {
subtrie = subtrace.data;
retv = Rc::new(subtrace.retv.unwrap());
logp = new_weight;
*weight += new_weight;
} else {
dbg!(prev_subtrie.sum());
subtrie = prev_subtrie;
retv = data.remove_leaf_node(addr).unwrap().0.downcast::<Y>().ok().unwrap();
}
*weight += logp;
} else {
if constrained {
let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap());
let (subtrace, new_weight) = gen_fn.generate(args, subconstraints);
subtrie = subtrace.data;
retv = Rc::new(subtrace.retv.unwrap());
logp = new_weight;
*weight += logp;
} else {
let subtrace = gen_fn.simulate(args);
subtrie = subtrace.data;
Expand All @@ -299,7 +307,6 @@ impl<A: 'static,T: 'static> TrieFnState<A,T> {
}
}

*weight += logp;
data.insert_internal_node(addr, subtrie);
data.insert_leaf_node(addr, (retv.clone(), 0.));
trace.logp += logp;
Expand Down Expand Up @@ -350,7 +357,7 @@ impl<A: 'static,T: 'static> TrieFnState<A,T> {
let (data, garbage, garbage_weight) = Self::_gc(trace.data, &unvisited);
assert!(visitor.all_visited(&data)); // all unvisited nodes garbage-collected
Self::Update {
trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp },
trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp - garbage_weight },
constraints,
weight: weight - garbage_weight,
discard: discard.merge(garbage),
Expand Down
92 changes: 89 additions & 3 deletions gen-rs/tests/test_triefn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{any::Any, rc::Rc};
use gen_rs::{Trie, GenFn, TrieFn, TrieFnState, normal};
use gen_rs::{bernoulli, normal, uniform, GenFn, Trie, TrieFn, TrieFnState};


pub fn _triefn_prototype(state: &mut TrieFnState<f64,f64>,noise: f64) -> f64 {
Expand All @@ -10,12 +10,11 @@ pub fn _triefn_prototype(state: &mut TrieFnState<f64,f64>,noise: f64) -> f64 {
}
sum
}
const triefn_prototype: TrieFn<f64,f64> = TrieFn { func: _triefn_prototype };

#[test]
pub fn test_triefn_prototype() {
for _ in (0..100).into_iter() {
let triefn_prototype = TrieFn::new(_triefn_prototype);

let _trace = triefn_prototype.simulate(1.);
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("1", Rc::new(100.));
Expand All @@ -25,4 +24,91 @@ pub fn test_triefn_prototype() {
dbg!(trace.logp);
dbg!(weight);
}
}


pub fn _triefn_sample_at_update_weight_regression(state: &mut TrieFnState<(),()>,_: ()) {
let b = state.sample_at(&bernoulli, 0.25, "b");
if b {
state.sample_at(&normal, (0., 1.), "x");
}
}
const triefn_sample_at_update_weight_regression: TrieFn<(),()> = TrieFn { func: _triefn_sample_at_update_weight_regression };

pub fn _triefn_trace_at_update_weight_regression(state: &mut TrieFnState<(),()>,_: ()) {
let b = state.sample_at(&bernoulli, 0.25, "b");
if b {
println!("tracing at!");
state.trace_at(&triefn_prototype, 1.0, "sub");
println!("shoulda seen ")
}
}
const triefn_trace_at_update_weight_regression: TrieFn<(),()> = TrieFn { func: _triefn_trace_at_update_weight_regression };

pub fn _triefn_sample_at_update_weight_regression2(state: &mut TrieFnState<(),()>,_: ()) {
let m = state.sample_at(&uniform, (0.,1.), "m");
state.sample_at(&normal, (m, 1.), "x");
state.sample_at(&normal, (m, 1.), "y");
}
const triefn_sample_at_update_weight_regression2: TrieFn<(),()> = TrieFn { func: _triefn_sample_at_update_weight_regression2 };

#[test]
pub fn test_sample_at_update_prev_and_constrained() {
// sample_at
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(true));
constraints.insert_leaf_node("x", Rc::new(0.0));
let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0;
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("x", Rc::new(1.0));
let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2;
assert_eq!(w, -0.5);
}

#[test]
pub fn test_sample_at_update_no_prev_and_constrained() {
// sample_at
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(false));
let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0;
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(true));
constraints.insert_leaf_node("x", Rc::new(1.0));
let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2;
approx::assert_abs_diff_eq!(w, -2.517551, epsilon = 1e-6);
}

#[test]
pub fn test_update_sample_at_prev_and_unconstrained() {
// sample_at
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("m", Rc::new(1.0));
constraints.insert_leaf_node("x", Rc::new(1.0));
constraints.insert_leaf_node("y", Rc::new(-0.3));
let tr = triefn_sample_at_update_weight_regression2.generate((), Trie::from_unweighted(constraints)).0;
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("m", Rc::new(0.5));
let w = triefn_sample_at_update_weight_regression2.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2;
approx::assert_abs_diff_eq!(w, 0.4000000, epsilon = 1e-6);
}

#[test]
pub fn test_update_no_prev_and_unconstrained() {
// sample_at
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(false));
let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0;
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(true));
let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2;
approx::assert_abs_diff_eq!(w, -1.098612, epsilon = 1e-6);

// trace_at
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(false));
let tr = triefn_trace_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0;
let mut constraints = Trie::<Rc<dyn Any>>::new();
constraints.insert_leaf_node("b", Rc::new(true));
let w = triefn_trace_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2;
approx::assert_abs_diff_eq!(w, -1.098612, epsilon = 1e-6);
}

0 comments on commit cafa88e

Please sign in to comment.