-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery Starbot ⭐ refactored brendanator/atari-rl #27
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,12 +20,8 @@ def bonus(self, observation): | |
prob = self.update_density_model(frame) | ||
recoding_prob = self.density_model_probability(frame) | ||
pseudo_count = prob * (1 - recoding_prob) / (recoding_prob - prob) | ||
if pseudo_count < 0: | ||
pseudo_count = 0 # Occasionally happens at start of training | ||
|
||
# Return exploration bonus | ||
exploration_bonus = self.beta / math.sqrt(pseudo_count + 0.01) | ||
return exploration_bonus | ||
pseudo_count = max(pseudo_count, 0) | ||
return self.beta / math.sqrt(pseudo_count + 0.01) | ||
Comment on lines
-23
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
def update_density_model(self, frame): | ||
return self.sum_pixel_probabilities(frame, self.density_model.update) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,7 @@ def __init__(self, config): | |
elif config.replay_priorities == 'proportional': | ||
self.priorities = ProportionalPriorities(config) | ||
else: | ||
raise Exception('Unknown replay_priorities: ' + config.replay_priorities) | ||
raise Exception(f'Unknown replay_priorities: {config.replay_priorities}') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def store_new_episode(self, observation): | ||
for frame in observation: | ||
|
@@ -133,7 +133,7 @@ def valid_indices(self, new_indices, input_range, indices=None): | |
return np.unique(np.append(valid_indices, indices)) | ||
|
||
def save(self): | ||
name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' | ||
name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with h5py.File(name, 'w') as h5f: | ||
for key, value in self.__dict__.items(): | ||
if key == 'priorities': | ||
|
@@ -144,7 +144,7 @@ def save(self): | |
h5f.create_dataset(key, data=value) | ||
|
||
def load(self): | ||
name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' | ||
name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with h5py.File(name, 'r') as h5f: | ||
for key in self.__dict__.keys(): | ||
if key == 'priorities': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,15 +81,14 @@ def train_agent(self, session, agent): | |
agent.replay_memory.save() | ||
|
||
def reset_target_network(self, session, step): | ||
if self.reset_op: | ||
if step > 0 and step % self.config.target_network_update_period == 0: | ||
if step > 0 and step % self.config.target_network_update_period == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question (llm): The refactoring to combine the conditions into a single if statement is good for readability, but ensure that the logic is equivalent and that self.reset_op is always defined when needed. |
||
if self.reset_op: | ||
Comment on lines
-84
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
session.run(self.reset_op) | ||
|
||
def train_batch(self, session, replay_memory, step): | ||
fetches = [self.global_step, self.train_op] + self.summary.operation(step) | ||
|
||
batch = replay_memory.sample_batch(fetches, self.config.batch_size) | ||
if batch: | ||
if batch := replay_memory.sample_batch(fetches, self.config.batch_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
step, priorities, summary = session.run(fetches, batch.feed_dict()) | ||
batch.update_priorities(priorities) | ||
self.summary.add_summary(summary, step) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,7 @@ def reset(self): | |
if self.render: self.env.render() | ||
self.frames = [] | ||
|
||
for i in range(np.random.randint(self.input_frames, self.max_noops + 1)): | ||
for _ in range(np.random.randint(self.input_frames, self.max_noops + 1)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
frame, reward_, done, _ = self.env.step(0) | ||
if self.render: self.env.render() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,12 +155,10 @@ def create_config(): | |
|
||
if config.async == 'one_step': | ||
config.batch_size = config.train_period | ||
elif config.async == 'n_step': | ||
config.batch_size = 1 | ||
elif config.async == 'a3c': | ||
elif config. async in ['n_step', 'a3c']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (llm): There's an extra space between 'config.' and 'async' which could lead to a syntax error. This should be corrected. |
||
config.batch_size = 1 | ||
else: | ||
raise Exception('Unknown asynchronous algorithm: ' + config.async) | ||
raise Exception(f'Unknown asynchronous algorithm: {config.async}') | ||
Comment on lines
-158
to
+161
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
config.n_step = config.async == 'n_step' | ||
config.actor_critic = config.async == 'a3c' | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,8 +144,8 @@ def variables(self): | |
def activation_summary(self, tensor): | ||
if self.write_summaries: | ||
tensor_name = tensor.op.name | ||
tf.summary.histogram(tensor_name + '/activations', tensor) | ||
tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(tensor)) | ||
tf.summary.histogram(f'{tensor_name}/activations', tensor) | ||
tf.summary.scalar(f'{tensor_name}/sparsity', tf.nn.zero_fraction(tensor)) | ||
Comment on lines
-147
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
class ActionValueHead(object): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,7 +108,7 @@ def create_summary_ops(self, loss, variables, gradients): | |
|
||
for grad, var in gradients: | ||
if grad is not None: | ||
tf.summary.histogram('gradient/' + var.name, grad) | ||
tf.summary.histogram(f'gradient/{var.name}', grad) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
self.summary.create_summary_op() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,9 +66,7 @@ def offset_data(t, name): | |
input_len = shape[0] | ||
if not hasattr(placeholder, 'zero_offset'): | ||
placeholder.zero_offset = tf.placeholder_with_default( | ||
input_len - 1, # If no zero_offset is given assume that t = 0 | ||
(), | ||
name + '/zero_offset') | ||
input_len - 1, (), f'{name}/zero_offset') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
end = t + 1 | ||
start = end - input_len | ||
|
@@ -100,11 +98,7 @@ def __init__(self, inputs, t): | |
|
||
class RequiredFeeds(object): | ||
def __init__(self, placeholder=None, time_offsets=0, feeds=None): | ||
if feeds: | ||
self.feeds = feeds | ||
else: | ||
self.feeds = {} | ||
|
||
self.feeds = feeds if feeds else {} | ||
Comment on lines
-103
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
if placeholder is None: | ||
return | ||
|
||
|
@@ -153,21 +147,20 @@ def required_feeds(cls, tensor): | |
if hasattr(tensor, 'required_feeds'): | ||
# Return cached result | ||
return tensor.required_feeds | ||
# Get feeds required by all inputs | ||
if isinstance(tensor, list): | ||
input_tensors = tensor | ||
else: | ||
# Get feeds required by all inputs | ||
if isinstance(tensor, list): | ||
input_tensors = tensor | ||
else: | ||
op = tensor if isinstance(tensor, tf.Operation) else tensor.op | ||
input_tensors = list(op.inputs) + list(op.control_inputs) | ||
op = tensor if isinstance(tensor, tf.Operation) else tensor.op | ||
input_tensors = list(op.inputs) + list(op.control_inputs) | ||
|
||
from networks import inputs | ||
feeds = inputs.RequiredFeeds() | ||
for input_tensor in input_tensors: | ||
feeds = feeds.merge(cls.required_feeds(input_tensor)) | ||
from networks import inputs | ||
feeds = inputs.RequiredFeeds() | ||
for input_tensor in input_tensors: | ||
feeds = feeds.merge(cls.required_feeds(input_tensor)) | ||
|
||
# Cache results | ||
if not isinstance(tensor, list): | ||
tensor.required_feeds = feeds | ||
# Cache results | ||
if not isinstance(tensor, list): | ||
tensor.required_feeds = feeds | ||
|
||
return feeds | ||
return feeds | ||
Comment on lines
+150
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,10 +31,7 @@ def batch_sigma_squared(self, batch): | |
self.v = (1 - self.beta) * self.v + self.beta * average_square_reward | ||
|
||
sigma_squared = (self.v - self.mu**2) / self.variance | ||
if sigma_squared > 0: | ||
return sigma_squared | ||
else: | ||
return 1.0 | ||
return sigma_squared if sigma_squared > 0 else 1.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def unnormalize_output(self, output): | ||
return output * self.scale_weight + self.scale_bias | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,9 +47,8 @@ def test_replay_memory(self): | |
self.assertAllEqual(feed_dict[inputs.alives], | ||
[[True, True], [True, False]]) | ||
|
||
discounted_reward = sum([ | ||
reward * config.discount_rate**(reward - 4) for reward in range(4, 11) | ||
]) | ||
discounted_reward = sum(reward * config.discount_rate**(reward - 4) | ||
for reward in range(4, 11)) | ||
Comment on lines
-50
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.assertNear( | ||
feed_dict[inputs.discounted_rewards][0], discounted_reward, err=0.0001) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,10 +27,7 @@ def episode(self, step, score, steps, duration): | |
self.summary_writer.add_summary(summary, step) | ||
|
||
def operation(self, step): | ||
if self.run_summary(step): | ||
return [self.summary_op] | ||
else: | ||
return [self.dummy_summary_op] | ||
return [self.summary_op] if self.run_summary(step) else [self.dummy_summary_op] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def add_summary(self, summary, step): | ||
if summary: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,18 +13,18 @@ def find_previous_run(dir): | |
if os.path.isdir(dir): | ||
runs = [child[4:] for child in os.listdir(dir) if child[:4] == 'run_'] | ||
if runs: | ||
return max([int(run) for run in runs]) | ||
return max(int(run) for run in runs) | ||
|
||
return 0 | ||
|
||
if config.run_dir == 'latest': | ||
parent_dir = 'runs/%s/' % config.game | ||
parent_dir = f'runs/{config.game}/' | ||
previous_run = find_previous_run(parent_dir) | ||
run_dir = parent_dir + ('run_%d' % previous_run) | ||
elif config.run_dir: | ||
run_dir = config.run_dir | ||
else: | ||
parent_dir = 'runs/%s/' % config.game | ||
parent_dir = f'runs/{config.game}/' | ||
Comment on lines
-16
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
previous_run = find_previous_run(parent_dir) | ||
run_dir = parent_dir + ('run_%d' % (previous_run + 1)) | ||
|
||
|
@@ -34,18 +34,18 @@ def find_previous_run(dir): | |
if not os.path.isdir(run_dir): | ||
os.makedirs(run_dir) | ||
|
||
log('Checkpoint and summary directory is %s' % run_dir) | ||
log(f'Checkpoint and summary directory is {run_dir}') | ||
|
||
return run_dir | ||
|
||
|
||
def format_offset(prefix, t): | ||
if t > 0: | ||
return prefix + '_t_plus_' + str(t) | ||
return f'{prefix}_t_plus_{str(t)}' | ||
elif t == 0: | ||
return prefix + '_t' | ||
return f'{prefix}_t' | ||
else: | ||
return prefix + '_t_minus_' + str(-t) | ||
return f'{prefix}_t_minus_{str(-t)}' | ||
Comment on lines
-44
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def add_loss_summaries(total_loss): | ||
|
@@ -91,7 +91,7 @@ def log(message): | |
import threading | ||
thread_id = threading.current_thread().name | ||
now = datetime.strftime(datetime.now(), '%F %X') | ||
print('%s %s: %s' % (now, thread_id, message)) | ||
print(f'{now} {thread_id}: {message}') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def memoize(f): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
Agent.action
refactored with the following changes:remove-unnecessary-else
)This removes the following comments ( why? ):