Skip to content

Commit f82e7ab

Browse files
committed
Initial commit
0 parents  commit f82e7ab

68 files changed

Lines changed: 8521 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2025 Li Yuan, Xinrui Zhai, and Fangzhou Ma
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 908 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
[project]
2+
name = "garmentiq"
3+
version = "0.0.4.8"
4+
authors = [
5+
{ name="Li Yuan", email="lyuan@gd.edu.kg" },
6+
{ name="Xinrui Zhai", email="zxr86272731@gmail.com" },
7+
{ name="Fangzhou Ma", email="fangzhou.ma029@gmail.com" },
8+
]
9+
maintainers = [
10+
{ name="Li Yuan", email="lyuan@gd.edu.kg" },
11+
{ name="Xinrui Zhai", email="zxr86272731@gmail.com" },
12+
{ name="Fangzhou Ma", email="fangzhou.ma029@gmail.com" },
13+
]
14+
description = "Automated Garment Measurement for Fashion Retail"
15+
readme = "README.md"
16+
requires-python = ">=3.11"
17+
classifiers = [
18+
"Programming Language :: Python :: 3",
19+
"License :: OSI Approved :: MIT License"
20+
]
21+
dependencies = [
22+
"tqdm==4.67.1",
23+
"pandas==2.2.2",
24+
"numpy==2.0.0",
25+
"torch==2.7.0",
26+
"torchvision==0.22.0",
27+
"scikit-learn==1.6.1",
28+
"scipy==1.15.2",
29+
"Pillow==11.1.0",
30+
"matplotlib==3.10.0",
31+
"transformers==4.50.3",
32+
"kornia==0.8.0",
33+
"timm==1.0.15",
34+
"einops==0.8.1",
35+
"shapely==2.1.1",
36+
"opencv-python==4.11.0.86"
37+
]
38+
39+
[project.urls]
40+
Homepage = "https://garmentiq.ly.gd.edu.kg/"
41+
Issues = "https://github.com/lygitdata/GarmentIQ/issues"
42+
43+
[tool.setuptools.package-data]
44+
"garmentiq.instruction" = ["*.json"]

src/garmentiq/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# garmentiq/__init__.py
2+
from .tailor import tailor
3+
from . import utils
4+
from . import classification
5+
from . import segmentation
6+
from . import landmark
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# garmentiq/classification/__init__.py
2+
from .train_test_split import train_test_split
3+
from .load_data import load_data
4+
from .load_model import load_model
5+
from .train_pytorch_nn import train_pytorch_nn
6+
from .fine_tune_pytorch_nn import fine_tune_pytorch_nn
7+
from .test_pytorch_nn import test_pytorch_nn
8+
from .predict import predict
9+
from .utils import (
10+
CachedDataset,
11+
seed_worker,
12+
train_epoch,
13+
validate_epoch,
14+
save_best_model,
15+
validate_train_param,
16+
validate_test_param,
17+
)
18+
from .model_definition import (
19+
CNN3,
20+
CNN4,
21+
tinyViT,
22+
)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.utils.data import DataLoader
4+
from typing import Callable, Type
5+
from tqdm.notebook import tqdm
6+
import os
7+
from sklearn.model_selection import StratifiedKFold
8+
from garmentiq.classification.utils import (
9+
CachedDataset,
10+
seed_worker,
11+
train_epoch,
12+
validate_epoch,
13+
save_best_model,
14+
validate_train_param,
15+
validate_test_param,
16+
)
17+
18+
def fine_tune_pytorch_nn(
19+
model_class: Type[torch.nn.Module],
20+
model_args: dict,
21+
dataset_class: Callable,
22+
dataset_args: dict,
23+
param: dict,
24+
):
25+
"""
26+
Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing.
27+
28+
This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset
29+
while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves
30+
the best performing model.
31+
32+
Args:
33+
model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from `torch.nn.Module`).
34+
model_args (dict): Arguments for model initialization.
35+
dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors.
36+
dataset_args (dict): Dict containing:
37+
- 'metadata_df': DataFrame for stratification
38+
- 'raw_labels': Labels array for KFold
39+
- 'cached_images': Tensor of images
40+
- 'cached_labels': Tensor of labels
41+
param (dict): Training configuration dict. Must include:
42+
- 'pretrained_path' (str): Path to pretrained weights (.pt)
43+
- 'freeze_layers' (bool): Whether to freeze base layers
44+
- 'optimizer_class', 'optimizer_args'
45+
- optional: 'device', 'n_fold', 'n_epoch', 'patience',
46+
'batch_size', 'model_save_dir', 'seed',
47+
'seed_worker', 'max_workers', 'pin_memory',
48+
'persistent_workers', 'best_model_name'
49+
50+
Raises:
51+
ValueError: If required keys are missing.
52+
Returns: None
53+
"""
54+
# Validate parameters
55+
validate_train_param(param)
56+
os.makedirs(param.get("model_save_dir", "./models"), exist_ok=True)
57+
overall_best_loss = float("inf")
58+
best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"])
59+
60+
# Stratified KFold
61+
kfold = StratifiedKFold(
62+
n_splits=param.get("n_fold", 5), shuffle=True, random_state=param.get("seed", 88)
63+
)
64+
65+
for fold, (train_idx, val_idx) in enumerate(
66+
kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"])
67+
):
68+
print(f"\nFold {fold+1}/{param.get('n_fold',5)}")
69+
70+
# Prepare data loaders
71+
train_dataset = dataset_class(
72+
train_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
73+
)
74+
val_dataset = dataset_class(
75+
val_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
76+
)
77+
78+
g = torch.Generator()
79+
g.manual_seed(param.get("seed", 88))
80+
81+
train_loader = DataLoader(
82+
train_dataset,
83+
batch_size=param.get("batch_size", 64),
84+
shuffle=True,
85+
num_workers=param.get("max_workers", 1),
86+
worker_init_fn=param.get("seed_worker", seed_worker),
87+
generator=g,
88+
pin_memory=param.get("pin_memory", True),
89+
persistent_workers=param.get("persistent_workers", False),
90+
)
91+
val_loader = DataLoader(
92+
val_dataset,
93+
batch_size=param.get("batch_size", 64),
94+
shuffle=False,
95+
num_workers=param.get("max_workers", 1),
96+
worker_init_fn=param.get("seed_worker", seed_worker),
97+
generator=g,
98+
pin_memory=param.get("pin_memory", True),
99+
persistent_workers=param.get("persistent_workers", False),
100+
)
101+
102+
# Initialize model and load pretrained weights
103+
device = param.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
104+
model = model_class(**model_args).to(device)
105+
106+
# Load pretrained weights
107+
state_dict = torch.load(param["pretrained_path"], map_location=device)
108+
cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()}
109+
model.load_state_dict(cleaned, strict=False)
110+
111+
# Freeze base layers if requested
112+
if param.get("freeze_layers", False):
113+
for name, p in model.named_parameters():
114+
if not any(x in name for x in param.get("unfreeze_patterns", [])):
115+
p.requires_grad = False
116+
117+
# DataParallel if multiple GPUs
118+
if device.type == "cuda" and torch.cuda.device_count() > 1:
119+
model = nn.DataParallel(model)
120+
121+
optimizer = param["optimizer_class"](
122+
filter(lambda p: p.requires_grad, model.parameters()),
123+
**param["optimizer_args"]
124+
)
125+
torch.cuda.empty_cache()
126+
127+
best_fold_loss = float("inf")
128+
patience_counter = 0
129+
epoch_pbar = tqdm(range(param.get("n_epoch", 100)), desc="Epoch", leave=False)
130+
131+
# Training loop
132+
for epoch in epoch_pbar:
133+
train_loss = train_epoch(model, train_loader, optimizer, param)
134+
val_loss, val_f1, val_acc = validate_epoch(model, val_loader, param)
135+
136+
best_fold_loss, patience_counter, overall_best_loss = save_best_model(
137+
model, val_loss, best_fold_loss, patience_counter,
138+
overall_best_loss, param, fold, best_model_path
139+
)
140+
141+
epoch_pbar.set_postfix({
142+
'train_loss': f"{train_loss:.4f}",
143+
'val_loss': f"{val_loss:.4f}",
144+
'val_acc': f"{val_acc:.4f}",
145+
'val_f1': f"{val_f1:.4f}",
146+
'patience': patience_counter,
147+
})
148+
149+
print(f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
150+
if patience_counter >= param.get("patience", 5):
151+
print(f"Early stopping at epoch {epoch+1}")
152+
break
153+
154+
torch.cuda.empty_cache()
155+
print(f"\nFine-tuning completed. Best model saved at: {best_model_path}")
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from torchvision import transforms
2+
import os
3+
from PIL import Image
4+
import torch
5+
from garmentiq.classification.utils import (
6+
CachedDataset,
7+
seed_worker,
8+
train_epoch,
9+
validate_epoch,
10+
save_best_model,
11+
validate_train_param,
12+
validate_test_param,
13+
)
14+
from tqdm.notebook import tqdm
15+
16+
17+
def load_data(
18+
df,
19+
img_dir,
20+
label_column,
21+
resize_dim=(120, 184),
22+
normalize_mean=[0.8047, 0.7808, 0.7769],
23+
normalize_std=[0.2957, 0.3077, 0.3081],
24+
):
25+
"""
26+
Loads and preprocesses image data into memory from a DataFrame of filenames and labels.
27+
28+
This function reads images from the specified directory, applies resizing, normalization,
29+
and tensor conversion, and encodes labels from a specified column. It returns tensors for
30+
images and labels, along with the transform pipeline used.
31+
32+
Args:
33+
df (pandas.DataFrame): A pandas DataFrame containing at least a 'filename' column and a label column.
34+
img_dir (str): Path to the directory containing image files.
35+
label_column (str): Name of the column in `df` containing class labels.
36+
resize_dim (tuple[int, int]): Tuple indicating the dimensions (height, width) to resize each image to.
37+
Defaults to (120, 184).
38+
normalize_mean (list[float]): Mean values for normalization (per channel).
39+
Defaults to `[0.8047, 0.7808, 0.7769]`.
40+
normalize_std (list[float]): Standard deviation values for normalization (per channel).
41+
Defaults to `[0.2957, 0.3077, 0.3081]`.
42+
43+
Returns:
44+
tuple[torch.Tensor, torch.Tensor, torchvision.transforms.Compose]: A tuple containing:
45+
- cached_images (torch.Tensor): Tensor containing all preprocessed images.
46+
- cached_labels (torch.Tensor): Tensor containing all encoded labels.
47+
- transform (torchvision.transforms.Compose): The transformation pipeline used.
48+
"""
49+
transform = transforms.Compose(
50+
[
51+
transforms.Resize(resize_dim),
52+
transforms.ToTensor(),
53+
transforms.Normalize(mean=normalize_mean, std=normalize_std),
54+
]
55+
)
56+
57+
classes = sorted(df[label_column].unique())
58+
class_to_idx = {c: i for i, c in enumerate(classes)}
59+
60+
cached_images = []
61+
cached_labels = []
62+
63+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Loading data into memory"):
64+
img_path = os.path.join(img_dir, row["filename"])
65+
image = Image.open(img_path).convert("RGB")
66+
image = transform(image)
67+
68+
label = class_to_idx[row[label_column]]
69+
70+
cached_images.append(image)
71+
cached_labels.append(label)
72+
73+
cached_images = torch.stack(cached_images)
74+
cached_labels = torch.tensor(cached_labels)
75+
76+
return cached_images, cached_labels, transform
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from typing import Type, List
5+
6+
7+
def load_model(model_path: str, model_class: Type[nn.Module], model_args: dict):
8+
"""
9+
Loads a PyTorch model from a checkpoint and prepares it for inference.
10+
11+
This function initializes a model from the provided `model_class`, loads its weights from
12+
the given file path, moves it to the appropriate device (GPU if available, otherwise CPU),
13+
and sets it to evaluation mode.
14+
15+
Args:
16+
model_path (str): Path to the saved model checkpoint (.pth or .pt file).
17+
model_class (Type[nn.Module]): The class definition of the model to be instantiated.
18+
This must be a subclass of `torch.nn.Module`.
19+
model_args (dict): A dictionary of arguments used to initialize the model class.
20+
21+
Returns:
22+
torch.nn.Module: The loaded and ready-to-use model.
23+
"""
24+
device = "cuda" if torch.cuda.is_available() else "cpu"
25+
26+
model = model_class(**model_args).to(device)
27+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
28+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
29+
model.load_state_dict(new_state_dict, strict=False)
30+
model.eval()
31+
32+
return model

0 commit comments

Comments
 (0)