-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
73 lines (69 loc) · 2.48 KB
/
utils.py
File metadata and controls
73 lines (69 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pathlib import Path
import clip
import cub_dataset
import imagenetv2_pytorch
import places365_dataset
import stanford_dogs_dataset
import torchvision
from torch.utils.data import DataLoader
def get_dataloader(config):
_, preprocess = clip.load(config["backbone"], device="cpu")
if config["dataset"] == "cub":
dataset = cub_dataset.Cub2011(
root=config["root"],
train=False,
transform=preprocess,
)
elif config["dataset"] == "dtd":
dataset = torchvision.datasets.DTD(
root=config["root"], split="test", transform=preprocess
)
elif config["dataset"] == "eurosat":
dataset = torchvision.datasets.EuroSAT(
root=config["root"], transform=preprocess
)
elif config["dataset"] == "fgvc_aircraft":
dataset = torchvision.datasets.FGVCAircraft(
root=config["root"], transform=preprocess, split="test"
)
elif config["dataset"] == "flowers":
dataset = torchvision.datasets.Flowers102(
root=config["root"], transform=preprocess, split="test"
)
elif config["dataset"] == "food101":
dataset = torchvision.datasets.Food101(
root=config["root"], transform=preprocess, split="test"
)
elif config["dataset"] == "imagenet":
dataset = torchvision.datasets.ImageNet(
root=config["root"], transform=preprocess, split="val"
)
elif config["dataset"] == "imagenet_v2":
dataset = imagenetv2_pytorch.ImageNetV2Dataset(
location=config["root"],
transform=preprocess,
)
elif config["dataset"] == "pets":
dataset = torchvision.datasets.OxfordIIITPet(
root=config["root"], transform=preprocess, split="test"
)
elif config["dataset"] == "places365":
dataset = places365_dataset.Places365(
root=config["root"], transform=preprocess, split="val"
)
elif config["dataset"] == "stanford_cars":
dataset = torchvision.datasets.StanfordCars(
root=config["root"], transform=preprocess, split="test"
)
elif config["dataset"] == "stanford_dogs":
dataset = stanford_dogs_dataset.StanfordDogs(
root=config["root"], transform=preprocess, split="test"
)
else:
raise ValueError
dataloader = DataLoader(
dataset=dataset,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
)
return dataloader