Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/lisbet/cli/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def configure_train_model_parser(parser: argparse.ArgumentParser) -> None:
- order: Temporal Order Classification
- shift: Temporal Shift Classification
- warp: Temporal Warp Classification
- geom : Geometric Consistency Classification

Example:
order,cons
Expand Down Expand Up @@ -132,6 +133,18 @@ def configure_train_model_parser(parser: argparse.ArgumentParser) -> None:
Use Bernoulli(pB) to select which keypoints to ablate.
Simulates sporadic occlusions or tracking failures.

- all_translate: Randomly translate all individuals together in x,y
consistently across all frames in a window.
Translation computed to keep all keypoints in [0,1] bounds.
Provides invariance to location within frame.
- all_mirror_x: Randomly mirror horizontally around x=0.5.
Consistently across all frames in a window.
Provides invariance to lateral orientation.
- all_zoom: Randomly zoom/dezoom around center (0.5, 0.5).
Consistently across all frames in a window.
Scale computed to keep all keypoints in [0,1] bounds.
Provides invariance to depth/distance.


Parameters (optional):
- p=<float>: Probability of applying the transformation (default: 1.0)
Expand Down
11 changes: 11 additions & 0 deletions src/lisbet/config/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class DataAugmentationConfig(BaseModel):
individuals), sets selected elements to NaN (all space dims).
Simulates missing or occluded keypoints.

Horizontal Transformations:
- all_translate: Randomly translate all individuals together in x,y.
- all_mirror_x: Randomly mirror horizontally around x=0.5.
- all_zoom: Randomly zoom/dezoom around center (0.5, 0.5).



Attributes:
Expand All @@ -95,6 +100,9 @@ class DataAugmentationConfig(BaseModel):
"blk_perm_id",
"gauss_jitter",
"kp_ablation",
"all_translate",
"all_mirror_x",
"all_zoom",
]
p: float = 1.0
pB: float | None = None
Expand All @@ -108,6 +116,9 @@ class DataAugmentationConfig(BaseModel):
"blk_perm_id": {"p", "frac"},
"gauss_jitter": {"p", "sigma"},
"kp_ablation": {"p", "pB"},
"all_translate": {"p"},
"all_mirror_x": {"p"},
"all_zoom": {"p"},
}

@field_validator("p")
Expand Down
2 changes: 2 additions & 0 deletions src/lisbet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TemporalOrderDataset,
TemporalShiftDataset,
TemporalWarpDataset,
GeometricInvarianceDataset,
)
from lisbet.datasets.map_style import AnnotatedWindowDataset, WindowDataset

Expand All @@ -13,6 +14,7 @@
"TemporalOrderDataset",
"TemporalShiftDataset",
"TemporalWarpDataset",
"GeometricInvarianceDataset",
"AnnotatedWindowDataset",
"WindowDataset",
]
Expand Down
152 changes: 152 additions & 0 deletions src/lisbet/datasets/iterable_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import IterableDataset

from lisbet.datasets.common import AnnotatedWindowSelector, WindowSelector
from lisbet.transforms_extra import RandomMirrorX, RandomTranslate, RandomZoom


class SocialBehaviorDataset(IterableDataset):
Expand Down Expand Up @@ -688,3 +689,154 @@ def __iter__(self):
x = self.transform(x)

yield x, y

class GeometricInvarianceDataset(IterableDataset):
"""
Iterable dataset for the Geometric Invariance self-supervised task.
Generates pairs of windows for contrastive learning, where the model learns that
geometric transformations (translation, flip, zoom) preserve the semantic identity of the
scene. Each sample consists of an original window and a geometrically transformed
version of the same window.
Unlike binary classification tasks, this dataset yields pairs (x_orig, x_transform)
for contrastive learning with InfoNCE loss. The model learns to produce similar
embeddings for different views of the same scene.
Geometric transformations applied:
- Translation: Random translation in x and y directions
- Flip: Random horizontal flip
- Zoom: Random zoom in/out around the center
Notes
-----
1. This is a contrastive learning task, NOT a classification task. The dataset
returns pairs of views without explicit labels.
2. The transformations are applied in the keypoint space (before the general
augmentation pipeline).
3. The same geometric transformation is applied to all individuals in the group
to preserve relative spatial relationships.
4. The transform parameter should only contain the standard augmentation pipeline
(normalization, missing data handling, etc.), NOT the geometric transformations
which are handled internally.
5. Both views (original and transformed) go through the same augmentation pipeline
for consistency.
"""

def __init__(
self,
records,
window_size,
window_offset=0,
fps_scaling=1.0,
transform=None,
base_seed=None,
):
"""
Initialize the GeometricInvarianceDataset.
Parameters
----------
records : list
List of records containing the data.
window_size : int
Size of the window in frames.
window_offset : int, optional
Offset for the window in frames (default is 0).
fps_scaling : float, optional
Scaling factor for the frames per second (default is 1.0).
transform : callable, optional
A function/transform to apply to BOTH views (default is None).
This should contain the standard augmentation pipeline, NOT geometric
transformations.
base_seed : int, optional
Base seed for random number generation (default is None, which generates a
random seed).
"""
super().__init__()

self.window_selector = WindowSelector(
records, window_size, window_offset, fps_scaling
)
self.n_frames = self.window_selector.n_frames
self.transform = transform

self.base_seed = (
base_seed
if base_seed is not None
else torch.randint(0, 2**31 - 1, (1,)).item()
)

# Set random generator for reproducibility
# NOTE: This could be overridden by the worker_init_fn to ensure each worker
# has a different seed for data shuffling.
self.g = torch.Generator().manual_seed(self.base_seed)

# Create geometric transformation functions
# Use a different seed for each transformation to ensure variety
self.translate = RandomTranslate(seed=self.base_seed)
self.mirror_x = RandomMirrorX(seed=self.base_seed + 1)
self.zoom = RandomZoom(seed=self.base_seed + 2)
Comment on lines +772 to +774
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The geometric transformation instances (self.translate, self.mirror_x, self.zoom) are created once in init with fixed seeds (lines 772-774). However, each time they are called, they will produce the same transformation because they use the same generator state. This means all transformed samples will have identical transformations applied, defeating the purpose of data augmentation. Each transformation should either: (1) be re-initialized with a new seed for each sample, or (2) the transforms should be modified to accept and update a shared generator state.

Copilot uses AI. Check for mistakes.

def _apply_geometric_transform(self, x):
"""
Apply random geometric transformations to the window.
Uses the existing data augmentation methods from transforms_extra.py
for consistency with the rest of the codebase.

Randomly selects 1 to 3 transformations and applies them in random order.
Parameters
----------
x : xr.Dataset
Window dataset with "position" variable.
Returns
-------
xr.Dataset
Transformed window with the same shape.
"""
# x is already a Dataset from window_selector.select()
x_ds = x.copy(deep=True)

# Available transformations
available_transforms = [
('translate', self.translate),
('mirror_x', self.mirror_x),
('zoom', self.zoom),
]

# Randomly select how many transformations to apply (1 to 3)
num_transforms = torch.randint(1, 4, (1,), generator=self.g).item()

# Randomly shuffle and select transformations
indices = torch.randperm(len(available_transforms), generator=self.g)[:num_transforms]
selected_transforms = [available_transforms[i] for i in indices]

# Apply selected transformations in random order
for name, transform in selected_transforms:
x_ds = transform(x_ds)

# Store transformation info for debugging
x_ds.attrs["geometric_transforms_applied"] = [name for name, _ in selected_transforms]

return x_ds

def __iter__(self):
while True:
# Select a random window (global frame index)
global_idx = torch.randint(0, self.n_frames, (1,), generator=self.g).item()

# Map global index to (record_index, frame_index)
rec_idx, frame_idx = self.window_selector.global_to_local(global_idx)

# Extract corresponding window
x_orig = self.window_selector.select(rec_idx, frame_idx)

# Apply geometric transformation
x_transform = self._apply_geometric_transform(x_orig)

# Add debugging information
x_orig.attrs["orig_coords"] = [rec_idx, frame_idx]
x_transform.attrs["orig_coords"] = [rec_idx, frame_idx]

# Apply standard augmentation pipeline to BOTH views
if self.transform:
x_orig = self.transform(x_orig)
x_transform = self.transform(x_transform)

# Yield pair of views (NOT x, y like classification tasks)
yield x_orig, x_transform
49 changes: 49 additions & 0 deletions src/lisbet/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,53 @@ def color_to_bgr(color):
keypoint_marker="^",
keypoint_size=6,
),
"human": BodySpecs(
skeleton_edges=[
# Face
("nose", "left_eye"),
("left_eye", "left_ear"),
("nose", "right_eye"),
("right_eye", "right_ear"),
# Upper body
("left_shoulder", "right_shoulder"),
("left_shoulder", "left_elbow"),
("left_elbow", "left_wrist"),
("right_shoulder", "right_elbow"),
("right_elbow", "right_wrist"),
# Torso
("left_shoulder", "left_hip"),
("right_shoulder", "right_hip"),
("left_hip", "right_hip"),
# Lower body
("left_hip", "left_knee"),
("left_knee", "left_ankle"),
("right_hip", "right_knee"),
("right_knee", "right_ankle"),
],
polygons=[],
keypoint_colors={
"nose": "red",
"left_eye": "orange",
"right_eye": "orange",
"left_ear": "yellow",
"right_ear": "yellow",
"left_shoulder": "lime",
"right_shoulder": "lime",
"left_elbow": "cyan",
"right_elbow": "cyan",
"left_wrist": "blue",
"right_wrist": "blue",
"left_hip": "magenta",
"right_hip": "magenta",
"left_knee": "purple",
"right_knee": "purple",
"left_ankle": "pink",
"right_ankle": "pink",
},
skeleton_color="lime",
polygon_color="cyan",
polygon_alpha=0.3,
keypoint_marker="o",
keypoint_size=4,
),
}
8 changes: 8 additions & 0 deletions src/lisbet/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from lisbet.modeling.info import model_info
from lisbet.modeling.models import MultiTaskModel

from lisbet.modeling.heads.projection import ProjectionHead
from lisbet.modeling.losses import InfoNCELoss
from lisbet.modeling.metrics import AlignmentMetric, UniformityMetric

__all__ = [
"FrameClassificationHead",
"WindowClassificationHead",
Expand All @@ -16,6 +20,10 @@
"MultiTaskModel",
"LSTMBackbone",
"TransformerBackbone",
"ProjectionHead",
"InfoNCELoss",
"AlignmentMetric",
"UniformityMetric",
]

__doc__ = """
Expand Down
10 changes: 10 additions & 0 deletions src/lisbet/modeling/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EmbeddingHead,
FrameClassificationHead,
MultiTaskModel,
ProjectionHead,
TransformerBackbone,
WindowClassificationHead,
)
Expand All @@ -38,6 +39,7 @@
"frame_classification": FrameClassificationHead,
"window_classification": WindowClassificationHead,
"embedding": EmbeddingHead,
"projection": ProjectionHead,
}


Expand Down Expand Up @@ -92,6 +94,14 @@ def create_model_from_config(model_config: ModelConfig) -> MultiTaskModel:
num_classes=head_cfg.get("num_classes", 1),
hidden_dim=head_cfg.get("hidden_dim"),
)
elif task_id == "geom":
heads[task_id] = ProjectionHead(
input_dim=backbone_kwargs["embedding_dim"],
projection_dim=head_cfg.get("projection_dim",
head_cfg.get("hidden_dim", 256)),
hidden_dim=head_cfg.get("hidden_dim"),
normalize=head_cfg.get("normalize", True),
)
else:
raise ValueError(f"Unknown task_id: {task_id}")

Expand Down
2 changes: 2 additions & 0 deletions src/lisbet/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
WindowClassificationHead,
)
from lisbet.modeling.heads.embedding import EmbeddingHead
from lisbet.modeling.heads.projection import ProjectionHead

__all__ = [
"FrameClassificationHead",
"WindowClassificationHead",
"EmbeddingHead",
"ProjectionHead",
]
Loading
Loading