This repository has been archived by the owner on Aug 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a222a7
commit 0205d7d
Showing
6 changed files
with
265 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright 2021 go-mcts. All rights reserved. | ||
// Use of this source code is governed by a MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package mcts | ||
|
||
import ( | ||
"fmt" | ||
) | ||
|
||
// Struct pointers as map keys is not work correctly. | ||
// see https://abhinavg.net/posts/pointers-as-map-keys/ | ||
// | ||
// use fmt.Sprintf("%v", key) as map keys | ||
type counter struct { | ||
m map[string]*entry | ||
} | ||
|
||
type entry struct { | ||
key interface{} | ||
count float64 | ||
} | ||
|
||
func newCounter() *counter { | ||
return &counter{make(map[string]*entry)} | ||
} | ||
|
||
func (c *counter) incr(key interface{}, count float64) { | ||
s := fmt.Sprintf("%v", key) | ||
if ent, ok := c.m[s]; ok { | ||
ent.count += count | ||
} else { | ||
c.m[s] = &entry{key, 1} | ||
} | ||
} | ||
|
||
func (c *counter) get(key interface{}) float64 { | ||
if ent, ok := c.m[fmt.Sprintf("%v", key)]; ok { | ||
return ent.count | ||
} | ||
return 0 | ||
} | ||
|
||
func (c *counter) rng(f func(key interface{}, count float64)) { | ||
for _, ent := range c.m { | ||
f(ent.key, ent.count) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright 2021 go-mcts. All rights reserved. | ||
// Use of this source code is governed by a MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package mcts | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
type someStructPointer struct { | ||
s string | ||
} | ||
|
||
func newPointer(s string) *someStructPointer { | ||
return &someStructPointer{s} | ||
} | ||
|
||
func TestCounter(t *testing.T) { | ||
m := make(map[*someStructPointer]int) | ||
m[newPointer("abc")]++ | ||
m[newPointer("abc")]++ | ||
assert.Equal(t, 2, len(m)) | ||
assert.Equal(t, 0, m[newPointer("abc")]) | ||
|
||
c := newCounter() | ||
c.incr(newPointer("abc"), 1) | ||
c.incr(newPointer("abc"), 1) | ||
assert.Equal(t, float64(2), c.get(newPointer("abc"))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
// Copyright 2021 go-mcts. All rights reserved. | ||
// Use of this source code is governed by a MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package tictactoe | ||
|
||
import ( | ||
"math/rand" | ||
|
||
"github.com/go-mcts/mcts" | ||
) | ||
|
||
var ( | ||
_ mcts.Move = (*move)(nil) | ||
_ mcts.State = (*state)(nil) | ||
) | ||
|
||
type move struct { | ||
x int | ||
y int | ||
v int | ||
} | ||
|
||
type state struct { | ||
playerToMove int | ||
board [3][3]int | ||
} | ||
|
||
func (s *state) PlayerToMove() int { | ||
return s.playerToMove | ||
} | ||
|
||
func (s *state) HasMoves() bool { | ||
return s.getResult(s.playerToMove) == -1 | ||
} | ||
|
||
func (s *state) GetMoves() []mcts.Move { | ||
moves := make([]mcts.Move, 0) | ||
if s.getResult(s.playerToMove) == -1 { | ||
for i := 0; i < 3; i++ { | ||
for j := 0; j < 3; j++ { | ||
if s.board[i][j] == 0 { | ||
m := &move{ | ||
x: i, | ||
y: j, | ||
v: s.playerToMove, | ||
} | ||
if s.playerToMove == 1 { | ||
m.v = 1 | ||
} else { | ||
m.v = -1 | ||
} | ||
moves = append(moves, m) | ||
} | ||
} | ||
} | ||
} | ||
return moves | ||
} | ||
|
||
func (s *state) DoMove(mctsMove mcts.Move) { | ||
m := mctsMove.(*move) | ||
if m.x < 0 || m.y < 0 || m.x > 2 || m.y > 2 || s.board[m.x][m.y] != 0 { | ||
panic("illegal move") | ||
} | ||
s.board[m.x][m.y] = m.v | ||
s.playerToMove = 3 - s.playerToMove | ||
} | ||
|
||
func (s *state) DoRandomMove(rd *rand.Rand) { | ||
moves := s.GetMoves() | ||
s.DoMove(moves[rd.Intn(len(moves))]) | ||
} | ||
|
||
func (s *state) GetResult(currentPlayerToMove int) float64 { | ||
if result := s.getResult(currentPlayerToMove); result == -1 { | ||
panic("game is not over") | ||
} else { | ||
return result | ||
} | ||
} | ||
|
||
func (s *state) getResult(currentPlayerToMove int) float64 { | ||
zero := 0 | ||
|
||
for i := 0; i < 3; i++ { | ||
row, col := 0, 0 | ||
for j := 0; j < 3; j++ { | ||
if s.board[i][j] == 0 { | ||
zero++ | ||
} | ||
row += s.board[i][j] | ||
col += s.board[j][i] | ||
} | ||
|
||
if row == 3 || row == -3 || col == 3 || col == -3 { | ||
if s.playerToMove == currentPlayerToMove { | ||
return 1 | ||
} | ||
return 0 | ||
} | ||
} | ||
|
||
tl := s.board[0][0] + s.board[1][1] + s.board[2][2] | ||
tr := s.board[0][2] + s.board[1][1] + s.board[2][0] | ||
|
||
if tl == 3 || tr == 3 || tl == -3 || tr == -3 { | ||
if s.playerToMove == currentPlayerToMove { | ||
return 1 | ||
} | ||
return 0 | ||
} | ||
|
||
if zero == 0 { | ||
return 0.5 | ||
} | ||
|
||
return -1 | ||
} | ||
|
||
func (s *state) Clone() mcts.State { | ||
return &state{ | ||
playerToMove: s.playerToMove, | ||
board: s.board, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright 2021 go-mcts. All rights reserved. | ||
// Use of this source code is governed by a MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package tictactoe | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/go-mcts/mcts" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestTicTacToe(t *testing.T) { | ||
rootState := &state{ | ||
playerToMove: 1, | ||
board: [3][3]int{ | ||
{0, 0, 0}, | ||
{0, 0, 0}, | ||
{0, 0, 0}, | ||
}, | ||
} | ||
mctsMove := mcts.ComputeMove(rootState, mcts.MaxIterations(20000), mcts.Verbose(true)) | ||
m := mctsMove.(*move) | ||
assert.Equal(t, 1, m.x) | ||
assert.Equal(t, 1, m.y) | ||
assert.Equal(t, 1, m.v) | ||
|
||
rootState = &state{ | ||
playerToMove: 1, | ||
board: [3][3]int{ | ||
{0, 0, 0}, | ||
{0, 1, 0}, | ||
{0, -1, 0}, | ||
}, | ||
} | ||
mctsMove = mcts.ComputeMove(rootState, mcts.Verbose(true)) | ||
m = mctsMove.(*move) | ||
assert.Equal(t, 1, m.v) | ||
|
||
assert.True(t, m.x == 0 && (m.y == 0 || m.y == 2) || | ||
m.x == 2 && (m.y == 0 || m.y == 2)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters