-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmodel.py
More file actions
70 lines (52 loc) · 2.68 KB
/
model.py
File metadata and controls
70 lines (52 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from tensorflow.keras import layers
from tensorflow import keras
def get_model(moving_image_shape, fixed_image_shape, with_label_inputs=True, up_filters=[64, 128, 256], down_filters=[256, 128, 64, 32]):
input_moving_image = keras.Input(moving_image_shape)
input_fixed_image = keras.Input(fixed_image_shape)
if with_label_inputs:
input_moving_label = keras.Input(moving_image_shape)
input_fixed_label = keras.Input(fixed_image_shape)
concatenate_layer = layers.Concatenate(axis=-1)([input_moving_image, input_fixed_image])
### [First half of the network: downsampling inputs] ###
# Entry block
x = layers.Conv3D(32, 3, strides=2, padding="same")(concatenate_layer)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
# Blocks 1, 2, 3 are identical apart from the feature depth.
for filters in up_filters:
x = layers.Activation("relu")(x)
x = layers.Conv3D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv3D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling3D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv3D(filters, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
### [Second half of the network: upsampling inputs] ###
for filters in down_filters:
x = layers.Activation("relu")(x)
x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling3D(2)(x)
# Project residual
residual = layers.UpSampling3D(2)(previous_block_activation)
residual = layers.Conv3D(filters, 1, padding="same")(residual)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
# Add a per-pixel classification layer
out_ddf = layers.Conv3D(3, 3, activation="linear", padding="same")(x)
# Define the model
if with_label_inputs:
model = keras.Model(inputs=[input_moving_image, input_fixed_image, input_moving_label, input_fixed_label], outputs=[out_ddf])
else:
model = keras.Model(inputs=[input_moving_image, input_fixed_image], outputs=[out_ddf])
return model