From 4a05ba8a325edbbe28b7d2ad49b2811ecd9b138d Mon Sep 17 00:00:00 2001 From: Kai Schmidt Date: Thu, 21 Nov 2024 07:40:45 -0800 Subject: [PATCH] fix bug in pop astar --- src/algorithm/mod.rs | 116 ++++++++++++++++++++++------------------ src/compile/optimize.rs | 7 +-- src/primitive/defs.rs | 1 + src/primitive/mod.rs | 2 + 4 files changed, 70 insertions(+), 56 deletions(-) diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index dc46137e0..ec1f159ee 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -773,14 +773,24 @@ fn fft_impl( } pub fn astar(ops: Ops, env: &mut Uiua) -> UiuaResult { - astar_impl(ops, false, env) + astar_impl(ops, AstarMode::All, env) } pub fn astar_first(ops: Ops, env: &mut Uiua) -> UiuaResult { - astar_impl(ops, true, env) + astar_impl(ops, AstarMode::First, env) } -fn astar_impl(ops: Ops, first_only: bool, env: &mut Uiua) -> UiuaResult { +pub fn astar_pop(ops: Ops, env: &mut Uiua) -> UiuaResult { + astar_impl(ops, AstarMode::CostOnly, env) +} + +enum AstarMode { + All, + First, + CostOnly, +} + +fn astar_impl(ops: Ops, mode: AstarMode, env: &mut Uiua) -> UiuaResult { let start = env.pop("start")?; let [neighbors, heuristic, is_goal] = get_ops(ops, env)?; let nei_sig = neighbors.sig; @@ -935,10 +945,10 @@ fn astar_impl(ops: Ops, first_only: bool, env: &mut Uiua) -> UiuaResult { if env.is_goal(&backing[curr])? { ends.insert(curr); shortest_cost = curr_cost; - if first_only { - break; - } else { + if let AstarMode::All = mode { continue; + } else { + break; } } // Check neighbors @@ -992,57 +1002,61 @@ fn astar_impl(ops: Ops, first_only: bool, env: &mut Uiua) -> UiuaResult { Value::from_row_values(path, env) }; - if first_only { - let mut curr = ends - .into_iter() - .next() - .ok_or_else(|| env.error("No path found"))?; - let mut path = vec![curr]; - while let Some(from) = came_from.get(&curr) { - path.push(from[0]); - curr = from[0]; - } - path.reverse(); - env.push(make_path(path)?); - } else { - for end in ends { - let mut currs = vec![vec![end]]; - let mut these_paths = Vec::new(); - while !currs.is_empty() { - let mut new_paths = Vec::new(); - currs.retain_mut(|path| { - let parents = came_from - .get(path.last().unwrap()) - .map(|p| p.as_slice()) - .unwrap_or(&[]); - match parents { - [] => { - these_paths.push(take(path)); - false - } - &[parent] => { - path.push(parent); - true - } - &[parent, ref rest @ ..] => { - for &parent in rest { - let mut path = path.clone(); + match mode { + AstarMode::All => { + for end in ends { + let mut currs = vec![vec![end]]; + let mut these_paths = Vec::new(); + while !currs.is_empty() { + let mut new_paths = Vec::new(); + currs.retain_mut(|path| { + let parents = came_from + .get(path.last().unwrap()) + .map(|p| p.as_slice()) + .unwrap_or(&[]); + match parents { + [] => { + these_paths.push(take(path)); + false + } + &[parent] => { + path.push(parent); + true + } + &[parent, ref rest @ ..] => { + for &parent in rest { + let mut path = path.clone(); + path.push(parent); + new_paths.push(path); + } path.push(parent); - new_paths.push(path); + true } - path.push(parent); - true } - } - }); - currs.extend(new_paths); + }); + currs.extend(new_paths); + } + for mut path in these_paths { + path.reverse(); + paths.push(Boxed(make_path(path)?)); + } } - for mut path in these_paths { - path.reverse(); - paths.push(Boxed(make_path(path)?)); + env.push(paths); + } + AstarMode::First => { + let mut curr = ends + .into_iter() + .next() + .ok_or_else(|| env.error("No path found"))?; + let mut path = vec![curr]; + while let Some(from) = came_from.get(&curr) { + path.push(from[0]); + curr = from[0]; } + path.reverse(); + env.push(make_path(path)?); } - env.push(paths); + AstarMode::CostOnly => {} } Ok(()) } diff --git a/src/compile/optimize.rs b/src/compile/optimize.rs index 08146439f..748d03b55 100644 --- a/src/compile/optimize.rs +++ b/src/compile/optimize.rs @@ -209,11 +209,8 @@ opt!( ImplMod(AstarFirst, args.clone(), *span) ), ( - [Mod(Astar, args, span), Prim(Pop, pop_span)], - [ - ImplMod(AstarFirst, args.clone(), *span), - Prim(Pop, *pop_span) - ] + [Mod(Astar, args, span), Prim(Pop, _)], + ImplMod(AstarPop, args.clone(), *span) ), ); diff --git a/src/primitive/defs.rs b/src/primitive/defs.rs index 8deae6340..8420b8494 100644 --- a/src/primitive/defs.rs +++ b/src/primitive/defs.rs @@ -3620,6 +3620,7 @@ impl_primitive!( (2[1], RowsWindows), (1, CountUnique), (1(2)[3], AstarFirst), + (1[3], AstarPop), // Implementation details (1[2], RepeatWithInverse), (2(1), ValidateType), diff --git a/src/primitive/mod.rs b/src/primitive/mod.rs index dde4d164c..abd03d70a 100644 --- a/src/primitive/mod.rs +++ b/src/primitive/mod.rs @@ -250,6 +250,7 @@ impl fmt::Display for ImplPrimitive { MatchLe => write!(f, "match ≤"), MatchGe => write!(f, "match ≥"), AstarFirst => write!(f, "{First}{Astar}"), + AstarPop => write!(f, "{Pop}{Astar}"), &ReduceDepth(n) => { for _ in 0..n { write!(f, "{Rows}")?; @@ -1647,6 +1648,7 @@ impl ImplPrimitive { ImplPrimitive::ReduceContent => reduce::reduce_content(ops, env)?, ImplPrimitive::Adjacent => reduce::adjacent(ops, env)?, ImplPrimitive::AstarFirst => algorithm::astar_first(ops, env)?, + ImplPrimitive::AstarPop => algorithm::astar_pop(ops, env)?, &ImplPrimitive::ReduceDepth(depth) => reduce::reduce(ops, depth, env)?, ImplPrimitive::RepeatWithInverse => loops::repeat(ops, true, env)?, ImplPrimitive::UnScan => reduce::unscan(ops, env)?,