From cafa88e5ae17dbf074b45a4fe616b407e63fe0a6 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Mon, 4 Mar 2024 19:43:35 -0700 Subject: [PATCH] critical.md --- gen-rs/Cargo.toml | 2 +- gen-rs/src/gfi.rs | 1 - gen-rs/src/modeling/triefn.rs | 21 +++++--- gen-rs/tests/test_triefn.rs | 92 +++++++++++++++++++++++++++++++++-- 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/gen-rs/Cargo.toml b/gen-rs/Cargo.toml index 71d7917..c902263 100644 --- a/gen-rs/Cargo.toml +++ b/gen-rs/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gen-rs" description = "a experimental library for probabilistic programming in Rust." -version = "0.2.2" +version = "0.2.3" edition = "2021" keywords = ["statistics", "ppl", "mcmc", "importance-sampling", "particle-filtering"] categories = ["science", "simulation"] diff --git a/gen-rs/src/gfi.rs b/gen-rs/src/gfi.rs index 2a08bd2..f463e24 100644 --- a/gen-rs/src/gfi.rs +++ b/gen-rs/src/gfi.rs @@ -62,7 +62,6 @@ pub trait GenFn { constraints: Data // Data := forward choices ) -> (Trace, Data, f64); // Data := backward choices - /// Call a generative function and return the output. fn call(&self, args: Args) -> Ret { self.simulate(args).retv.unwrap() diff --git a/gen-rs/src/modeling/triefn.rs b/gen-rs/src/modeling/triefn.rs index f6980fb..ce50849 100644 --- a/gen-rs/src/modeling/triefn.rs +++ b/gen-rs/src/modeling/triefn.rs @@ -163,6 +163,7 @@ impl TrieFnState { let has_previous = data.has_leaf_node(addr); let constrained = constraints.has_leaf_node(addr); + let logp; let mut prev_logp = 0.; if has_previous { let val = data.remove_leaf_node(addr).unwrap(); @@ -174,22 +175,25 @@ impl TrieFnState { } else { x = prev_x; } + logp = dist.logpdf(x.as_ref(), args); + *weight += logp; + *weight -= prev_logp; } else { if constrained { x = constraints.remove_leaf_node(addr).unwrap().downcast::().ok().unwrap(); + logp = dist.logpdf(x.as_ref(), args); + *weight += logp; } else { x = Rc::new(GLOBAL_RNG.with_borrow_mut(|rng| { dist.random(rng, args.clone()) })); + logp = dist.logpdf(x.as_ref(), args); } } - let logp = dist.logpdf(x.as_ref(), args); - let d_logp = logp - prev_logp; - *weight += d_logp; - data.insert_leaf_node(addr, (x.clone(), logp)); - trace.logp += d_logp; + trace.logp += logp; + trace.logp -= prev_logp; x.as_ref().clone() } @@ -280,10 +284,13 @@ impl TrieFnState { subtrie = subtrace.data; retv = Rc::new(subtrace.retv.unwrap()); logp = new_weight; + *weight += new_weight; } else { + dbg!(prev_subtrie.sum()); subtrie = prev_subtrie; retv = data.remove_leaf_node(addr).unwrap().0.downcast::().ok().unwrap(); } + *weight += logp; } else { if constrained { let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap()); @@ -291,6 +298,7 @@ impl TrieFnState { subtrie = subtrace.data; retv = Rc::new(subtrace.retv.unwrap()); logp = new_weight; + *weight += logp; } else { let subtrace = gen_fn.simulate(args); subtrie = subtrace.data; @@ -299,7 +307,6 @@ impl TrieFnState { } } - *weight += logp; data.insert_internal_node(addr, subtrie); data.insert_leaf_node(addr, (retv.clone(), 0.)); trace.logp += logp; @@ -350,7 +357,7 @@ impl TrieFnState { let (data, garbage, garbage_weight) = Self::_gc(trace.data, &unvisited); assert!(visitor.all_visited(&data)); // all unvisited nodes garbage-collected Self::Update { - trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp }, + trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp - garbage_weight }, constraints, weight: weight - garbage_weight, discard: discard.merge(garbage), diff --git a/gen-rs/tests/test_triefn.rs b/gen-rs/tests/test_triefn.rs index 5156d16..eeed858 100644 --- a/gen-rs/tests/test_triefn.rs +++ b/gen-rs/tests/test_triefn.rs @@ -1,5 +1,5 @@ use std::{any::Any, rc::Rc}; -use gen_rs::{Trie, GenFn, TrieFn, TrieFnState, normal}; +use gen_rs::{bernoulli, normal, uniform, GenFn, Trie, TrieFn, TrieFnState}; pub fn _triefn_prototype(state: &mut TrieFnState,noise: f64) -> f64 { @@ -10,12 +10,11 @@ pub fn _triefn_prototype(state: &mut TrieFnState,noise: f64) -> f64 { } sum } +const triefn_prototype: TrieFn = TrieFn { func: _triefn_prototype }; #[test] pub fn test_triefn_prototype() { for _ in (0..100).into_iter() { - let triefn_prototype = TrieFn::new(_triefn_prototype); - let _trace = triefn_prototype.simulate(1.); let mut constraints = Trie::>::new(); constraints.insert_leaf_node("1", Rc::new(100.)); @@ -25,4 +24,91 @@ pub fn test_triefn_prototype() { dbg!(trace.logp); dbg!(weight); } +} + + +pub fn _triefn_sample_at_update_weight_regression(state: &mut TrieFnState<(),()>,_: ()) { + let b = state.sample_at(&bernoulli, 0.25, "b"); + if b { + state.sample_at(&normal, (0., 1.), "x"); + } +} +const triefn_sample_at_update_weight_regression: TrieFn<(),()> = TrieFn { func: _triefn_sample_at_update_weight_regression }; + +pub fn _triefn_trace_at_update_weight_regression(state: &mut TrieFnState<(),()>,_: ()) { + let b = state.sample_at(&bernoulli, 0.25, "b"); + if b { + println!("tracing at!"); + state.trace_at(&triefn_prototype, 1.0, "sub"); + println!("shoulda seen ") + } +} +const triefn_trace_at_update_weight_regression: TrieFn<(),()> = TrieFn { func: _triefn_trace_at_update_weight_regression }; + +pub fn _triefn_sample_at_update_weight_regression2(state: &mut TrieFnState<(),()>,_: ()) { + let m = state.sample_at(&uniform, (0.,1.), "m"); + state.sample_at(&normal, (m, 1.), "x"); + state.sample_at(&normal, (m, 1.), "y"); +} +const triefn_sample_at_update_weight_regression2: TrieFn<(),()> = TrieFn { func: _triefn_sample_at_update_weight_regression2 }; + +#[test] +pub fn test_sample_at_update_prev_and_constrained() { + // sample_at + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(true)); + constraints.insert_leaf_node("x", Rc::new(0.0)); + let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0; + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("x", Rc::new(1.0)); + let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2; + assert_eq!(w, -0.5); +} + +#[test] +pub fn test_sample_at_update_no_prev_and_constrained() { + // sample_at + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(false)); + let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0; + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(true)); + constraints.insert_leaf_node("x", Rc::new(1.0)); + let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2; + approx::assert_abs_diff_eq!(w, -2.517551, epsilon = 1e-6); +} + +#[test] +pub fn test_update_sample_at_prev_and_unconstrained() { + // sample_at + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("m", Rc::new(1.0)); + constraints.insert_leaf_node("x", Rc::new(1.0)); + constraints.insert_leaf_node("y", Rc::new(-0.3)); + let tr = triefn_sample_at_update_weight_regression2.generate((), Trie::from_unweighted(constraints)).0; + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("m", Rc::new(0.5)); + let w = triefn_sample_at_update_weight_regression2.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2; + approx::assert_abs_diff_eq!(w, 0.4000000, epsilon = 1e-6); +} + +#[test] +pub fn test_update_no_prev_and_unconstrained() { + // sample_at + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(false)); + let tr = triefn_sample_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0; + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(true)); + let w = triefn_sample_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2; + approx::assert_abs_diff_eq!(w, -1.098612, epsilon = 1e-6); + + // trace_at + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(false)); + let tr = triefn_trace_at_update_weight_regression.generate((), Trie::from_unweighted(constraints)).0; + let mut constraints = Trie::>::new(); + constraints.insert_leaf_node("b", Rc::new(true)); + let w = triefn_trace_at_update_weight_regression.update(tr, (), gen_rs::GfDiff::Unknown, Trie::from_unweighted(constraints)).2; + approx::assert_abs_diff_eq!(w, -1.098612, epsilon = 1e-6); } \ No newline at end of file