forked from OFSkean/FroSSL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_linear.py
More file actions
241 lines (210 loc) · 8.64 KB
/
main_linear.py
File metadata and controls
241 lines (210 loc) · 8.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import inspect
import logging
import os
import hydra
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from omegaconf import DictConfig, OmegaConf
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from solo.args.linear import parse_cfg
from solo.data.classification_dataloader import prepare_data
from solo.data.pretrain_dataloader import NCropAugmentation, build_transform_pipeline
from solo.methods.bagoffeatures import BagOfFeaturesModel
from solo.methods.base import BaseMethod
from solo.methods.linear import LinearModel
from solo.utils.auto_resumer import AutoResumer
from solo.utils.checkpointer import Checkpointer
from solo.utils.misc import make_contiguous
try:
from solo.data.dali_dataloader import ClassificationDALIDataModule
except ImportError:
_dali_avaliable = False
else:
_dali_avaliable = True
@hydra.main(version_base="1.2")
def main(cfg: DictConfig):
# hydra doesn't allow us to add new keys for "safety"
# set_struct(..., False) disables this behavior and allows us to add more parameters
# without making the user specify every single thing about the model
OmegaConf.set_struct(cfg, False)
cfg = parse_cfg(cfg)
backbone_model = BaseMethod._BACKBONES[cfg.backbone.name]
# initialize backbone
backbone = backbone_model(method=cfg.pretrain_method, **cfg.backbone.kwargs)
if cfg.backbone.name.startswith("resnet"):
# remove fc layer
backbone.fc = nn.Identity()
cifar = cfg.data.dataset in ["cifar10", "cifar100"]
if cifar:
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
backbone.maxpool = nn.Identity()
elif cfg.data.dataset[-3:] == "msi": #TODO adapt kernel size, ..
backbone.conv1 = nn.Conv2d(
13, 64, kernel_size=7, stride=2, padding=3, bias=False
)
ckpt_path = cfg.get("pretrained_feature_extractor", None)
# if no path, read a path from the last_ckpt file
if ckpt_path == "None" or ckpt_path == None or ckpt_path == '':
with open("last_ckpt.txt", "r") as f:
ckpt_path = f.read().strip()
# delete last_ckpt file
os.remove("last_ckpt.txt")
print(ckpt_path)
assert ckpt_path.endswith(".ckpt") or ckpt_path.endswith(".pth") or ckpt_path.endswith(".pt")
state = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state.keys()):
if "encoder" in k:
state[k.replace("encoder", "backbone")] = state[k]
logging.warn(
"You are using an older checkpoint. Use a new one as some issues might arrise."
)
if "backbone" in k:
state[k.replace("backbone.", "")] = state[k]
del state[k]
backbone.load_state_dict(state, strict=False)
logging.info(f"Loaded {ckpt_path}")
# check if mixup or cutmix is enabled
mixup_func = None
mixup_active = cfg.mixup > 0 or cfg.cutmix > 0
if mixup_active:
logging.info("Mixup activated")
mixup_func = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode="batch",
label_smoothing=cfg.label_smoothing,
num_classes=cfg.data.num_classes,
)
# smoothing is handled with mixup label transform
loss_func = SoftTargetCrossEntropy()
elif cfg.label_smoothing > 0:
loss_func = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing)
else:
loss_func = torch.nn.CrossEntropyLoss()
if cfg.pretrain_method == "empssl":
aug_cfg = cfg.augmentations[0]
train_pipeline = NCropAugmentation(
build_transform_pipeline(cfg.data.dataset, aug_cfg), 20
)
val_pipeline = NCropAugmentation(
build_transform_pipeline(cfg.data.dataset, aug_cfg), 20
)
modelClass = BagOfFeaturesModel
else:
modelClass = LinearModel
train_pipeline = None
val_pipeline = None
model = modelClass(backbone, loss_func=loss_func, mixup_func=mixup_func, cfg=cfg)
make_contiguous(model)
# can provide up to ~20% speed up
if not cfg.performance.disable_channel_last:
model = model.to(memory_format=torch.channels_last)
if cfg.data.format == "dali":
val_data_format = "image_folder"
else:
val_data_format = cfg.data.format
train_loader, val_loader = prepare_data(
cfg.data.dataset,
train_data_path=cfg.data.train_path,
val_data_path=cfg.data.val_path,
data_format=val_data_format,
batch_size=cfg.optimizer.batch_size,
num_workers=cfg.data.num_workers,
auto_augment=cfg.auto_augment,
train_pipeline = train_pipeline,
val_pipeline = val_pipeline,
precompute_features=cfg.precompute,
model=model,
)
if cfg.data.format == "dali":
assert (
_dali_avaliable
), "Dali is not currently avaiable, please install it first with pip3 install .[dali]."
assert not cfg.auto_augment, "Auto augmentation is not supported with Dali."
dali_datamodule = ClassificationDALIDataModule(
dataset=cfg.data.dataset,
train_data_path=cfg.data.train_path,
val_data_path=cfg.data.val_path,
num_workers=cfg.data.num_workers,
batch_size=cfg.optimizer.batch_size,
data_fraction=cfg.data.fraction,
dali_device=cfg.dali.device,
)
# use normal torchvision dataloader for validation to save memory
dali_datamodule.val_dataloader = lambda: val_loader
# 1.7 will deprecate resume_from_checkpoint, but for the moment
# the argument is the same, but we need to pass it as ckpt_path to trainer.fit
ckpt_path, wandb_run_id = None, None
if cfg.auto_resume.enabled and cfg.resume_from_checkpoint is None:
auto_resumer = AutoResumer(
checkpoint_dir=os.path.join(cfg.checkpoint.dir, "linear"),
max_hours=cfg.auto_resume.max_hours,
)
resume_from_checkpoint, wandb_run_id = auto_resumer.find_checkpoint(cfg)
if resume_from_checkpoint is not None:
print(
"Resuming from previous checkpoint that matches specifications:",
f"'{resume_from_checkpoint}'",
)
ckpt_path = resume_from_checkpoint
elif cfg.resume_from_checkpoint is not None:
ckpt_path = cfg.resume_from_checkpoint
del cfg.resume_from_checkpoint
callbacks = []
if cfg.checkpoint.enabled:
ckpt = Checkpointer(
cfg,
logdir=os.path.join(cfg.checkpoint.dir, "linear"),
frequency=cfg.checkpoint.frequency,
keep_prev=cfg.checkpoint.keep_prev,
)
callbacks.append(ckpt)
# wandb logging
if cfg.wandb.enabled:
wandb_logger = WandbLogger(
name=cfg.name,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
offline=cfg.wandb.offline,
resume="allow" if wandb_run_id else None,
id=wandb_run_id,
)
wandb_logger.watch(model, log="gradients", log_freq=100)
wandb_logger.log_hyperparams(OmegaConf.to_container(cfg))
# lr logging
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
trainer_kwargs = OmegaConf.to_container(cfg)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(Trainer.__init__).parameters
trainer_kwargs = {name: trainer_kwargs[name] for name in valid_kwargs if name in trainer_kwargs}
trainer_kwargs.update(
{
"logger": wandb_logger if cfg.wandb.enabled else None,
"callbacks": callbacks,
"enable_checkpointing": False,
"strategy": DDPStrategy(find_unused_parameters=False)
if cfg.strategy == "ddp"
else cfg.strategy,
}
)
trainer = Trainer(**trainer_kwargs)
if cfg.data.format == "dali":
trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
with open("last_ckpt.txt", "w") as f:
if hasattr(ckpt, "last_ckpt"):
f.write(str(ckpt.last_ckpt))
else:
f.write(str(ckpt_path))
if __name__ == "__main__":
main()