Skip to content

Commit

Permalink
Change distance function (#2)
Browse files Browse the repository at this point in the history
* Improve distance function

* normalized distance

* test

* double-pointer

* fix depth
  • Loading branch information
kelindar authored Nov 12, 2023
1 parent 0b8319f commit e959537
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 50 deletions.
16 changes: 12 additions & 4 deletions planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync"
)

const maxDepth = 100

// Action represents an action that can be performed.
type Action interface {

Expand All @@ -24,7 +26,6 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
start = start.Clone()
start.node = node{
heuristic: start.Distance(goal),
stateCost: 0,
}

heap := acquireHeap()
Expand All @@ -34,6 +35,14 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
for heap.Len() > 0 {
current, _ := heap.Pop()

/*fmt.Printf("- (%d) %s, cost=%v, heuristic=%v, total=%v\n",
current.depth, current.action,
current.stateCost, current.heuristic, current.totalCost)*/

if current.depth >= maxDepth {
return reconstructPlan(current), nil
}

// If we reached the goal, reconstruct the path.
done, err := current.Match(goal)
switch {
Expand All @@ -59,8 +68,6 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
return nil, err
}

//fmt.Printf("Action: %s, State: %s, New: %s\n", action.String(), current.String(), newState.String())

// Check if newState is already planned to be visited or if the newCost is lower
newCost := current.stateCost + action.Cost()
node, found := heap.Find(newState.Hash())
Expand All @@ -72,6 +79,7 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
newState.heuristic = heuristic
newState.stateCost = newCost
newState.totalCost = newCost + heuristic
newState.depth = current.depth + 1
heap.Push(newState)

// In any of those cases, we need to release the new state
Expand All @@ -92,7 +100,7 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {

// reconstructPlan reconstructs the plan from the goal node to the start node.
func reconstructPlan(goalNode *State) []Action {
plan := make([]Action, 0, int(goalNode.index))
plan := make([]Action, 0, int(goalNode.depth))
for n := goalNode; n != nil; n = n.parent {
if n.action != nil { // The start node has no action
plan = append(plan, n.action)
Expand Down
41 changes: 25 additions & 16 deletions planner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ BenchmarkPlan/deep-24 380756 3103 ns/op 230 B/op 1
BenchmarkPlan/deep-24 337836 3519 ns/op 230 B/op 1 allocs/op
BenchmarkPlan/deep-24 420907 2831 ns/op 230 B/op 1 allocs/op
BenchmarkPlan/deep-24 444250 2716 ns/op 230 B/op 1 allocs/op
BenchmarkPlan/deep-24 499970 2345 ns/op 211 B/op 1 allocs/op
BenchmarkPlan/maze-24 37 31458708 ns/op 2702894 B/op 80711 allocs/op
BenchmarkPlan/maze-24 63 18643352 ns/op 1569536 B/op 51464 allocs/op
Expand Down Expand Up @@ -93,40 +94,44 @@ func TestNumericPlan(t *testing.T) {

plan, err := Plan(start, goal, actions)
assert.NoError(t, err)
assert.Equal(t, []string{"Forage", "Forage", "Forage", "Sleep", "Forage", "Forage", "Sleep", "Forage", "Forage", "Forage", "Sleep", "Eat", "Forage"},
assert.Equal(t, []string{"Forage", "Forage", "Forage", "Sleep", "Forage", "Forage", "Sleep", "Forage", "Forage", "Forage", "Sleep", "Forage"},
planOf(plan))

//assert.Fail(t, "xxx")
}

func TestMaze(t *testing.T) {
start := StateOf("A")
goal := StateOf("Z")
actions := []Action{
plan, err := Plan(StateOf("A"), StateOf("Z"), []Action{
move("A->B"), move("B->C"), move("C->D"), move("D->E"), move("E->F"), move("F->G"),
move("G->H"), move("H->I"), move("I->J"), move("C->X1"), move("E->X2"), move("G->X3"),
move("X1->D"), move("X2->F"), move("X3->H"), move("B->Y1"), move("D->Y2"), move("F->Y3"),
move("Y1->C"), move("Y2->E"), move("Y3->G"), move("J->K"), move("K->L"), move("L->M"),
move("M->N"), move("N->O"), move("O->P"), move("P->Q"), move("Q->R"), move("R->S"),
move("S->T"), move("T->U"), move("U->V"), move("V->W"), move("W->X"), move("X->Y"),
move("Y->Z"), move("U->Z1"), move("W->Z2"), move("Z1->V"), move("Z2->X"), move("A->Z3"),
}

plan, err := Plan(start, goal, actions)
})
assert.NoError(t, err)
assert.Equal(t, []string{"A->B", "B->C", "C->D", "D->E", "E->F", "F->G", "G->H", "H->I", "I->J",
"J->K", "K->L", "L->M", "M->N", "N->O", "O->P", "P->Q", "Q->R", "R->S", "S->T", "T->U", "U->V",
"V->W", "W->X", "X->Y", "Y->Z"},
planOf(plan))
//assert.Fail(t, "xxx")
}

func TestSimplePlan(t *testing.T) {
start := StateOf("A", "B")
goal := StateOf("C", "D")
actions := []Action{move("A->C"), move("A->D"), move("B->C"), move("B->D")}
func TestWeightedPlan(t *testing.T) {
plan, err := Plan(StateOf("A", "B"), StateOf("C", "D"),
[]Action{move("A->C"), move("A->D", 0.5), move("B->C"), move("B->D", 0.75)},
)
assert.NoError(t, err)
assert.Equal(t, []string{"A->D", "B->C"}, planOf(plan))
}

plan, err := Plan(start, goal, actions)
func TestSimplePlan(t *testing.T) {
plan, err := Plan(StateOf("A", "B"), StateOf("C", "D"),
[]Action{move("A->C"), move("A->D"), move("B->C"), move("B->D")},
)
assert.NoError(t, err)
assert.Equal(t, []string{"A->C", "B->D"},
planOf(plan))
assert.Equal(t, []string{"A->C", "B->D"}, planOf(plan))
}

func TestNoPlanFound(t *testing.T) {
Expand All @@ -139,9 +144,13 @@ func TestNoPlanFound(t *testing.T) {

// ------------------------------------ Test Action ------------------------------------

func move(m string) Action {
func move(m string, w ...float32) Action {
if len(w) == 0 {
w = append(w, 1.0)
}

arr := strings.Split(m, "->")
return actionOf(m, 1, StateOf(arr[0]), StateOf("!"+arr[0], arr[1]))
return actionOf(m, w[0], StateOf(arr[0]), StateOf("!"+arr[0], arr[1]))
}

func planOf(plan []Action) []string {
Expand Down
60 changes: 35 additions & 25 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"sync"
)

const linearCutoff = 8 // 1 cache line
const linearCutoff = 16 // 2 cache line

var pool = sync.Pool{
New: func() any {
Expand Down Expand Up @@ -45,6 +45,7 @@ type node struct {
stateCost float32 // Cost from the start state to this state
totalCost float32 // Sum of cost and heuristic
index int // Index of the state in the heap
depth int // Depth of the state in the tree
visited bool // Whether the state was visited
}

Expand Down Expand Up @@ -225,38 +226,47 @@ func (s *State) Apply(effects *State) error {
return nil
}

// Distance estimates the distance to the goal state as the number of differing keys.
// Distance estimates the distance to the goal state.
func (state *State) Distance(goal *State) (diff float32) {
i, j := 0, 0
for i < len(goal.vx) && j < len(state.vx) {
f0 := goal.vx[i].Fact()
f1 := state.vx[j].Fact()
i := 0
for _, g := range goal.vx {
x := g.Expr().Value()
v := float32(0)

// Find the value in the state
for ; i < len(state.vx); i++ {
if state.vx[i].Fact() == g.Fact() {
v = state.vx[i].Expr().Value()
break // Found
}
if state.vx[i].Fact() < g.Fact() {
break // Not found
}
}

switch {
case f1 == f0:
x := goal.vx[i].Expr().Value()
y := state.vx[j].Expr().Value()
// Calculate the difference, normalized
switch g.Expr().Operator() {
case opEqual:
switch {
case x > y:
diff += x - y
case x < y:
diff += y - x
case v < x:
diff += (x - v)
case v > x:
diff += (v - x)
default: // v == x
}

j++
i++
case f1 > f0:
diff += 100
j++
case f1 < f0:
diff += 100
i++
case opLess:
if v > x {
diff += (v - x)
}

case opGreater:
if v < x {
diff += (x - v)
}
}
}

// Add the remaining elements
diff += float32(len(goal.vx)-i) * 100
diff += float32(len(state.vx)-j) * 100
return diff
}

Expand Down
23 changes: 18 additions & 5 deletions state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,32 @@ func TestDistance(t *testing.T) {
{[]string{"A"}, []string{"A"}, 0},
{[]string{"A=100"}, []string{"A=10"}, 90},
{[]string{"A=100"}, []string{"A=90"}, 10},
{[]string{"A"}, []string{"B"}, 200},
{[]string{"A=25"}, []string{"A=50"}, 25},
{[]string{"A=0"}, []string{"A=50"}, 50},
{[]string{"A=75"}, []string{"A=50"}, 25},
{[]string{"A"}, []string{"B"}, 100},
{[]string{"A"}, []string{"A", "B"}, 100},
{[]string{"A", "B"}, []string{"A"}, 100},
{[]string{"A", "B"}, []string{"C", "D"}, 400},
{[]string{"A", "B"}, []string{"A"}, 0},
{[]string{"A", "B"}, []string{"C", "D"}, 200},
{[]string{"A", "B"}, []string{"A", "B"}, 0},
{[]string{"A", "B"}, []string{"A", "B", "C"}, 100},
{[]string{"A", "B", "C"}, []string{"D", "B"}, 300},
{[]string{"A", "B", "C"}, []string{"D", "B"}, 100},
{[]string{"A=20"}, []string{"B=10"}, 10},
{[]string{"A=20"}, []string{"B=70"}, 70},
{[]string{"A=20", "C=40"}, []string{"B=5"}, 5},
{[]string{"A=5", "C=40"}, []string{"A=10", "E=40"}, 45},
{[]string{"A=10"}, []string{}, 0},
{[]string{}, []string{"A=10"}, 10},
{[]string{"A=10"}, []string{"A<50"}, 0},
{[]string{"A=75"}, []string{"A<50"}, 25},
{[]string{"A=10"}, []string{"A>50"}, 40},
{[]string{"A=70"}, []string{"A>50"}, 0},
}

for _, test := range tests {
state1 := StateOf(test.state1...)
state2 := StateOf(test.state2...)
assert.Equal(t, test.expect, state1.Distance(state2),
assert.InDelta(t, test.expect, state1.Distance(state2), 0.01,
"state1=%s, state2=%s", state1, state2)
}
}
Expand Down

0 comments on commit e959537

Please sign in to comment.