diff --git a/cleanrl/pqn.py b/cleanrl/pqn.py index 6ed6e205..fd0eb8e2 100644 --- a/cleanrl/pqn.py +++ b/cleanrl/pqn.py @@ -212,8 +212,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): else: nextnonterminal = 1.0 - dones[t + 1] next_value = values[t + 1] - returns[t] = rewards[t] + args.gamma * ( - args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal + returns[t] = ( + rewards[t] + + args.gamma * (args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value) * nextnonterminal ) # flatten the batch diff --git a/cleanrl/pqn_atari_envpool.py b/cleanrl/pqn_atari_envpool.py index 45fd5a4c..9602640e 100644 --- a/cleanrl/pqn_atari_envpool.py +++ b/cleanrl/pqn_atari_envpool.py @@ -255,8 +255,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): else: nextnonterminal = 1.0 - dones[t + 1] next_value = values[t + 1] - returns[t] = rewards[t] + args.gamma * ( - args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal + returns[t] = ( + rewards[t] + + args.gamma * (args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value) * nextnonterminal ) # flatten the batch diff --git a/cleanrl/pqn_atari_envpool_lstm.py b/cleanrl/pqn_atari_envpool_lstm.py index 6b348b0a..10d0de4e 100644 --- a/cleanrl/pqn_atari_envpool_lstm.py +++ b/cleanrl/pqn_atari_envpool_lstm.py @@ -290,8 +290,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): else: nextnonterminal = 1.0 - dones[t + 1] next_value = values[t + 1] - returns[t] = rewards[t] + args.gamma * ( - args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal + returns[t] = ( + rewards[t] + + args.gamma * (args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value) * nextnonterminal ) # flatten the batch