diff --git a/README.md b/README.md index 134b9d8..f72f332 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,20 @@ state := &State{ move := mcts.ComputeMove(state, mcts.MaxIterations(100000), mcts.Verbose(true)) ``` +By default, `runtime.NumCPU()` goroutines are started for compute: + +``` +2021-10-16T16:13:33+08:00 DEBUG 100000 games played (493618.37 / second). +2021-10-16T16:13:33+08:00 DEBUG 100000 games played (487598.00 / second). +2021-10-16T16:13:33+08:00 DEBUG 100000 games played (474579.62 / second). +2021-10-16T16:13:33+08:00 DEBUG 100000 games played (470514.17 / second). +2021-10-16T16:13:33+08:00 DEBUG Move: 2 (16% visits) (48% wins) +2021-10-16T16:13:33+08:00 DEBUG Move: 1 (72% visits) (67% wins) +2021-10-16T16:13:33+08:00 DEBUG Move: 3 (13% visits) (48% wins) +2021-10-16T16:13:33+08:00 DEBUG Best: 1 (72% visits) (67% wins) +2021-10-16T16:13:33+08:00 DEBUG 400000 games played in 0.21 s. (1881366.23 / second, 4 parallel jobs). +``` + ## License This project is under the MIT License. See the [LICENSE](LICENSE) file for the full license text. diff --git a/examples/nim/nim.go b/examples/nim/nim.go index 81128bb..ea49dc4 100644 --- a/examples/nim/nim.go +++ b/examples/nim/nim.go @@ -11,27 +11,27 @@ import ( ) var ( - _ mcts.Move = (*Move)(nil) - _ mcts.State = (*State)(nil) + _ mcts.Move = (*move)(nil) + _ mcts.State = (*state)(nil) ) -type Move int +type move int -type State struct { +type state struct { playerToMove int chips int } -func (s *State) PlayerToMove() int { +func (s *state) PlayerToMove() int { return s.playerToMove } -func (s *State) HasMoves() bool { +func (s *state) HasMoves() bool { s.checkInvariant() return s.chips > 0 } -func (s *State) GetMoves() []mcts.Move { +func (s *state) GetMoves() []mcts.Move { s.checkInvariant() var moves []mcts.Move @@ -41,7 +41,7 @@ func (s *State) GetMoves() []mcts.Move { return moves } -func (s *State) DoMove(move mcts.Move) { +func (s *state) DoMove(move mcts.Move) { m := move.(int) if m < 1 || m > 3 { panic("illegal move") @@ -54,7 +54,7 @@ func (s *State) DoMove(move mcts.Move) { s.checkInvariant() } -func (s *State) DoRandomMove(rd *rand.Rand) { +func (s *state) DoRandomMove(rd *rand.Rand) { if s.chips <= 0 { panic("invalid chips") } @@ -66,7 +66,7 @@ func (s *State) DoRandomMove(rd *rand.Rand) { s.checkInvariant() } -func (s *State) GetResult(currentPlayerToMove int) float64 { +func (s *state) GetResult(currentPlayerToMove int) float64 { if s.chips != 0 { panic("game not over") } @@ -78,14 +78,14 @@ func (s *State) GetResult(currentPlayerToMove int) float64 { return 0.0 } -func (s *State) Clone() mcts.State { - return &State{ +func (s *state) Clone() mcts.State { + return &state{ playerToMove: s.playerToMove, chips: s.chips, } } -func (s *State) checkInvariant() { +func (s *state) checkInvariant() { if s.chips < 0 || (s.playerToMove != 1 && s.playerToMove != 2) { panic("illegal state") } diff --git a/examples/nim/nim_test.go b/examples/nim/nim_test.go index dce0625..a4a4f6f 100644 --- a/examples/nim/nim_test.go +++ b/examples/nim/nim_test.go @@ -14,7 +14,7 @@ import ( func TestNim(t *testing.T) { for chips := 4; chips <= 21; chips++ { if chips%4 != 0 { - state := &State{ + state := &state{ playerToMove: 1, chips: chips, } diff --git a/mcts.go b/mcts.go index 204f544..a1a82de 100644 --- a/mcts.go +++ b/mcts.go @@ -8,14 +8,23 @@ import ( "math/rand" ) +// Move must be implemented for different games type Move interface{} +// State must be implemented for different games type State interface { + // PlayerToMove is who next to play PlayerToMove() int + // HasMoves return whether the game is over HasMoves() bool + // GetMoves get all legal moves GetMoves() []Move + // DoMove modify state with the given move DoMove(move Move) + // DoRandomMove do random move with the given random engine DoRandomMove(rd *rand.Rand) + // GetResult return game result GetResult(currentPlayerToMove int) float64 + // Clone is deep copy Clone() State } diff --git a/options.go b/options.go index aff8101..566d1d9 100644 --- a/options.go +++ b/options.go @@ -9,15 +9,16 @@ import ( "time" ) +// Options use functional-option to customize mcts type Options struct { - Groutines int + Goroutines int MaxIterations int MaxTime time.Duration Verbose bool } var defaultOptions = Options{ - Groutines: runtime.NumCPU(), + Goroutines: runtime.NumCPU(), MaxIterations: 10000, MaxTime: -1, Verbose: false, @@ -25,24 +26,28 @@ var defaultOptions = Options{ type Option func(*Options) +// Goroutines number of goroutines, default is runtime.NumCPU() func Goroutines(number int) Option { return func(o *Options) { - o.Groutines = number + o.Goroutines = number } } +// MaxIterations maximum number of iterations, default is 10000 func MaxIterations(iter int) Option { return func(o *Options) { o.MaxIterations = iter } } +// MaxTime search timeout, default is not limit func MaxTime(d time.Duration) Option { return func(o *Options) { o.MaxTime = d } } +// Verbose print details log, default is false func Verbose(v bool) Option { return func(o *Options) { o.Verbose = v diff --git a/options_test.go b/options_test.go index 83b9e7b..a3410b5 100644 --- a/options_test.go +++ b/options_test.go @@ -34,7 +34,7 @@ func Test_newOptions(t *testing.T) { }, }, Options{ - Groutines: runtime.NumCPU(), + Goroutines: runtime.NumCPU(), MaxIterations: 100000, MaxTime: -1, Verbose: true, @@ -51,7 +51,7 @@ func Test_newOptions(t *testing.T) { }, }, Options{ - Groutines: 1, + Goroutines: 1, MaxIterations: 100000, MaxTime: 5 * time.Second, Verbose: true, diff --git a/uct.go b/uct.go index 0f91d58..0d976bf 100644 --- a/uct.go +++ b/uct.go @@ -64,6 +64,7 @@ func computeTree(rootState State, rd *rand.Rand, opts ...Option) *node { return root } +// ComputeMove start multi goroutines to compute best move func ComputeMove(rootState State, opts ...Option) Move { options := newOptions(opts...) @@ -82,8 +83,8 @@ func ComputeMove(rootState State, opts ...Option) Move { startTime := time.Now() - rootFutures := make(chan *node, options.Groutines) - for i := 0; i < options.Groutines; i++ { + rootFutures := make(chan *node, options.Goroutines) + for i := 0; i < options.Goroutines; i++ { go func() { rd := rand.New(rand.NewSource(time.Now().UnixNano())) rootFutures <- computeTree(rootState, rd, opts...) @@ -93,7 +94,7 @@ func ComputeMove(rootState State, opts ...Option) Move { visits := make(map[Move]int) wins := make(map[Move]float64) gamePlayed := 0 - for i := 0; i < options.Groutines; i++ { + for i := 0; i < options.Goroutines; i++ { root := <-rootFutures gamePlayed += root.visits for _, c := range root.children { @@ -133,7 +134,7 @@ func ComputeMove(rootState State, opts ...Option) Move { gamePlayed, now.Sub(startTime).Seconds(), float64(gamePlayed)/now.Sub(startTime).Seconds(), - options.Groutines, + options.Goroutines, ) } return bestMove