Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/search_most_dissimilar_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch
from src.download_images import download_images

# Load a pre-trained model (e.g., ResNet) for feature extraction
# Updated to use the correct 'weights' parameter

# Load a pre-trained ResNet50 model for feature extraction
model = models.resnet50(weights='IMAGENET1K_V1')
model = model.eval() # Set the model to evaluation mode

# Remove the final classification layer to extract 2048-dim features from avgpool
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
feature_extractor = feature_extractor.eval()

# Transformation for input images (resize, normalize, etc.)
transform = transforms.Compose([
transforms.Resize((224, 224)),
Expand All @@ -23,17 +27,17 @@ def extract_features(image_path):
Extracts features from an image using a pre-trained model.

:param image_path: Path to the image.
:return: Feature vector of the image.
:return: Feature vector of the image (2048-dim from avgpool layer).
"""
try:
image = Image.open(image_path).convert(
'RGB') # Ensure the image is in RGB format
image_tensor = transform(image).unsqueeze(
0) # Transform and add batch dimension
with torch.no_grad():
# Use the model to extract features
# Convert tensor to numpy array and flatten
features = model(image_tensor).flatten().numpy()
# Use the feature extractor to get 2048-dim features from avgpool
features = feature_extractor(image_tensor)
features = features.flatten().numpy()
return features
except Exception as e:
print(f"Error extracting features from {image_path}: {e}")
Expand Down
32 changes: 29 additions & 3 deletions src/train_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
from ultralytics import YOLO
import os
import logging
import torch

logger = logging.getLogger(__name__)


def get_optimal_batch_size():
"""
Determines optimal batch size based on available VRAM.

Returns:
int: Optimal batch size (16, 8, or 4)
"""
if torch.cuda.is_available():
# Get GPU memory in GB
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if gpu_mem >= 8:
return 16
elif gpu_mem >= 4:
return 8
else:
return 4
else:
# CPU training - use smaller batch
return 4


def train_model(data_yaml_path, model_type='yolov8'):
"""
Trains the YOLO model using the annotated dataset.
Expand All @@ -20,15 +42,19 @@ def train_model(data_yaml_path, model_type='yolov8'):
# Initialize YOLO model
model = YOLO('yolov8n.pt') # Start with pre-trained model

# Determine optimal batch size based on available VRAM
batch_size = get_optimal_batch_size()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Train with specific parameters
results = model.train(
data=data_yaml_path,
epochs=25, # Reduced epochs for faster training
epochs=25, # Default epochs for training
imgsz=640, # Image size
batch=8, # Batch size (reduce if memory issues)
batch=batch_size, # Auto batch size based on VRAM
patience=10, # Early stopping patience
save=True, # Save model
device='cpu' # Change to 'cuda' if GPU available
device=device # Use GPU if available, else CPU
)

# Get the best model path
Expand Down
Loading