Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery Starbot ⭐ refactored brendanator/atari-rl #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ def new_game(self):
return observation, reward, done

def action(self, session, step, observation):
# Epsilon greedy exploration/exploitation even for bootstrapped DQN
if np.random.rand() < self.epsilon(step):
return self.atari.sample_action()
else:
[action] = session.run(
self.policy_network.choose_action,
{self.policy_network.inputs.observations: [observation]})
return action
[action] = session.run(
self.policy_network.choose_action,
{self.policy_network.inputs.observations: [observation]})
return action
Comment on lines -25 to +30
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Agent.action refactored with the following changes:

This removes the following comments ( why? ):

# Epsilon greedy exploration/exploitation even for bootstrapped DQN


def epsilon(self, step):
"""Epsilon is linearly annealed from an initial exploration value
Expand Down
8 changes: 2 additions & 6 deletions agents/exploration_bonus.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@ def bonus(self, observation):
prob = self.update_density_model(frame)
recoding_prob = self.density_model_probability(frame)
pseudo_count = prob * (1 - recoding_prob) / (recoding_prob - prob)
if pseudo_count < 0:
pseudo_count = 0 # Occasionally happens at start of training

# Return exploration bonus
exploration_bonus = self.beta / math.sqrt(pseudo_count + 0.01)
return exploration_bonus
pseudo_count = max(pseudo_count, 0)
return self.beta / math.sqrt(pseudo_count + 0.01)
Comment on lines -23 to +24
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ExplorationBonus.bonus refactored with the following changes:

This removes the following comments ( why? ):

# Return exploration bonus
# Occasionally happens at start of training


def update_density_model(self, frame):
return self.sum_pixel_probabilities(frame, self.density_model.update)
Expand Down
6 changes: 3 additions & 3 deletions agents/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config):
elif config.replay_priorities == 'proportional':
self.priorities = ProportionalPriorities(config)
else:
raise Exception('Unknown replay_priorities: ' + config.replay_priorities)
raise Exception(f'Unknown replay_priorities: {config.replay_priorities}')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ReplayMemory.__init__ refactored with the following changes:


def store_new_episode(self, observation):
for frame in observation:
Expand Down Expand Up @@ -133,7 +133,7 @@ def valid_indices(self, new_indices, input_range, indices=None):
return np.unique(np.append(valid_indices, indices))

def save(self):
name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf'
name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ReplayMemory.save refactored with the following changes:

with h5py.File(name, 'w') as h5f:
for key, value in self.__dict__.items():
if key == 'priorities':
Expand All @@ -144,7 +144,7 @@ def save(self):
h5f.create_dataset(key, data=value)

def load(self):
name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf'
name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ReplayMemory.load refactored with the following changes:

with h5py.File(name, 'r') as h5f:
for key in self.__dict__.keys():
if key == 'priorities':
Expand Down
7 changes: 3 additions & 4 deletions agents/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,14 @@ def train_agent(self, session, agent):
agent.replay_memory.save()

def reset_target_network(self, session, step):
if self.reset_op:
if step > 0 and step % self.config.target_network_update_period == 0:
if step > 0 and step % self.config.target_network_update_period == 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (llm): The refactoring to combine the conditions into a single if statement is good for readability, but ensure that the logic is equivalent and that self.reset_op is always defined when needed.

if self.reset_op:
Comment on lines -84 to +85
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer.reset_target_network refactored with the following changes:

session.run(self.reset_op)

def train_batch(self, session, replay_memory, step):
fetches = [self.global_step, self.train_op] + self.summary.operation(step)

batch = replay_memory.sample_batch(fetches, self.config.batch_size)
if batch:
if batch := replay_memory.sample_batch(fetches, self.config.batch_size):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer.train_batch refactored with the following changes:

step, priorities, summary = session.run(fetches, batch.feed_dict())
batch.update_priorities(priorities)
self.summary.add_summary(summary, step)
Expand Down
2 changes: 1 addition & 1 deletion atari/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def reset(self):
if self.render: self.env.render()
self.frames = []

for i in range(np.random.randint(self.input_frames, self.max_noops + 1)):
for _ in range(np.random.randint(self.input_frames, self.max_noops + 1)):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Atari.reset refactored with the following changes:

frame, reward_, done, _ = self.env.step(0)
if self.render: self.env.render()

Expand Down
6 changes: 2 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,10 @@ def create_config():

if config.async == 'one_step':
config.batch_size = config.train_period
elif config.async == 'n_step':
config.batch_size = 1
elif config.async == 'a3c':
elif config. async in ['n_step', 'a3c']:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (llm): There's an extra space between 'config.' and 'async' which could lead to a syntax error. This should be corrected.

config.batch_size = 1
else:
raise Exception('Unknown asynchronous algorithm: ' + config.async)
raise Exception(f'Unknown asynchronous algorithm: {config.async}')
Comment on lines -158 to +161
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_config refactored with the following changes:

config.n_step = config.async == 'n_step'
config.actor_critic = config.async == 'a3c'

Expand Down
4 changes: 2 additions & 2 deletions networks/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def variables(self):
def activation_summary(self, tensor):
if self.write_summaries:
tensor_name = tensor.op.name
tf.summary.histogram(tensor_name + '/activations', tensor)
tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(tensor))
tf.summary.histogram(f'{tensor_name}/activations', tensor)
tf.summary.scalar(f'{tensor_name}/sparsity', tf.nn.zero_fraction(tensor))
Comment on lines -147 to +148
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Network.activation_summary refactored with the following changes:



class ActionValueHead(object):
Expand Down
2 changes: 1 addition & 1 deletion networks/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def create_summary_ops(self, loss, variables, gradients):

for grad, var in gradients:
if grad is not None:
tf.summary.histogram('gradient/' + var.name, grad)
tf.summary.histogram(f'gradient/{var.name}', grad)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function NetworkFactory.create_summary_ops refactored with the following changes:


self.summary.create_summary_op()

Expand Down
37 changes: 15 additions & 22 deletions networks/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def offset_data(t, name):
input_len = shape[0]
if not hasattr(placeholder, 'zero_offset'):
placeholder.zero_offset = tf.placeholder_with_default(
input_len - 1, # If no zero_offset is given assume that t = 0
(),
name + '/zero_offset')
input_len - 1, (), f'{name}/zero_offset')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function auto_placeholder refactored with the following changes:

This removes the following comments ( why? ):

# If no zero_offset is given assume that t = 0


end = t + 1
start = end - input_len
Expand Down Expand Up @@ -100,11 +98,7 @@ def __init__(self, inputs, t):

class RequiredFeeds(object):
def __init__(self, placeholder=None, time_offsets=0, feeds=None):
if feeds:
self.feeds = feeds
else:
self.feeds = {}

self.feeds = feeds if feeds else {}
Comment on lines -103 to +101
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RequiredFeeds.__init__ refactored with the following changes:

if placeholder is None:
return

Expand Down Expand Up @@ -153,21 +147,20 @@ def required_feeds(cls, tensor):
if hasattr(tensor, 'required_feeds'):
# Return cached result
return tensor.required_feeds
# Get feeds required by all inputs
if isinstance(tensor, list):
input_tensors = tensor
else:
# Get feeds required by all inputs
if isinstance(tensor, list):
input_tensors = tensor
else:
op = tensor if isinstance(tensor, tf.Operation) else tensor.op
input_tensors = list(op.inputs) + list(op.control_inputs)
op = tensor if isinstance(tensor, tf.Operation) else tensor.op
input_tensors = list(op.inputs) + list(op.control_inputs)

from networks import inputs
feeds = inputs.RequiredFeeds()
for input_tensor in input_tensors:
feeds = feeds.merge(cls.required_feeds(input_tensor))
from networks import inputs
feeds = inputs.RequiredFeeds()
for input_tensor in input_tensors:
feeds = feeds.merge(cls.required_feeds(input_tensor))

# Cache results
if not isinstance(tensor, list):
tensor.required_feeds = feeds
# Cache results
if not isinstance(tensor, list):
tensor.required_feeds = feeds

return feeds
return feeds
Comment on lines +150 to +166
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RequiredFeeds.required_feeds refactored with the following changes:

5 changes: 1 addition & 4 deletions networks/reward_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def batch_sigma_squared(self, batch):
self.v = (1 - self.beta) * self.v + self.beta * average_square_reward

sigma_squared = (self.v - self.mu**2) / self.variance
if sigma_squared > 0:
return sigma_squared
else:
return 1.0
return sigma_squared if sigma_squared > 0 else 1.0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RewardScaling.batch_sigma_squared refactored with the following changes:


def unnormalize_output(self, output):
return output * self.scale_weight + self.scale_bias
Expand Down
5 changes: 2 additions & 3 deletions test/test_replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def test_replay_memory(self):
self.assertAllEqual(feed_dict[inputs.alives],
[[True, True], [True, False]])

discounted_reward = sum([
reward * config.discount_rate**(reward - 4) for reward in range(4, 11)
])
discounted_reward = sum(reward * config.discount_rate**(reward - 4)
for reward in range(4, 11))
Comment on lines -50 to +51
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ReplayMemoryTest.test_replay_memory refactored with the following changes:

self.assertNear(
feed_dict[inputs.discounted_rewards][0], discounted_reward, err=0.0001)

Expand Down
5 changes: 1 addition & 4 deletions util/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def episode(self, step, score, steps, duration):
self.summary_writer.add_summary(summary, step)

def operation(self, step):
if self.run_summary(step):
return [self.summary_op]
else:
return [self.dummy_summary_op]
return [self.summary_op] if self.run_summary(step) else [self.dummy_summary_op]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Summary.operation refactored with the following changes:


def add_summary(self, summary, step):
if summary:
Expand Down
16 changes: 8 additions & 8 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ def find_previous_run(dir):
if os.path.isdir(dir):
runs = [child[4:] for child in os.listdir(dir) if child[:4] == 'run_']
if runs:
return max([int(run) for run in runs])
return max(int(run) for run in runs)

return 0

if config.run_dir == 'latest':
parent_dir = 'runs/%s/' % config.game
parent_dir = f'runs/{config.game}/'
previous_run = find_previous_run(parent_dir)
run_dir = parent_dir + ('run_%d' % previous_run)
elif config.run_dir:
run_dir = config.run_dir
else:
parent_dir = 'runs/%s/' % config.game
parent_dir = f'runs/{config.game}/'
Comment on lines -16 to +27
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function run_directory refactored with the following changes:

previous_run = find_previous_run(parent_dir)
run_dir = parent_dir + ('run_%d' % (previous_run + 1))

Expand All @@ -34,18 +34,18 @@ def find_previous_run(dir):
if not os.path.isdir(run_dir):
os.makedirs(run_dir)

log('Checkpoint and summary directory is %s' % run_dir)
log(f'Checkpoint and summary directory is {run_dir}')

return run_dir


def format_offset(prefix, t):
if t > 0:
return prefix + '_t_plus_' + str(t)
return f'{prefix}_t_plus_{str(t)}'
elif t == 0:
return prefix + '_t'
return f'{prefix}_t'
else:
return prefix + '_t_minus_' + str(-t)
return f'{prefix}_t_minus_{str(-t)}'
Comment on lines -44 to +48
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function format_offset refactored with the following changes:



def add_loss_summaries(total_loss):
Expand Down Expand Up @@ -91,7 +91,7 @@ def log(message):
import threading
thread_id = threading.current_thread().name
now = datetime.strftime(datetime.now(), '%F %X')
print('%s %s: %s' % (now, thread_id, message))
print(f'{now} {thread_id}: {message}')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function log refactored with the following changes:



def memoize(f):
Expand Down