Skip to content

Commit c254432

Browse files
Add depthwise separable convolutions
1 parent 6cb28d0 commit c254432

6 files changed

Lines changed: 69 additions & 30 deletions

File tree

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
120
1212
],
1313
"cSpell.language": "en-GB",
14-
"yaml.format.printWidth": 120
14+
"yaml.format.printWidth": 200
1515
}

example_net.yaml

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,32 +104,19 @@ network:
104104
# This variable name if placed anywhere in the structure options will be replaced with the integer number of outputs
105105
# the dataset will produce
106106
structure:
107-
g1: { op: GraphConvolution, inputs: [X, G] }
108-
d1: { op: Dense, inputs: [g1], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
109-
g2: { op: GraphConvolution, inputs: [d1, G] }
110-
d2: { op: Dense, inputs: [g2], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
111-
g3: { op: GraphConvolution, inputs: [d2, G] }
112-
d3: { op: Dense, inputs: [g3], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
113-
g4: { op: GraphConvolution, inputs: [d3, G] }
114-
d4: { op: Dense, inputs: [g4], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
115-
g5: { op: GraphConvolution, inputs: [d4, G] }
116-
d5: { op: Dense, inputs: [g5], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
117-
g6: { op: GraphConvolution, inputs: [d5, G] }
118-
d6: { op: Dense, inputs: [g6], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
119-
g7: { op: GraphConvolution, inputs: [d6, G] }
120-
d7: { op: Dense, inputs: [g7], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
121-
g8: { op: GraphConvolution, inputs: [d7, G] }
122-
d8: { op: Dense, inputs: [g8], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
123-
g9: { op: GraphConvolution, inputs: [d8, G] }
124-
d9: { op: Dense, inputs: [g9], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
125-
g10: { op: GraphConvolution, inputs: [d9, G] }
126-
d10: { op: Dense, inputs: [g10], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
127-
g11: { op: GraphConvolution, inputs: [d10, G] }
128-
d11: { op: Dense, inputs: [g11], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
129-
g12: { op: GraphConvolution, inputs: [d11, G] }
130-
d12: { op: Dense, inputs: [g12], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
131-
g13: { op: GraphConvolution, inputs: [d12, G] }
132-
output: { op: Dense, inputs: [g13], options: { units: $output_dims, activation: softmax } }
107+
l1: { op: GraphConvolution, inputs: [X, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
108+
l2: { op: GraphConvolution, inputs: [l1, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
109+
l3: { op: GraphConvolution, inputs: [l2, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
110+
l4: { op: GraphConvolution, inputs: [l3, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
111+
l5: { op: GraphConvolution, inputs: [l4, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
112+
l6: { op: GraphConvolution, inputs: [l5, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
113+
l7: { op: GraphConvolution, inputs: [l6, G], options: { units: 16, activation: selu, kernel_initializer: lecun_normal } }
114+
l8: { op: GraphConvolution, inputs: [l7, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
115+
l9: { op: GraphConvolution, inputs: [l8, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
116+
l10: { op: GraphConvolution, inputs: [l9, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
117+
l11: { op: GraphConvolution, inputs: [l10, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
118+
l12: { op: GraphConvolution, inputs: [l11, G], options: { units: 8, activation: selu, kernel_initializer: lecun_normal } }
119+
l13: { op: GraphConvolution, inputs: [l12, G], options: { units: $output_dims, activation: softmax } }
133120

134121
# Testing
135122
testing:

training/layer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1515

1616
from .graph_convolution import GraphConvolution
17+
from .depthwise_seperable_graph_convolution import DepthwiseSeparableGraphConvolution
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (C) 2017-2020 Trent Houliston <trent@houliston.me>
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
6+
# permit persons to whom the Software is furnished to do so, subject to the following conditions:
7+
#
8+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
9+
# Software.
10+
#
11+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
12+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
13+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
14+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15+
16+
import tensorflow as tf
17+
18+
19+
class Depthwise(tf.keras.layers.Layer):
20+
def __init__(self, **kwargs):
21+
super(Depthwise, self).__init__()
22+
self.pointwise = tf.keras.layers.Dense(**kwargs)
23+
24+
def build(self, input_shape):
25+
# Copy whatever we have on our pointwise kernel
26+
self.depthwise_weights = self.add_weight(
27+
"depthwise_kernel",
28+
input_shape[1:],
29+
dtype=self.dtype,
30+
initializer=self.pointwise.kernel_initializer,
31+
regularizer=self.pointwise.kernel_regularizer,
32+
constraint=self.pointwise.kernel_constraint,
33+
)
34+
35+
def call(self, X):
36+
depthwise = tf.einsum("ijk,jk->ik", X, self.depthwise_weights)
37+
return self.pointwise(depthwise)
38+
39+
40+
class DepthwiseSeparableGraphConvolution(tf.keras.layers.Layer):
41+
def __init__(self, **kwargs):
42+
super(DepthwiseSeparableGraphConvolution, self).__init__()
43+
self.depthwise = Depthwise(**kwargs)
44+
45+
def call(self, X, G):
46+
convolved = tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, G.shape[-1], X.shape[-1]])
47+
return self.depthwise(convolved)

training/layer/graph_convolution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
class GraphConvolution(tf.keras.layers.Layer):
2020
def __init__(self, **kwargs):
21-
super(GraphConvolution, self).__init__(**kwargs)
21+
super(GraphConvolution, self).__init__()
22+
self.dense = tf.keras.layers.Dense(**kwargs)
2223

2324
def call(self, X, G):
24-
return tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, X.shape[-1] * G.shape[-1]])
25+
# Call the dense layer with the gathered data
26+
return self.dense(tf.reshape(tf.gather(X, G, name="NetworkGather"), shape=[-1, X.shape[-1] * G.shape[-1]]))

training/model/visual_mesh_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1515

1616
import tensorflow as tf
17-
from training.layer import GraphConvolution
17+
from training.layer import GraphConvolution, DepthwiseSeparableGraphConvolution
1818

1919

2020
class VisualMeshModel(tf.keras.Model):
@@ -40,6 +40,8 @@ def _make_op(self, op, options):
4040

4141
if op == "GraphConvolution":
4242
return GraphConvolution(**options)
43+
elif op == "DepthwiseSeparableGraphConvolution":
44+
return DepthwiseSeparableGraphConvolution(**options)
4345
elif hasattr(tf.keras.layers, op):
4446
return getattr(tf.keras.layers, op)(**options)
4547
else:

0 commit comments

Comments
 (0)