From 6006a5ba4737593ee2e1e9be6c2a48e611975572 Mon Sep 17 00:00:00 2001 From: Eric Hunsberger Date: Wed, 23 Nov 2016 14:59:26 -0500 Subject: [PATCH 1/3] Refactored gui.py image display --- nengo_extras/gui.py | 81 +++++++++++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/nengo_extras/gui.py b/nengo_extras/gui.py index 16db59c..a75d47d 100644 --- a/nengo_extras/gui.py +++ b/nengo_extras/gui.py @@ -1,29 +1,38 @@ -def preprocess_display(x, image_shape, - transpose=(1, 2, 0), scale=255., offset=0.): +from __future__ import absolute_import + +import nengo + +from .convnet import ShapeParam + + +def preprocess_display(x, transpose=(1, 2, 0), scale=255., offset=0.): """Basic preprocessing that reshapes, transposes, and scales an image""" - y = x.reshape(image_shape) - y = (y + offset) * scale + x = (x + offset) * scale if transpose is not None: - y = y.transpose(transpose) # color channel last - if y.shape[-1] == 1: - y = y[..., 0] - return y.clip(0, 255).astype('uint8') + x = x.transpose(transpose) # color channel last + if x.shape[-1] == 1: + x = x[..., 0] + return x.clip(0, 255).astype('uint8') -def image_display_function(image_shape, preprocess=preprocess_display, - **preprocess_args): - """Make a function to display images in Nengo GUI +def image_html_function(image_shape, preprocess=preprocess_display, + **preprocess_args): + """Make a function to turn an image into HTML to display as an SVG. - Examples - -------- - >>> u = nengo.Node(nengo.processes.PresentInput(images, 0.1)) - >>> display_f = nengo_extras.gui.image_display_function(image_shape) - >>> display_node = nengo.Node(display_f, size_in=u.size_out) - >>> nengo.Connection(u, display_node, synapse=None) + Parameters + ---------- + image_shape : array_like (3,) + The shape of the image: (channels, height, width) + preprocess : callable + Callable that takes an image and preprocesses it to be displayed. + preprocess_args : dict + Optional dictionary of keyword arguments for ``preprocess``. - Requirements - ------------ - pillow (provides PIL, `pip install pillow`) + Returns + ------- + html_function : callable (t, x) + A function that takes time and a flattened image, and returns a string + that defines an SVG object in HTML to display the image. """ import base64 import PIL.Image @@ -31,18 +40,42 @@ def image_display_function(image_shape, preprocess=preprocess_display, assert len(image_shape) == 3 - def display_func(t, x): - y = preprocess(x, image_shape, **preprocess_args) + def html_function(x): + x = x.reshape(image_shape) + y = preprocess(x, **preprocess_args) png = PIL.Image.fromarray(y) buffer = cStringIO.StringIO() png.save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()) - - display_func._nengo_html_ = ''' + return ''' ''' % (''.join(img_str)) + return html_function + + +def image_display_function(image_shape, preprocess=preprocess_display, + **preprocess_args): + """Make a function to display images in Nengo GUI + + Examples + -------- + >>> u = nengo.Node(nengo.processes.PresentInput(images, 0.1)) + >>> display_f = nengo_extras.gui.image_display_function(image_shape) + >>> display_node = nengo.Node(display_f, size_in=u.size_out) + >>> nengo.Connection(u, display_node, synapse=None) + + Requirements + ------------ + pillow (provides PIL, `pip install pillow`) + """ + html_function = image_html_function( + image_shape, preprocess=preprocess, **preprocess_args) + + def display_func(t, x): + display_func._nengo_html_ = html_function(x) + return display_func From 94193ca88425c789944737a3de390d52425ac325 Mon Sep 17 00:00:00 2001 From: Eric Hunsberger Date: Tue, 7 Jun 2016 11:55:17 -0400 Subject: [PATCH 2/3] Added PresentImages process to show images in GUI --- nengo_extras/gui.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/nengo_extras/gui.py b/nengo_extras/gui.py index a75d47d..9f33278 100644 --- a/nengo_extras/gui.py +++ b/nengo_extras/gui.py @@ -79,3 +79,35 @@ def display_func(t, x): display_func._nengo_html_ = html_function(x) return display_func + + +class PresentImages(nengo.processes.PresentInput): + """PresentInput process whose inputs are displayed as images in nengo_gui. + """ + + image_shape = ShapeParam('image_shape', length=3, low=1) + + def __init__(self, images, presentation_time, **kwargs): + self.image_shape = images.shape[1:] + super(PresentImages, self).__init__( + images, presentation_time, **kwargs) + self.configure_display() + + def configure_display(self, preprocess=preprocess_display, + **preprocess_args): + """Configure display parameters for images + + Parameters + ---------- + preprocess : callable + Callable that takes an image and preprocesses it to be displayed. + preprocess_args : dict + Optional dictionary of keyword arguments for ``preprocess``. + """ + html_function = image_html_function( + self.image_shape, preprocess=preprocess, **preprocess_args) + + def _nengo_html_(t, x): + return html_function(x) + + self._nengo_html_ = _nengo_html_ From b3d97e92c453599cb0aecaf274ff790aa83831ac Mon Sep 17 00:00:00 2001 From: Eric Hunsberger Date: Wed, 23 Nov 2016 15:01:10 -0500 Subject: [PATCH 3/3] Revised examples to use PresentImages Instead of ``image_display_function``. --- examples/cuda_convnet/cifar10_spiking_cnn.py | 12 ++++-------- examples/cuda_convnet/imagenet_spiking_cnn.py | 11 ++++------- examples/keras/mnist_spiking_cnn.py | 10 ++-------- examples/vision/mnist_single_layer_gui.py | 13 +++++-------- 4 files changed, 15 insertions(+), 31 deletions(-) diff --git a/examples/cuda_convnet/cifar10_spiking_cnn.py b/examples/cuda_convnet/cifar10_spiking_cnn.py index 18bc028..28ae528 100644 --- a/examples/cuda_convnet/cifar10_spiking_cnn.py +++ b/examples/cuda_convnet/cifar10_spiking_cnn.py @@ -7,7 +7,7 @@ from nengo_extras.data import load_cifar10 from nengo_extras.cuda_convnet import CudaConvnetNetwork, load_model_pickle -from nengo_extras.gui import image_display_function +from nengo_extras.gui import PresentImages (X_train, y_train), (X_test, y_test), label_names = load_cifar10(label_names=True) X_train = X_train.reshape(-1, 3, 32, 32).astype('float32') @@ -30,19 +30,15 @@ model = nengo.Network() with model: - u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time)) + u = nengo.Node(PresentImages(X_test, presentation_time)) + u.output.configure_display(offset=data_mean, scale=1.) + ccnet = CudaConvnetNetwork(cc_model, synapse=nengo.synapses.Alpha(0.005)) nengo.Connection(u, ccnet.inputs['data'], synapse=None) input_p = nengo.Probe(u) output_p = nengo.Probe(ccnet.output) - # --- image display - image_shape = X_test.shape[1:] - display_f = image_display_function(image_shape, scale=1, offset=data_mean) - display_node = nengo.Node(display_f, size_in=u.size_out) - nengo.Connection(u, display_node, synapse=None) - # --- output spa display vocab_names = [s.upper() for s in label_names] vocab_vectors = np.eye(len(vocab_names)) diff --git a/examples/cuda_convnet/imagenet_spiking_cnn.py b/examples/cuda_convnet/imagenet_spiking_cnn.py index fa7779d..e653999 100644 --- a/examples/cuda_convnet/imagenet_spiking_cnn.py +++ b/examples/cuda_convnet/imagenet_spiking_cnn.py @@ -9,7 +9,7 @@ from nengo_extras.data import load_ilsvrc2012, spasafe_names from nengo_extras.cuda_convnet import CudaConvnetNetwork, load_model_pickle -from nengo_extras.gui import image_display_function +from nengo_extras.gui import PresentImages # retrieve from https://figshare.com/s/cdde71007405eb11a88f filename = 'ilsvrc-2012-batches-test3.tar.gz' @@ -34,18 +34,15 @@ model = nengo.Network() with model: - u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time)) + u = nengo.Node(PresentImages(X_test, presentation_time)) + u.output.configure_display(offset=data_mean, scale=1.) + ccnet = CudaConvnetNetwork(cc_model, synapse=nengo.synapses.Alpha(0.001)) nengo.Connection(u, ccnet.inputs['data'], synapse=None) # input_p = nengo.Probe(u) output_p = nengo.Probe(ccnet.output) - # --- image display - display_f = image_display_function(image_shape, scale=1., offset=data_mean) - display_node = nengo.Node(display_f, size_in=u.size_out) - nengo.Connection(u, display_node, synapse=None) - # --- output spa display vocab_names = spasafe_names(label_names) vocab_vectors = np.eye(len(vocab_names)) diff --git a/examples/keras/mnist_spiking_cnn.py b/examples/keras/mnist_spiking_cnn.py index 7737a41..47496fa 100644 --- a/examples/keras/mnist_spiking_cnn.py +++ b/examples/keras/mnist_spiking_cnn.py @@ -15,7 +15,7 @@ import nengo from nengo_extras.keras import ( load_model_pair, save_model_pair, SequentialNetwork, SoftLIF) -from nengo_extras.gui import image_display_function +from nengo_extras.gui import PresentImages filename = 'mnist_spiking_cnn' @@ -88,19 +88,13 @@ model = nengo.Network() with model: - u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time)) + u = nengo.Node(PresentImages(X_test, presentation_time)) seq = SequentialNetwork(kmodel, synapse=nengo.synapses.Alpha(0.005)) nengo.Connection(u, seq.input, synapse=None) input_p = nengo.Probe(u) output_p = nengo.Probe(seq.output) - # --- image display - image_shape = kmodel.input_shape[1:] - display_f = image_display_function(image_shape) - display_node = nengo.Node(display_f, size_in=u.size_out) - nengo.Connection(u, display_node, synapse=None) - # --- output spa display vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR', 'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE'] diff --git a/examples/vision/mnist_single_layer_gui.py b/examples/vision/mnist_single_layer_gui.py index 642d6cb..74874dc 100644 --- a/examples/vision/mnist_single_layer_gui.py +++ b/examples/vision/mnist_single_layer_gui.py @@ -8,7 +8,7 @@ from nengo_extras.data import load_mnist from nengo_extras.vision import Gabor, Mask -from nengo_extras.gui import image_display_function +from nengo_extras.gui import PresentImages def one_hot(labels, c=None): @@ -57,7 +57,10 @@ def one_hot(labels, c=None): presentation_time = 0.1 with nengo.Network(seed=3) as model: - u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time)) + u = nengo.Node(PresentImages(X_test.reshape((-1, 1, 28, 28)), + presentation_time)) + u.output.configure_display(offset=1., scale=128.) + a = nengo.Ensemble(n_hid, n_vis, **ens_params) v = nengo.Node(size_in=n_out) nengo.Connection(u, a, synapse=None) @@ -65,12 +68,6 @@ def one_hot(labels, c=None): a, v, synapse=None, eval_points=X_train, function=train_targets, solver=solver) - # --- image display - image_shape = (1, 28, 28) - display_f = image_display_function(image_shape, offset=1, scale=128) - display_node = nengo.Node(display_f, size_in=u.size_out) - nengo.Connection(u, display_node, synapse=None) - # --- output spa display vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR', 'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE']