Skip to content

Commit

Permalink
Add depthwise separable convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
TrentHouliston committed Sep 4, 2020
1 parent 6cb28d0 commit c254432
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
120
],
"cSpell.language": "en-GB",
"yaml.format.printWidth": 120
"yaml.format.printWidth": 200
}
39 changes: 13 additions & 26 deletions example_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,19 @@ network:
# This variable name if placed anywhere in the structure options will be replaced with the integer number of outputs
# the dataset will produce
structure:
g1: { op: GraphConvolution, inputs: [X, G] }
d1: { op: Dense, inputs: [g1], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g2: { op: GraphConvolution, inputs: [d1, G] }
d2: { op: Dense, inputs: [g2], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g3: { op: GraphConvolution, inputs: [d2, G] }
d3: { op: Dense, inputs: [g3], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g4: { op: GraphConvolution, inputs: [d3, G] }
d4: { op: Dense, inputs: [g4], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g5: { op: GraphConvolution, inputs: [d4, G] }
d5: { op: Dense, inputs: [g5], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g6: { op: GraphConvolution, inputs: [d5, G] }
d6: { op: Dense, inputs: [g6], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g7: { op: GraphConvolution, inputs: [d6, G] }
d7: { op: Dense, inputs: [g7], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
g8: { op: GraphConvolution, inputs: [d7, G] }
d8: { op: Dense, inputs: [g8], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
g9: { op: GraphConvolution, inputs: [d8, G] }
d9: { op: Dense, inputs: [g9], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
g10: { op: GraphConvolution, inputs: [d9, G] }
d10: { op: Dense, inputs: [g10], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
g11: { op: GraphConvolution, inputs: [d10, G] }
d11: { op: Dense, inputs: [g11], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
g12: { op: GraphConvolution, inputs: [d11, G] }
d12: { op: Dense, inputs: [g12], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
g13: { op: GraphConvolution, inputs: [d12, G] }
output: { op: Dense, inputs: [g13], options: { units: $output_dims, activation: softmax } }
l1: { op: GraphConvolution, inputs: [X, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l2: { op: GraphConvolution, inputs: [l1, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l3: { op: GraphConvolution, inputs: [l2, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l4: { op: GraphConvolution, inputs: [l3, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l5: { op: GraphConvolution, inputs: [l4, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l6: { op: GraphConvolution, inputs: [l5, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l7: { op: GraphConvolution, inputs: [l6, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
l8: { op: GraphConvolution, inputs: [l7, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
l9: { op: GraphConvolution, inputs: [l8, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
l10: { op: GraphConvolution, inputs: [l9, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
l11: { op: GraphConvolution, inputs: [l10, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
l12: { op: GraphConvolution, inputs: [l11, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
l13: { op: GraphConvolution, inputs: [l12, G], options: { units: $output_dims, activation: softmax } }

# Testing
testing:
Expand Down
1 change: 1 addition & 0 deletions training/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from .graph_convolution import GraphConvolution
from .depthwise_seperable_graph_convolution import DepthwiseSeparableGraphConvolution
47 changes: 47 additions & 0 deletions training/layer/depthwise_seperable_graph_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (C) 2017-2020 Trent Houliston <[email protected]>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import tensorflow as tf


class Depthwise(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(Depthwise, self).__init__()
self.pointwise = tf.keras.layers.Dense(**kwargs)

def build(self, input_shape):
# Copy whatever we have on our pointwise kernel
self.depthwise_weights = self.add_weight(
"depthwise_kernel",
input_shape[1:],
dtype=self.dtype,
initializer=self.pointwise.kernel_initializer,
regularizer=self.pointwise.kernel_regularizer,
constraint=self.pointwise.kernel_constraint,
)

def call(self, X):
depthwise = tf.einsum("ijk,jk->ik", X, self.depthwise_weights)
return self.pointwise(depthwise)


class DepthwiseSeparableGraphConvolution(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(DepthwiseSeparableGraphConvolution, self).__init__()
self.depthwise = Depthwise(**kwargs)

def call(self, X, G):
convolved = tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, G.shape[-1], X.shape[-1]])
return self.depthwise(convolved)
6 changes: 4 additions & 2 deletions training/layer/graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

class GraphConvolution(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(GraphConvolution, self).__init__(**kwargs)
super(GraphConvolution, self).__init__()
self.dense = tf.keras.layers.Dense(**kwargs)

def call(self, X, G):
return tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, X.shape[-1] * G.shape[-1]])
# Call the dense layer with the gathered data
return self.dense(tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, X.shape[-1] * G.shape[-1]]))
4 changes: 3 additions & 1 deletion training/model/visual_mesh_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import tensorflow as tf
from training.layer import GraphConvolution
from training.layer import GraphConvolution, DepthwiseSeparableGraphConvolution


class VisualMeshModel(tf.keras.Model):
Expand All @@ -40,6 +40,8 @@ def _make_op(self, op, options):

if op == "GraphConvolution":
return GraphConvolution(**options)
elif op == "DepthwiseSeparableGraphConvolution":
return DepthwiseSeparableGraphConvolution(**options)
elif hasattr(tf.keras.layers, op):
return getattr(tf.keras.layers, op)(**options)
else:
Expand Down

0 comments on commit c254432

Please sign in to comment.