diff --git a/docs/tutorials/training_agents/FrozenLake_tuto.py b/docs/tutorials/training_agents/FrozenLake_tuto.py index 7830965c3..318efaa47 100644 --- a/docs/tutorials/training_agents/FrozenLake_tuto.py +++ b/docs/tutorials/training_agents/FrozenLake_tuto.py @@ -161,12 +161,10 @@ def choose_action(self, action_space, state, qtable): # Exploitation (taking the biggest Q-value for this state) else: # Break ties randomly - # If all actions are the same for this state we choose a random one - # (otherwise `np.argmax()` would always take the first one) - if np.all(qtable[state, :]) == qtable[state, 0]: - action = action_space.sample() - else: - action = np.argmax(qtable[state, :]) + # Find the indices where the Q-value equals the maximum value + # Choose a random action from the indices where the Q-value is maximum + max_ids = np.where(qtable[state, :] == max(qtable[state, :]))[0] + action = rng.choice(max_ids) return action