Skip to content

Commit

Permalink
Minor optimisations (keras-team#12627)
Browse files Browse the repository at this point in the history
* Avoid importing copy

* Minor optimisation

* Cleanup

* PEP8

* Revert changes
  • Loading branch information
abhaikollara authored and fchollet committed Apr 8, 2019
1 parent 30fe4ff commit b2771d1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 14 deletions.
13 changes: 3 additions & 10 deletions keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import absolute_import
from __future__ import division

import copy
import re
from six.moves import zip

Expand Down Expand Up @@ -139,10 +138,7 @@ def __init__(self, **kwargs):
if 'batch_input_shape' in kwargs:
batch_input_shape = tuple(kwargs['batch_input_shape'])
elif 'input_shape' in kwargs:
if 'batch_size' in kwargs:
batch_size = kwargs['batch_size']
else:
batch_size = None
batch_size = kwargs.get('batch_size')
batch_input_shape = (
batch_size,) + tuple(kwargs['input_shape'])
self.batch_input_shape = batch_input_shape
Expand All @@ -155,10 +151,7 @@ def __init__(self, **kwargs):
dtype = K.floatx()
self.dtype = dtype

if 'weights' in kwargs:
self._initial_weights = kwargs['weights']
else:
self._initial_weights = None
self._initial_weights = kwargs.get('weights')

@staticmethod
def _node_key(layer, node_index):
Expand Down Expand Up @@ -441,7 +434,7 @@ def __call__(self, inputs, **kwargs):

# Handle mask propagation.
previous_mask = _collect_previous_mask(inputs)
user_kwargs = copy.copy(kwargs)
user_kwargs = kwargs.copy()
if not is_all_none(previous_mask):
# The previous layer generated a mask.
if has_arg(self.call, 'mask'):
Expand Down
6 changes: 2 additions & 4 deletions keras/engine/training_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,7 @@ def test_loop(model, f, ins,
batch_outs = f(ins)
if isinstance(batch_outs, list):
if step == 0:
for _ in enumerate(batch_outs):
outs.append(0.)
outs.extend([0.] * len(batch_outs))
for i, batch_out in enumerate(batch_outs):
if i in stateful_metric_indices:
outs[i] = float(batch_out)
Expand Down Expand Up @@ -471,8 +470,7 @@ def test_loop(model, f, ins,
batch_outs = f(ins_batch)
if isinstance(batch_outs, list):
if batch_index == 0:
for batch_out in enumerate(batch_outs):
outs.append(0.)
outs.extend([0.] * len(batch_outs))
for i, batch_out in enumerate(batch_outs):
if i in stateful_metric_indices:
outs[i] = batch_out
Expand Down

0 comments on commit b2771d1

Please sign in to comment.