Skip to content

Latest commit

 

History

History
66 lines (49 loc) · 1.87 KB

README.md

File metadata and controls

66 lines (49 loc) · 1.87 KB

DeepCFR.jl

codecov

Deep Counterfactual Regret Minimization (Brown et al.)

using CounterfactualRegret
using CounterfactualRegret.Games
import CounterfactualRegret as CFR
using StaticArrays
using DeepCFR

# Get Rock-Paper-Scissors as default CounterfactualRegret.jl matrix game
RPS = MatrixGame()

#=
Information state type of matrix game is `Int`, 
so extend `vectorized` method to convert to vector 
s.t. it's able to be passed through a Flux.jl network
=#
DeepCFR.vectorized(::MatrixGame, I) = SA[Float32(I)]


sol = DeepCFRSolver(
        RPS; 
        buffer_size = 100*10^3, 
        batch_size = 128, 
        traversals = 10, 
        on_gpu = false
)

# train CFR solver for 1000 iterations
train!(sol, 1_000, show_progress=true)

I0 = DeepCFR.vectorized(0) # information state corresponding to first player's turn
I1 = DeepCFR.vectorized(1) # information state corresponding to second player's turn

sol(I0) # return strategy for player 1 
sol(I1) # return strategy for player 2

Define custom Flux.jl networks

using Flux

in_size = 1 # information state vector is of length 1 (ref `DeepCFR.vectorized`)
out_size = 3 # 3 actions: rock, paper, scissors

#= 
strategy is a probability distribution -> network output must add to 1.
Simple solution is to softmax output
=#
strategy_network = Chain(Dense(in_size, 40), Dense(40, out_size), softmax)

# regret/value does not need to be normalized
value_network = Chain(Dense(in_size, 20), Dense(20, out_size))

sol = DeepCFRSolver(
        RPS; 
        strategy = strategy_network,
        values = (value_network, deepcopy(value_network)) 
) # DeepCFR requires as many value networks as there are players (2 here)