-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
79 lines (66 loc) · 2.83 KB
/
utils.py
File metadata and controls
79 lines (66 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
try:
import umap
except ImportError:
umap = None
def get_device():
"""Consistently returns the best available hardware device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def get_mnist_dataloaders(batch_size=128, data_dir='./data', download=True):
"""
Returns train and test dataloaders for MNIST.
Used by both train_ae.py and train_vae.py
"""
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root=data_dir, train=True, transform=transform, download=download)
test_dataset = datasets.MNIST(root=data_dir, train=False, transform=transform, download=download)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader, train_dataset, test_dataset
def tensor_to_image(tensor):
"""
Converts a PyTorch tensor (C, H, W) in the range [0, 1]
into a NumPy array suitable for displaying in Streamlit or Matplotlib.
"""
# Detach from graph, move to CPU, convert to numpy
image = tensor.detach().cpu().numpy()
# If it's a single channel image (like MNIST), reshape it from (1, 28, 28) to (28, 28)
if image.ndim == 3 and image.shape[0] == 1:
image = image.squeeze(0)
return image
def get_2d_projections(embeddings, cluster_centers=None, method="PCA"):
"""
Reduces 64D embeddings to 2D using PCA, t-SNE, or UMAP.
If cluster_centers are provided (like in a GM-VAE), they are transformed alongside the data.
"""
if cluster_centers is not None:
n_samples = embeddings.shape[0]
all_points = np.vstack([embeddings, cluster_centers])
else:
all_points = embeddings
if "PCA" in method:
reducer = PCA(n_components=2)
elif "t-SNE" in method:
reducer = TSNE(n_components=2, perplexity=30, random_state=42)
elif "UMAP" in method:
if umap is None:
raise ImportError("UMAP is not installed. Please run `uv add umap-learn` in your terminal.")
reducer = umap.UMAP(n_components=2, random_state=42)
else:
raise ValueError(f"Unknown reduction method: {method}")
coords_2d_all = reducer.fit_transform(all_points)
if cluster_centers is not None:
coords_2d_data = coords_2d_all[:n_samples]
coords_2d_centers = coords_2d_all[n_samples:]
return coords_2d_data, coords_2d_centers
return coords_2d_all, None