diff --git a/experimental/shoshin/models.py b/experimental/shoshin/models.py index 0ffa2f9f2..b9c76ed37 100644 --- a/experimental/shoshin/models.py +++ b/experimental/shoshin/models.py @@ -71,6 +71,63 @@ class ModelTrainingParameters: reweighting_error_percentile_threshold: Optional[float] = 0.2 +@register_model('two_tower_resnet50v2') +class TwoTowerResnet50v2(tf.keras.Model): + """Two tower model based on Resnet50v2.""" + + def __init__( + self, model_params: ModelTrainingParameters + ): + super(TwoTowerResnet50v2, self).__init__(name=model_params.model_name) + self.backbone = tf.keras.applications.resnet50.ResNet50( + include_top=False, + # classes=2, + weights='imagenet', # Also set to None. + input_shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 3), + # input_tensor=None, + pooling='avg', + ) + + inputs = tf.keras.Input((RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6)) + before = inputs[:, :, :, :3] + after = inputs[:, :, :, 3:6] + after_crop = tf.image.resize( + tf.image.central_crop(after, 64 / RESNET_IMAGE_SIZE), + [RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE], + method='bilinear', + ) + before_embedding = self.backbone(before) + after_embedding = self.backbone(after) + after_crop_embedding = self.backbone(after_crop) + combined = tf.concat( + [before_embedding, after_embedding, after_crop_embedding], axis=1 + ) + outputs = tf.keras.layers.Dropout(0.5)(combined) + outputs = tf.keras.layers.Dense(units=256, activation='relu')(outputs) + outputs = tf.keras.layers.Dropout(0.5)(outputs) + outputs = tf.keras.layers.Dense(units=64, activation='relu')(outputs) + outputs = tf.keras.layers.Dropout(0.5)(outputs) + outputs = tf.keras.layers.Dense( + units=model_params.num_classes, activation='sigmoid' + )(outputs) + self.backbone.trainable = False + self.backbone.layers[-1].trainable = True + self.model = tf.keras.Model(inputs, outputs) + self.output_bias = tf.keras.layers.Dense( + model_params.num_classes, + trainable=model_params.train_bias, + activation='softmax', + name='bias', + ) + + def call(self, inputs): + x_1 = self.backbone(inputs[:,:, :, :3]) + x_2 = self.backbone(inputs[:,:, :, 3:]) + x = tf.concat([x_1, x_2], axis=-1) + out_bias = self.output_bias(x) + return {'main': self.model(inputs), 'bias': out_bias} + + @register_model('mlp') class MLP(tf.keras.Model): """Defines a MLP model class with two output heads.