From 315535493244f0c1eabc2decdd4998aae8c94958 Mon Sep 17 00:00:00 2001 From: Michael Liang Date: Mon, 22 Jul 2024 15:39:22 -0400 Subject: [PATCH 1/3] confidence score --- spatial_server/confidence.py | 94 ++++++++++++++++ spatial_server/train_model.py | 198 ++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 spatial_server/confidence.py create mode 100644 spatial_server/train_model.py diff --git a/spatial_server/confidence.py b/spatial_server/confidence.py new file mode 100644 index 0000000..9bb28ec --- /dev/null +++ b/spatial_server/confidence.py @@ -0,0 +1,94 @@ +import os +from pathlib import Path +import sys +import h5py +from collections import defaultdict + +from . import extract_features, extractors, matchers, pairs_from_retrieval, match_features, visualization +from .extract_features import ImageDataset +from .localize_sfm import QueryLocalizer, pose_from_cluster +from .fast_localize import localize +from .utils import viz_3d, io +from .utils.base_model import dynamic_load +from .utils.io import list_h5_names +from .utils.parsers import names_to_pair + +import pycolmap +import numpy as np +from scipy.spatial.transform import Rotation +import torch +import torch.nn as nn +import torch.optim as optim + +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +import random +import clip +from PIL import Image +import os +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from torch.cuda.amp import autocast, GradScaler +from .train_model import ProjectionHead + +torch_hub_dir = Path('data/torch_hub') +if not torch_hub_dir.exists(): + torch_hub_dir.mkdir(parents=True) +torch.hub.set_dir(str(torch_hub_dir)) + +def encode_map(map, device, preprocess, model): + model.load_state_dict(torch.load(f"models/{map}_ViTL14-336px.pth")) + # Create and load the projection head + projection_head = ProjectionHead(model.visual.output_dim, 512, 256).to(device) + projection_head.load_state_dict(torch.load(f"models/{map}_projection_head.pth")) + + map_path = f"/code/data/map_data/{map}/images" + + image_list = [] + image_names = [] + + for filename in os.listdir(map_path): + if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".jpeg"): # Adjust file extensions as needed + image_path = os.path.join(map_path, filename) + processed_image = preprocess(Image.open(image_path).convert('RGB')) + image_list.append(processed_image) + image_names.append(filename) + + image_batch = torch.stack(image_list, dim=0).to(device) + + with torch.no_grad(): + features = model.encode_image(map_path) + projected_features = projection_head(features.float()) + projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) + + #Save embeddings + embeddings = {"image_names": image_names, "projected_features": projected_features} + torch.save(embeddings, f"embeddings/{map}_embeddings.pt") + +def get_confidence(map, query_path, preprocess, model, device): + model.load_state_dict(torch.load(f"models/{map}_ViTL14-336px.pth")) + # Create and load the projection head + projection_head = ProjectionHead(model.visual.output_dim, 512, 256).to(device) + projection_head.load_state_dict(torch.load(f"models/{map}_projection_head.pth")) + + embeddings = torch.load(f"embeddings/{map}_embeddings.pt") + + image = preprocess(Image.open(query_path)).unsqueeze(0).to(device) + with torch.no_grad(): + img_features = model.encode_image(image) + + projected_img_features = projection_head(img_features.float()) + projected_img_features = projected_img_features / projected_img_features.norm(dim=-1, keepdim=True) + # print(projected_img_features.shape) + # print(datasets[data_set_name]['images_features'].shape) + similarity = torch.nn.functional.cosine_similarity(projected_img_features, embeddings["projected_features"]) + top1, idx = torch.topk(similarity, 1, dim=0) + return top1.tolist()[0], embeddings["image_names"][idx.tolist()[0]] + + + + + + + \ No newline at end of file diff --git a/spatial_server/train_model.py b/spatial_server/train_model.py new file mode 100644 index 0000000..c23c36a --- /dev/null +++ b/spatial_server/train_model.py @@ -0,0 +1,198 @@ +import os +from pathlib import Path +import sys +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim + +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +import random +import clip +from PIL import Image +import os +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from torch.cuda.amp import autocast, GradScaler + +torch_hub_dir = Path('data/torch_hub') +if not torch_hub_dir.exists(): + torch_hub_dir.mkdir(parents=True) +torch.hub.set_dir(str(torch_hub_dir)) + + +class ContrastiveDataset(Dataset): + def __init__(self, root_dir, anchor_folder, transform=None): + self.root_dir = root_dir + self.anchor_folder = anchor_folder + self.transform = transform + + self.anchor_images = sorted([f for f in os.listdir(os.path.join(root_dir, anchor_folder)) + if f.lower().endswith(('.png', '.jpg', '.jpeg', '.heif', '.HEIC'))]) + # print(self.anchor_images) + + self.other_folders = [f'{f}/images' for f in os.listdir(root_dir) + if not f.startswith('.') and os.path.isdir(os.path.join(root_dir, f)) and f'{f}/images' != anchor_folder] + print(self.other_folders) + + self.other_images = {} + for folder in self.other_folders: + self.other_images[folder] = [f for f in os.listdir(os.path.join(root_dir, folder)) + if f.lower().endswith(('.png', '.jpg', '.jpeg', '.heif', '.HEIC'))] + + def __len__(self): + return len(self.anchor_images) + + def __getitem__(self, idx): + # Anchor image (always from folder1) + anchor_img = self.anchor_images[idx] + anchor_path = os.path.join(self.root_dir, self.anchor_folder, anchor_img) + anchor = Image.open(anchor_path).convert('RGB') + + # Positive image (different image from folder1) + positive_img = random.choice([img for img in self.anchor_images if img != anchor_img]) + positive_path = os.path.join(self.root_dir, self.anchor_folder, positive_img) + positive = Image.open(positive_path).convert('RGB') + # print(anchor_path, positive_path) + + # Negative image (from a different folder) + negative_folder = random.choice(self.other_folders) + negative_img = random.choice(self.other_images[negative_folder]) + negative_path = os.path.join(self.root_dir, negative_folder, negative_img) + # print(negative_path) + negative = Image.open(negative_path).convert('RGB') + + if self.transform: + anchor = self.transform(anchor) + positive = self.transform(positive) + negative = self.transform(negative) + + return anchor, positive, negative + + +class CosineSimilarityContrastiveLoss(nn.Module): + def __init__(self, margin=0.6, negative_weight=1.2): + super().__init__() + self.margin = margin + self.negative_weight = negative_weight + + def forward(self, anchor, positive, negative): + cos = nn.CosineSimilarity(dim=1) + + similarity_positive = cos(anchor, positive) + similarity_negative = cos(anchor, negative) + + losses = torch.relu(self.margin - (similarity_positive - similarity_negative * self.negative_weight)) + return losses.mean() + + +class ProjectionHead(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.projection = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + x = self.projection(x.float()) + return x / x.norm(dim=-1, keepdim=True) + + +def train_model(map, num_epochs=20, batch_size=16, lr=1e-5): + dataset_names = [] + for dataset in os.listdir("/code/data/map_data/"): + if not dataset.startswith('.'): dataset_names.append(dataset) + + print(dataset_names) + + datasets = {} + + for dataset_name in dataset_names: + paths = {} + paths['querys_path'] = f'/code/data/query_data/{dataset_name}/images' + paths['imgs_path'] = f'/code/data/map_data/{dataset_name}/images' + datasets[dataset_name] = paths + + print(datasets) + + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Load pre-trained CLIP model + model, preprocess = clip.load("ViT-L/14@336px", device=device) + + # Enable gradient checkpointing + model.visual.transformer.grad_checkpointing = True + + # Freeze most of the model, only fine-tune the last few layers + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the last few layers (adjust as needed) + for param in model.visual.transformer.resblocks[-2:].parameters(): + param.requires_grad = True + + # Create projection head + projection_head = ProjectionHead(model.visual.output_dim, 512, 256).to(device) + + # Prepare dataset and dataloader + root_dir = "/code/data/map_data/" # This should contain your folders + transform = transforms.Compose([ + preprocess, + transforms.Lambda(lambda x: x.squeeze(0)) # Remove batch dimension added by CLIP's preprocess + ]) + anchor_folder = f"{map}/images" + + dataset = ContrastiveDataset(root_dir, anchor_folder, transform=transform) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Reduced batch size + + # Initialize contrastive loss + criterion = CosineSimilarityContrastiveLoss() + + # Adjust optimizer to only update unfrozen parameters + params_to_optimize = [p for p in list(model.parameters()) + list(projection_head.parameters()) if p.requires_grad] + optimizer = optim.Adam(params_to_optimize, lr=lr) + + for epoch in range(num_epochs): + model.train() + projection_head.train() + total_loss = 0 + progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}") + for i, batch in progress_bar: + anchor, positive, negative = [x.to(device) for x in batch] + + with torch.no_grad(): + anchor_features = model.encode_image(anchor) + positive_features = model.encode_image(positive) + negative_features = model.encode_image(negative) + + anchor_features = projection_head(anchor_features.float()) + positive_features = projection_head(positive_features.float()) + negative_features = projection_head(negative_features.float()) + + # Normalize features + anchor_features = anchor_features / anchor_features.norm(dim=-1, keepdim=True) + positive_features = positive_features / positive_features.norm(dim=-1, keepdim=True) + negative_features = negative_features / negative_features.norm(dim=-1, keepdim=True) + + loss = criterion(anchor_features, positive_features, negative_features) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0) + optimizer.step() + + total_loss += loss.item() + progress_bar.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(dataloader) + print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}") + + # Save the fine-tuned model + torch.save(model.state_dict(), f"spatial_server/models/{map}_ViTL14-336px.pth") + torch.save(projection_head.state_dict(), f"spatial_server/models/{map}_projection_head.pth") From a73364dfb46b09c8e8eaf266b47cd52381269269 Mon Sep 17 00:00:00 2001 From: Michael Liang Date: Mon, 22 Jul 2024 17:18:48 -0400 Subject: [PATCH 2/3] fixed some bugs with encoding images --- spatial_server/confidence.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/spatial_server/confidence.py b/spatial_server/confidence.py index 9bb28ec..b812232 100644 --- a/spatial_server/confidence.py +++ b/spatial_server/confidence.py @@ -4,15 +4,6 @@ import h5py from collections import defaultdict -from . import extract_features, extractors, matchers, pairs_from_retrieval, match_features, visualization -from .extract_features import ImageDataset -from .localize_sfm import QueryLocalizer, pose_from_cluster -from .fast_localize import localize -from .utils import viz_3d, io -from .utils.base_model import dynamic_load -from .utils.io import list_h5_names -from .utils.parsers import names_to_pair - import pycolmap import numpy as np from scipy.spatial.transform import Rotation @@ -38,10 +29,10 @@ torch.hub.set_dir(str(torch_hub_dir)) def encode_map(map, device, preprocess, model): - model.load_state_dict(torch.load(f"models/{map}_ViTL14-336px.pth")) + model.load_state_dict(torch.load(f"spatial_server/models/{map}_ViTL14-336px.pth")) # Create and load the projection head projection_head = ProjectionHead(model.visual.output_dim, 512, 256).to(device) - projection_head.load_state_dict(torch.load(f"models/{map}_projection_head.pth")) + projection_head.load_state_dict(torch.load(f"spatial_server/models/{map}_projection_head.pth")) map_path = f"/code/data/map_data/{map}/images" @@ -58,21 +49,21 @@ def encode_map(map, device, preprocess, model): image_batch = torch.stack(image_list, dim=0).to(device) with torch.no_grad(): - features = model.encode_image(map_path) + features = model.encode_image(image_batch) projected_features = projection_head(features.float()) projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) #Save embeddings embeddings = {"image_names": image_names, "projected_features": projected_features} - torch.save(embeddings, f"embeddings/{map}_embeddings.pt") + torch.save(embeddings, f"spatial_server/embeddings/{map}_embeddings.pt") def get_confidence(map, query_path, preprocess, model, device): - model.load_state_dict(torch.load(f"models/{map}_ViTL14-336px.pth")) + model.load_state_dict(torch.load(f"spatial_server/models/{map}_ViTL14-336px.pth")) # Create and load the projection head projection_head = ProjectionHead(model.visual.output_dim, 512, 256).to(device) - projection_head.load_state_dict(torch.load(f"models/{map}_projection_head.pth")) + projection_head.load_state_dict(torch.load(f"spatial_server/models/{map}_projection_head.pth")) - embeddings = torch.load(f"embeddings/{map}_embeddings.pt") + embeddings = torch.load(f"spatial_server/embeddings/{map}_embeddings.pt") image = preprocess(Image.open(query_path)).unsqueeze(0).to(device) with torch.no_grad(): From 3faf2e7f0127e5aac393e8ab501c5cf5c4068270 Mon Sep 17 00:00:00 2001 From: Michael Liang Date: Thu, 3 Apr 2025 07:03:19 -0400 Subject: [PATCH 3/3] confidence --- spatial_server/.DS_Store | Bin 0 -> 6148 bytes spatial_server/confidence.py | 79 +++++++++++++++++++------- spatial_server/train_model.py | 101 +++++++++++++++++++++++++++------- 3 files changed, 139 insertions(+), 41 deletions(-) create mode 100644 spatial_server/.DS_Store diff --git a/spatial_server/.DS_Store b/spatial_server/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..45907e61c98f5d0b359225fbcb39ef5becf404a7 GIT binary patch literal 6148 zcmeHKO-lnY5Pi`e3L^CAF@Hdd7tc#sZ{kHfYu#EbR4whU;B9}pZ$6Z+E_y3eW+3yD z`A8!PtJmLiv^?r!de$0AbS@j&-^`Bvd z_f`FH)okiDqb_iZ4fnOC9^;v@6`nAqWyx{L87efi%l>PP*>Hv_`wQwNPCffg=NNOe z?Y;Zv{i~{Luj-fD_u~hMDuaPwAQ%V+cFq9rY?0xSVd!8W7zhTw7?ATJvk4ZCt)U(r zRQd!U4ro?kU)~aulN}4k){q`bvQ(m_CVyfiOQ$`1T;bRnS~`-h%z1p}&&NyB)oITf zj#L