From 6a47472dfb297175b6ca02c6602b38aa259d34ed Mon Sep 17 00:00:00 2001 From: Uncertainty Baselines Team Date: Wed, 31 Aug 2022 15:23:52 -0700 Subject: [PATCH] Internal. PiperOrigin-RevId: 471370306 --- uncertainty_baselines/models/__init__.py | 1 + uncertainty_baselines/models/clip.py | 490 +++++++++++++++++++++++ 2 files changed, 491 insertions(+) create mode 100644 uncertainty_baselines/models/clip.py diff --git a/uncertainty_baselines/models/__init__.py b/uncertainty_baselines/models/__init__.py index 0e2a603a1..48a45373f 100644 --- a/uncertainty_baselines/models/__init__.py +++ b/uncertainty_baselines/models/__init__.py @@ -24,6 +24,7 @@ # break the external build. # ============================================================================== from uncertainty_baselines.models import efficientnet_utils +from uncertainty_baselines.models.clip import clip from uncertainty_baselines.models.criteo_mlp import criteo_mlp from uncertainty_baselines.models.efficientnet import efficientnet from uncertainty_baselines.models.efficientnet_batch_ensemble import efficientnet_batch_ensemble diff --git a/uncertainty_baselines/models/clip.py b/uncertainty_baselines/models/clip.py new file mode 100644 index 000000000..e4e69cd08 --- /dev/null +++ b/uncertainty_baselines/models/clip.py @@ -0,0 +1,490 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI's CLIP models in Flax. + +Originally ported from +https://github.com/google-research/scenic/blob/main/scenic/projects/baselines/clip/layers.py. + +This version was adapted from +https://github.com/google-research/scenic/blob/main/scenic/projects/baselines/clip/layers.py. + +[1] Alec Radford, Jong Wook Kim, et al. + Learning Transferable Visual Models From Natural Language Supervision. + https://arxiv.org/abs/2103.00020. +""" + +import functools +from typing import Sequence, Optional, Union, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp + +# TODO(scenic): Make initialization of all layers identical to official one. +# Note: this doesn't matter for loading pretrained models. + +# Match PyTorch default LayerNorm epsilon of 1e-5 (FLAX defaults to 1e-6). +LayerNorm = functools.partial(nn.LayerNorm, epsilon=1e-5) + + +def quick_gelu(x: jnp.ndarray) -> jnp.ndarray: + return x * jax.nn.sigmoid(1.702 * x) + + +class Shortcut(nn.Module): + """Shortcut in ResNet. + + Attributes: + features: Number of features. + stride: Stride of the down-sampled output. + """ + features: int + stride: int + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.avg_pool(x, (self.stride, self.stride), (self.stride, self.stride)) + x = nn.Conv( + self.features, (1, 1), strides=(1, 1), use_bias=False, name='0')(x) + x = nn.BatchNorm(use_running_average=True, name='1')(x) + return x + + +class Bottleneck(nn.Module): + """Bottleneck layer of ResNet. + + Attributes: + features: Number of features. + stride: Stride of the down-sampled output. + expansion: Expansion of feature dimension. + """ + features: int + stride: int = 1 + expansion: int = 4 + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + conv1 = nn.Conv(self.features, (1, 1), use_bias=False, name='conv1') + bn1 = nn.BatchNorm(use_running_average=True, name='bn1') + + conv2 = nn.Conv(self.features, (3, 3), padding=[(1, 1), (1, 1)], + use_bias=False, name='conv2') + bn2 = nn.BatchNorm(use_running_average=True, name='bn2') + + conv3 = nn.Conv( + self.features * self.expansion, (1, 1), use_bias=False, name='conv3') + bn3 = nn.BatchNorm(use_running_average=True, name='bn3') + + out = nn.relu(bn1(conv1(x))) + out = nn.relu(bn2(conv2(out))) + out = nn.avg_pool(out, (self.stride, self.stride), + (self.stride, self.stride)) + out = bn3(conv3(out)) + + downsample = self.stride > 1 or x.shape[-1] != self.features * self.expansion + if downsample: + x = Shortcut(features=self.features * self.expansion, + stride=self.stride, name='downsample')(x) + + out += x + out = nn.relu(out) + return out + + +class AttentionPool(nn.Module): + """Attention pooling layer. + + Attributes: + num_heads: Number of heads. + features: Number of features. + """ + num_heads: int + features: Optional[int] = None + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.reshape(x.shape[0], -1, x.shape[3]) + + x = jnp.concatenate([x.mean(axis=1, keepdims=True), x], axis=1) + + positional_embedding = self.param( + 'positional_embedding', + jax.nn.initializers.normal(1. / x.shape[-1]**0.5), + (x.shape[1], x.shape[2])) + attn = nn.MultiHeadDotProductAttention( + self.num_heads, + qkv_features=x.shape[-1], + use_bias=True, + out_features=self.features, + name='attn') + + x = x + positional_embedding[jnp.newaxis].astype(x.dtype) + x = attn(x[:, :1], x) + return x[:, 0] + + +class ResNetStage(nn.Module): + """Attention pooling layer. + + Attributes: + features: Number of features. + num_layers: Number of bottleneck blocks. + stride: Stride in the Bottleneck module. + """ + features: int + num_layers: int + stride: int = 1 + + @nn.compact + def __call__(self, x: jnp.array) -> jnp.ndarray: + x = Bottleneck(self.features, self.stride, name='0')(x) + for i in range(1, self.num_layers): + x = Bottleneck(self.features, name=str(i))(x) + return x + + +class ModifiedResNet(nn.Module): + """A ResNet class that is similar to torchvision's with changes. + + - There are now 3 "stem" convolutions as opposed to 1, with an average pool + instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is + prepended to convolutions with stride > 1 - The final pooling layer is a + QKV attention instead of an average pool. + + Attributes: + features: Number of features. + out_features: Number of output features. If None, return resnet feature-map. + num_layers: Number of layers for each block. + num_heads: Number of heads. + """ + features: int + out_features: Optional[int] + num_layers: Sequence[int] + num_heads: Optional[int] + + def setup(self): + # The 3-layer stem. + self.conv1 = nn.Conv( + self.features // 2, + kernel_size=(3, 3), + strides=(2, 2), + padding=[(1, 1), (1, 1)], + use_bias=False, + name='conv1') + self.bn1 = nn.BatchNorm(use_running_average=True, name='bn1') + self.conv2 = nn.Conv( + self.features // 2, + kernel_size=(3, 3), + padding=[(1, 1), (1, 1)], + use_bias=False, + name='conv2') + self.bn2 = nn.BatchNorm(use_running_average=True, name='bn2') + self.conv3 = nn.Conv( + self.features, + kernel_size=(3, 3), + padding=[(1, 1), (1, 1)], + use_bias=False, + name='conv3') + self.bn3 = nn.BatchNorm(use_running_average=True, name='bn3') + + # Residual layers. + self.layer1 = ResNetStage(self.features, self.num_layers[0], name='layer1') + self.layer2 = ResNetStage( + self.features * 2, self.num_layers[1], stride=2, name='layer2') + self.layer3 = ResNetStage( + self.features * 4, self.num_layers[2], stride=2, name='layer3') + self.layer4 = ResNetStage( + self.features * 8, self.num_layers[3], stride=2, name='layer4') + if self.out_features is not None: + self.attnpool = AttentionPool( + self.num_heads, self.out_features, name='attnpool') + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = nn.relu(bn(conv(x))) + x = nn.avg_pool(x, (2, 2), (2, 2)) + return x + + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = feature_map = self.layer4(x) + + if self.out_features is not None: + x = self.attnpool(x) + + return x, feature_map + + +class MLP(nn.Module): + """Simple MLP for Transformer.""" + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + ch = x.shape[-1] + x = nn.Dense(4 * ch, name='c_fc')(x) + x = quick_gelu(x) + x = nn.Dense(ch, name='c_proj')(x) + return x + + +class ResidualAttentionBlock(nn.Module): + """Self-attention block of Transformer. + + Attributes: + num_heads: Number of heads. + """ + num_heads: int + + @nn.compact + def __call__(self, x: jnp.ndarray, attn_mask=None) -> jnp.ndarray: + xn = LayerNorm(name='ln_1')(x) + x = x + nn.SelfAttention( + self.num_heads, name='attn', deterministic=True)(xn, attn_mask) + xn = LayerNorm(name='ln_2')(x) + x = x + MLP(name='mlp')(xn) + return x + + +class Transformer(nn.Module): + """Transformer module. + + Attributes: + features: Number of features. + num_layers: Number of layers for each block. + num_heads: Number of heads. + """ + features: int + num_layers: int + num_heads: int + + @nn.compact + def __call__(self, + x: jnp.ndarray, + attn_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + for i in range(self.num_layers): + x = ResidualAttentionBlock( + num_heads=self.num_heads, name=f'resblocks.{i}')(x, attn_mask) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer. + + Attributes: + patch_size: The size of the patches to embed. + features: Number of features. + num_layers: Number of transformer blocks (self-attn + MLP). + num_heads: Number of attention heads. + out_features: Number of output features. If None, return transformer output. + """ + patch_size: int + features: int + num_layers: int + num_heads: int + out_features: int + + @nn.compact + def __call__(self, + x: jnp.ndarray, + attn_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + x = nn.Conv(self.features, + kernel_size=(self.patch_size, self.patch_size), + strides=(self.patch_size, self.patch_size), + use_bias=False, name='conv1')(x) + x = x.reshape(x.shape[0], -1, x.shape[-1]) + scale = 1.0 / jnp.sqrt(self.features) + class_embedding = self.param('class_embedding', + jax.nn.initializers.normal(stddev=scale), + (self.features,)) + x = jnp.concatenate((jnp.tile(class_embedding[None, None, :], + (x.shape[0], 1, 1)), x), + axis=1) + positional_embedding = self.param('positional_embedding', + jax.nn.initializers.normal(stddev=scale), + (x.shape[1], self.features)) + x = x + positional_embedding[None] + + x = LayerNorm(name='ln_pre')(x) + x = Transformer( + self.features, self.num_layers, self.num_heads, name='transformer')(x) + x = LayerNorm(name='ln_post')(x[:, 0]) + x = nn.Dense(self.out_features, use_bias=False, name='proj')(x) + return x + + +class TextEncoder(nn.Module): + """Text Transformer. + + Attributes: + vocab_size: Size of the vocabulary. + features: Number of features. + num_layers: Number of transformer blocks (self-attn + MLP). + num_heads: Number of attention heads. + out_features: Size of the final text embedding. + """ + vocab_size: int + features: int + num_layers: int + num_heads: int + out_features: int + + @nn.compact + def __call__(self, text: jnp.ndarray) -> jnp.ndarray: + positional_embedding = self.param('positional_embedding', + jax.nn.initializers.zeros, + (text.shape[1], self.features)) + mask = nn.combine_masks( + nn.make_attention_mask(text > 0, text > 0), nn.make_causal_mask(text)) + x = nn.Embed(self.vocab_size, self.features, name='token_embedding')(text) + x = x + positional_embedding[None] + x = Transformer( + self.features, self.num_layers, self.num_heads, name='transformer')( + x, attn_mask=mask) + x = LayerNorm(name='ln_final')(x) + x = x[jnp.arange(x.shape[0]), text.argmax(-1)] + x = nn.Dense(self.out_features, use_bias=False, name='text_projection')(x) + return x + + +class CLIP(nn.Module): + """CLIP model consisting of a vision and text transformer. + + Attributes: + vocab_size: Size of the vocabulary. + embed_dim: Size of the text and vision embeddings. + text_features: Number of features in text transformer. + text_num_layers: Number of text transformer blocks (self-attn + MLP). + text_num_heads: Number of heads in text transformer. + vision_features: Number of features in vision transformer. + vision_num_layers: Number of vision transformer blocks (self-attn + MLP). + vision_patch_size: Size of patches to embed in vision transformer. + """ + vocab_size: int + embed_dim: int + text_features: int + text_num_layers: int + text_num_heads: int + vision_features: int + vision_num_layers: Union[int, Sequence[int]] + vision_patch_size: Optional[int] = None + + def setup(self): + if isinstance(self.vision_num_layers, (tuple, list)): + self.vision_num_heads = self.vision_features * 32 // 64 + self.visual = ModifiedResNet( + num_layers=self.vision_num_layers, + features=self.vision_features, + num_heads=self.vision_num_heads, + out_features=self.embed_dim) + else: + self.vision_num_heads = self.vision_features // 64 + self.visual = VisionTransformer( + patch_size=self.vision_patch_size, + features=self.vision_features, + num_layers=self.vision_num_layers, + num_heads=self.vision_num_heads, + out_features=self.embed_dim) + self.text = TextEncoder( + out_features=self.embed_dim, + vocab_size=self.vocab_size, + features=self.text_features, + num_layers=self.text_num_layers, + num_heads=self.text_num_heads) + self.logit_scale = self.param('logit_scale', jax.nn.initializers.ones, ()) + + def encode_image(self, + image: jnp.ndarray, + normalize: bool = True, + scale_logits: bool = True) -> jnp.ndarray: + x = self.visual(image) + if normalize: + x /= jnp.linalg.norm(x, axis=-1, keepdims=True) + if scale_logits: + # Note: we use sqrt(t) rather than t, because we apply the scaling to zimg + # and ztxt individually, rather than to logits = jnp.dot(zimg, ztxt.T) as + # in the Scenic implementation. + # jnp.dot(sqrt(t)*zimg, sqrt(t)*ztxt.T) = jnp.dot(zimg, ztxt.T) * t + x *= jnp.sqrt(jnp.exp(self.logit_scale)) + return x + + def encode_text(self, + text: jnp.ndarray, + normalize: bool = True, + scale_logits: bool = True) -> jnp.ndarray: + x = self.text(text) + if normalize: + x /= jnp.linalg.norm(x, axis=-1, keepdims=True) + if scale_logits: + x *= jnp.sqrt(jnp.exp(self.logit_scale)) + return x + + def __call__(self, + image: jnp.ndarray, + text: jnp.ndarray, + normalize: bool = True, + scale_logits: bool = True, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + x = y = None + if image is not None: + x = self.encode_image(image, normalize, scale_logits) + if text is not None: + y = self.encode_text(text, normalize, scale_logits) + return x, y + + +def clip( + vocab_size: int, + embed_dim: int, + text_features: int, + text_num_layers: int, + text_num_heads: int, + vision_features: int, + vision_num_layers: Union[int, Sequence[int]], + vision_patch_size: Optional[int] = None, +): + """Builds a CLIP (Contrastive Language-Image Pre-Training) model. + + The model consists of a vision and a text component. The vision component is + a ResNet or a Vision Transformer. The text component is a text Transformer. + + Args: + vocab_size: Size of the vocabulary. + embed_dim: Size of the text and vision embeddings. + text_features: Number of features in text transformer. + text_num_layers: Number of text transformer blocks (self-attn + MLP). + text_num_heads: Number of heads in text transformer. + vision_features: Number of features in vision transformer. + vision_num_layers: Number of vision transformer blocks (self-attn + MLP). + vision_patch_size: Size of patches to embed in vision transformer. + + Returns: + flax.linen.Module. + """ + return CLIP( + vocab_size=vocab_size, + embed_dim=embed_dim, + text_features=text_features, + text_num_layers=text_num_layers, + text_num_heads=text_num_heads, + vision_features=vision_features, + vision_num_layers=vision_num_layers, + vision_patch_size=vision_patch_size)