Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added PresentImages process to show images in GUI #22

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions examples/cuda_convnet/cifar10_spiking_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from nengo_extras.data import load_cifar10
from nengo_extras.cuda_convnet import CudaConvnetNetwork, load_model_pickle
from nengo_extras.gui import image_display_function
from nengo_extras.gui import PresentImages

(X_train, y_train), (X_test, y_test), label_names = load_cifar10(label_names=True)
X_train = X_train.reshape(-1, 3, 32, 32).astype('float32')
Expand All @@ -30,19 +30,15 @@

model = nengo.Network()
with model:
u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
u = nengo.Node(PresentImages(X_test, presentation_time))
u.output.configure_display(offset=data_mean, scale=1.)

ccnet = CudaConvnetNetwork(cc_model, synapse=nengo.synapses.Alpha(0.005))
nengo.Connection(u, ccnet.inputs['data'], synapse=None)

input_p = nengo.Probe(u)
output_p = nengo.Probe(ccnet.output)

# --- image display
image_shape = X_test.shape[1:]
display_f = image_display_function(image_shape, scale=1, offset=data_mean)
display_node = nengo.Node(display_f, size_in=u.size_out)
nengo.Connection(u, display_node, synapse=None)

# --- output spa display
vocab_names = [s.upper() for s in label_names]
vocab_vectors = np.eye(len(vocab_names))
Expand Down
11 changes: 4 additions & 7 deletions examples/cuda_convnet/imagenet_spiking_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from nengo_extras.data import load_ilsvrc2012, spasafe_names
from nengo_extras.cuda_convnet import CudaConvnetNetwork, load_model_pickle
from nengo_extras.gui import image_display_function
from nengo_extras.gui import PresentImages

# retrieve from https://figshare.com/s/cdde71007405eb11a88f
filename = 'ilsvrc-2012-batches-test3.tar.gz'
Expand All @@ -34,18 +34,15 @@

model = nengo.Network()
with model:
u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
u = nengo.Node(PresentImages(X_test, presentation_time))
u.output.configure_display(offset=data_mean, scale=1.)

ccnet = CudaConvnetNetwork(cc_model, synapse=nengo.synapses.Alpha(0.001))
nengo.Connection(u, ccnet.inputs['data'], synapse=None)

# input_p = nengo.Probe(u)
output_p = nengo.Probe(ccnet.output)

# --- image display
display_f = image_display_function(image_shape, scale=1., offset=data_mean)
display_node = nengo.Node(display_f, size_in=u.size_out)
nengo.Connection(u, display_node, synapse=None)

# --- output spa display
vocab_names = spasafe_names(label_names)
vocab_vectors = np.eye(len(vocab_names))
Expand Down
10 changes: 2 additions & 8 deletions examples/keras/mnist_spiking_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import nengo
from nengo_extras.keras import (
load_model_pair, save_model_pair, SequentialNetwork, SoftLIF)
from nengo_extras.gui import image_display_function
from nengo_extras.gui import PresentImages

filename = 'mnist_spiking_cnn'

Expand Down Expand Up @@ -88,19 +88,13 @@

model = nengo.Network()
with model:
u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
u = nengo.Node(PresentImages(X_test, presentation_time))
seq = SequentialNetwork(kmodel, synapse=nengo.synapses.Alpha(0.005))
nengo.Connection(u, seq.input, synapse=None)

input_p = nengo.Probe(u)
output_p = nengo.Probe(seq.output)

# --- image display
image_shape = kmodel.input_shape[1:]
display_f = image_display_function(image_shape)
display_node = nengo.Node(display_f, size_in=u.size_out)
nengo.Connection(u, display_node, synapse=None)

# --- output spa display
vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR',
'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE']
Expand Down
13 changes: 5 additions & 8 deletions examples/vision/mnist_single_layer_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from nengo_extras.data import load_mnist
from nengo_extras.vision import Gabor, Mask
from nengo_extras.gui import image_display_function
from nengo_extras.gui import PresentImages


def one_hot(labels, c=None):
Expand Down Expand Up @@ -57,20 +57,17 @@ def one_hot(labels, c=None):
presentation_time = 0.1

with nengo.Network(seed=3) as model:
u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
u = nengo.Node(PresentImages(X_test.reshape((-1, 1, 28, 28)),
presentation_time))
u.output.configure_display(offset=1., scale=128.)

a = nengo.Ensemble(n_hid, n_vis, **ens_params)
v = nengo.Node(size_in=n_out)
nengo.Connection(u, a, synapse=None)
conn = nengo.Connection(
a, v, synapse=None,
eval_points=X_train, function=train_targets, solver=solver)

# --- image display
image_shape = (1, 28, 28)
display_f = image_display_function(image_shape, offset=1, scale=128)
display_node = nengo.Node(display_f, size_in=u.size_out)
nengo.Connection(u, display_node, synapse=None)

# --- output spa display
vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR',
'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE']
Expand Down
113 changes: 89 additions & 24 deletions nengo_extras/gui.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,113 @@
def preprocess_display(x, image_shape,
transpose=(1, 2, 0), scale=255., offset=0.):
from __future__ import absolute_import

import nengo

from .convnet import ShapeParam


def preprocess_display(x, transpose=(1, 2, 0), scale=255., offset=0.):
"""Basic preprocessing that reshapes, transposes, and scales an image"""
y = x.reshape(image_shape)
y = (y + offset) * scale
x = (x + offset) * scale
if transpose is not None:
y = y.transpose(transpose) # color channel last
if y.shape[-1] == 1:
y = y[..., 0]
return y.clip(0, 255).astype('uint8')
x = x.transpose(transpose) # color channel last
if x.shape[-1] == 1:
x = x[..., 0]
return x.clip(0, 255).astype('uint8')


def image_display_function(image_shape, preprocess=preprocess_display,
**preprocess_args):
"""Make a function to display images in Nengo GUI
def image_html_function(image_shape, preprocess=preprocess_display,
**preprocess_args):
"""Make a function to turn an image into HTML to display as an SVG.

Examples
--------
>>> u = nengo.Node(nengo.processes.PresentInput(images, 0.1))
>>> display_f = nengo_extras.gui.image_display_function(image_shape)
>>> display_node = nengo.Node(display_f, size_in=u.size_out)
>>> nengo.Connection(u, display_node, synapse=None)
Parameters
----------
image_shape : array_like (3,)
The shape of the image: (channels, height, width)
preprocess : callable
Callable that takes an image and preprocesses it to be displayed.
preprocess_args : dict
Optional dictionary of keyword arguments for ``preprocess``.

Requirements
------------
pillow (provides PIL, `pip install pillow`)
Returns
-------
html_function : callable (t, x)
A function that takes time and a flattened image, and returns a string
that defines an SVG object in HTML to display the image.
"""
import base64
import PIL.Image
import cStringIO

assert len(image_shape) == 3

def display_func(t, x):
y = preprocess(x, image_shape, **preprocess_args)
def html_function(x):
x = x.reshape(image_shape)
y = preprocess(x, **preprocess_args)
png = PIL.Image.fromarray(y)
buffer = cStringIO.StringIO()
png.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue())

display_func._nengo_html_ = '''
return '''
<svg width="100%%" height="100%%" viewbox="0 0 100 100">
<image width="100%%" height="100%%"
xlink:href="data:image/png;base64,%s"
style="image-rendering: pixelated;">
</svg>''' % (''.join(img_str))

return html_function


def image_display_function(image_shape, preprocess=preprocess_display,
**preprocess_args):
"""Make a function to display images in Nengo GUI

Examples
--------
>>> u = nengo.Node(nengo.processes.PresentInput(images, 0.1))
>>> display_f = nengo_extras.gui.image_display_function(image_shape)
>>> display_node = nengo.Node(display_f, size_in=u.size_out)
>>> nengo.Connection(u, display_node, synapse=None)

Requirements
------------
pillow (provides PIL, `pip install pillow`)
"""
html_function = image_html_function(
image_shape, preprocess=preprocess, **preprocess_args)

def display_func(t, x):
display_func._nengo_html_ = html_function(x)

return display_func


class PresentImages(nengo.processes.PresentInput):
"""PresentInput process whose inputs are displayed as images in nengo_gui.
"""

image_shape = ShapeParam('image_shape', length=3, low=1)

def __init__(self, images, presentation_time, **kwargs):
self.image_shape = images.shape[1:]
super(PresentImages, self).__init__(
images, presentation_time, **kwargs)
self.configure_display()

def configure_display(self, preprocess=preprocess_display,
**preprocess_args):
"""Configure display parameters for images

Parameters
----------
preprocess : callable
Callable that takes an image and preprocesses it to be displayed.
preprocess_args : dict
Optional dictionary of keyword arguments for ``preprocess``.
"""
html_function = image_html_function(
self.image_shape, preprocess=preprocess, **preprocess_args)

def _nengo_html_(t, x):
return html_function(x)

self._nengo_html_ = _nengo_html_