forked from JeongsooP/Community-Forensics
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
50 lines (47 loc) · 2.35 KB
/
models.py
File metadata and controls
50 lines (47 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import timm
from huggingface_hub import PyTorchModelHubMixin
class ViTClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self,
model_size="small",
input_size=384,
patch_size=16,
freeze_backbone=False,
device='cuda', dtype=torch.float32):
"""
ViT Classifier based on huggingface timm module
"""
super(ViTClassifier, self).__init__()
#self.args = args
self.device=device
self.dtype=dtype
if model_size=="small":
if input_size==224:
if patch_size==32:
self.vit = timm.create_model('vit_small_patch32_224.augreg_in21k_ft_in1k', pretrained=True).to(device)
elif patch_size==16:
self.vit = timm.create_model('vit_small_patch16_224.augreg_in21k_ft_in1k', pretrained=True).to(device)
elif input_size==384:
if patch_size==32:
self.vit = timm.create_model('vit_small_patch32_384.augreg_in21k_ft_in1k', pretrained=True).to(device)
elif patch_size==16:
self.vit = timm.create_model('vit_small_patch16_384.augreg_in21k_ft_in1k', pretrained=True).to(device)
if freeze_backbone:
for param in self.vit.parameters():
param.requires_grad = False
self.vit.head = nn.Linear(in_features=384, out_features=1, bias=True, device=device, dtype=dtype)
elif model_size=="tiny":
assert patch_size==16, "Only patch size 16 is available for ViT-Ti"
if input_size==224:
self.vit = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True).to(device)
elif input_size==384:
self.vit = timm.create_model('vit_tiny_patch16_384.augreg_in21k_ft_in1k', pretrained=True).to(device)
if freeze_backbone:
for param in self.vit.parameters():
param.requires_grad = False
self.vit.head = nn.Linear(in_features=192, out_features=1, bias=True, device=device, dtype=dtype)
for param in self.vit.head.parameters():
assert param.requires_grad==True, "Model head should be trainable."
def forward(self, x):
return self.vit(x)