-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
65 lines (52 loc) · 2.06 KB
/
utils.py
File metadata and controls
65 lines (52 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Literal
import numpy as np
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
def normalize_fast(x:np.ndarray, pmin:float=1, pmax:float=99.8):
if not x.ndim==4 and x.shape[1]<=3:
raise ValueError(f"x.ndim = {x.ndim} but expected 4 and channel dimension might be off")
mi, ma = np.percentile(x[:,:,::8,::8], (pmin, pmax), axis=(2,3), keepdims=True)
x = x.astype(np.float32)
x -= mi
x /= ma-mi+1e-6
return x
def load_data(dataset:Literal['flywing', 'retina'], data_dir="/scratch/denbi/k8s/ADL4IA_flexprojects/flexproject2/data"):
print(f'Loading {dataset} dataset... ')
data_dir = Path(data_dir) if data_dir is not None else Path(".")
if dataset == 'flywing':
x = np.load(data_dir/'Flywing_n0/train/train_data.npz')['X_train']
x = np.expand_dims(x, axis=1)
elif dataset == 'retina':
x = np.load(data_dir/'Isotropic_Retina/train_data/data_label.npz')['Y']
else:
raise ValueError(f"Dataset {dataset} not found")
print(f'Loaded image array with shape {x.shape}')
return x
def to_rgb(x):
# (T)CHW
assert x.ndim == 3 or x.ndim == 4
if x.ndim == 4:
return np.stack([to_rgb(s) for s in x])
if len(x)==1:
x = np.concatenate(3*[x], axis=0)
elif len(x)==2:
x = np.stack([x[0], x[1], x[0]], axis=0)
else:
x = x[:3]
return x
def render_letter_to_array(letter, image_size=(64, 64), size=None):
if size is None:
size = min(image_size)//2
image = Image.new("L", image_size[::-1], color=255)
font = ImageFont.load_default(size=size)
draw = ImageDraw.Draw(image)
text_size = draw.textbbox((0, 0), letter, font=font)
text_width = text_size[2] - text_size[0]
text_height = text_size[3] - text_size[1]
text_x = (image_size[1] - text_width) // 2
text_y = (image_size[0] - text_height) // 2
draw.text((text_x, text_y), letter, fill=0, font=font)
x = np.array(image)
x = x.astype(np.float32)/255
x = np.stack(3*[x], axis=0)
return x