Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds support for geometric invariance learning through contrastive learning and introduces three new spatial data augmentation transformations (translation, horizontal mirroring, and zoom) to improve model robustness to side-view recording variations.
Changes:
- Adds three new geometric data augmentation transforms: RandomTranslate, RandomMirrorX, and RandomZoom that operate consistently across all frames in a window
- Introduces a new self-supervised task (geom) using contrastive learning with InfoNCE loss and alignment metrics
- Adds ProjectionHead for contrastive learning, InfoNCELoss, AlignmentMetric, and UniformityMetric components
- Implements GeometricInvarianceDataset that generates pairs of original and transformed views for contrastive training
- Updates training loop to handle contrastive tasks that return pairs instead of (data, target) tuples
- Adds "human" body specifications to the drawing module
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_transforms_extra.py | Comprehensive tests for new spatial augmentations; updates ablation test comments from NaN to 0.0 |
| tests/test_training_helpers.py | Comments out two test functions for load_multi_records and save_model_config without explanation |
| tests/test_geometric_invariance.py | New test suite for contrastive learning components (ProjectionHead, InfoNCE, metrics, dataset) |
| tests/test_data_augmentation_config.py | Adds config tests for new augmentations; removes validation tests for frac/sigma parameters |
| tests/test_augmentation_integration.py | Adds integration tests for translate, mirror_x, zoom augmentations individually and combined |
| src/lisbet/transforms_extra.py | Implements RandomTranslate, RandomMirrorX, RandomZoom with proper NaN handling and bounds checking |
| src/lisbet/training/tasks.py | Adds geometric invariance task configuration and transform builders for new augmentations |
| src/lisbet/training/core.py | Modifies training and evaluation loops to handle contrastive tasks with paired data |
| src/lisbet/modeling/metrics.py | New file implementing AlignmentMetric and UniformityMetric for contrastive learning evaluation |
| src/lisbet/modeling/losses.py | New file implementing InfoNCELoss for contrastive learning |
| src/lisbet/modeling/heads/projection.py | New file implementing ProjectionHead with optional batch normalization and L2 normalization |
| src/lisbet/modeling/heads/init.py | Exports ProjectionHead |
| src/lisbet/modeling/factory.py | Adds factory support for creating ProjectionHead instances |
| src/lisbet/modeling/init.py | Exports contrastive learning components |
| src/lisbet/datasets/iterable_style.py | Implements GeometricInvarianceDataset for generating augmented pairs |
| src/lisbet/datasets/init.py | Exports GeometricInvarianceDataset |
| src/lisbet/config/schemas.py | Adds all_translate, all_mirror_x, all_zoom to valid augmentation names |
| src/lisbet/cli/commands/train.py | Updates CLI help text to document new augmentations and geom task |
| src/lisbet/drawing.py | Adds human body specifications for visualization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| train_dataset = datasets.GeometricInvarianceDataset( | ||
| records=train_rec["geom"], | ||
| window_size=window_size, | ||
| window_offset=window_offset, | ||
| transform=train_transform, | ||
| base_seed=run_seeds["dataset_geom"], | ||
| ) |
There was a problem hiding this comment.
The GeometricInvarianceDataset is initialized without the fps_scaling parameter (lines 377-383 and 400-406), but the dataset's init method expects it (see iterable_style.py line 727). This will cause a TypeError when the dataset is instantiated. The fps_scaling parameter should be added to these dataset initializations, similar to how it's done in the _configure_selfsupervised_task function.
| print(type(x_orig.position.values)) | ||
| print(x_orig.position.values) |
There was a problem hiding this comment.
Debug print statements should be removed or replaced with proper logging. Print statements in tests can clutter output and should be removed before merging.
| x_orig, x_transform = next(iterator) | ||
|
|
||
| # Check both views are xarray DataArrays | ||
| print(type(x_orig)) |
There was a problem hiding this comment.
Debug print statement should be removed or replaced with proper logging. Print statements in tests can clutter output and should be removed before merging.
| # def test_load_multi_records_success(tmp_path): | ||
| # """Test _load_multi_records succeeds with consistent features across datasets.""" | ||
| # root1 = make_dummy_dataset(tmp_path / "ds1", keypoints=("nose", "tail")) | ||
| # root2 = make_dummy_dataset(tmp_path / "ds2", keypoints=("nose", "tail")) | ||
| # records = load_multi_records( | ||
| # data_format="movement,movement", | ||
| # data_path=f"{root1},{root2}", | ||
| # data_scale=None, | ||
| # data_filter=None, | ||
| # select_coords=None, | ||
| # rename_coords=None, | ||
| # ) | ||
| # assert len(records) == 2 | ||
|
|
||
|
|
||
| # def test_load_multi_records_inconsistent_features_raises(tmp_path): | ||
| # """ | ||
| # Test _load_multi_records raises ValueError if features are inconsistent across | ||
| # datasets. | ||
| # """ | ||
| # root1 = make_dummy_dataset(tmp_path / "ds1", keypoints=("nose", "tail")) | ||
| # root2 = make_dummy_dataset(tmp_path / "ds2", keypoints=("nose",)) | ||
| # with pytest.raises( | ||
| # ValueError, match="Inconsistent posetracks coordinates in loaded records" | ||
| # ): | ||
| # load_multi_records( | ||
| # data_format="movement,movement", | ||
| # data_path=f"{root1},{root2}", | ||
| # data_scale=None, | ||
| # data_filter=None, | ||
| # select_coords=None, | ||
| # rename_coords=None, | ||
| # ) |
There was a problem hiding this comment.
Two tests (test_load_multi_records_success and test_load_multi_records_inconsistent_features_raises) have been commented out without explanation. If these tests are no longer needed or are being replaced, they should be removed entirely. If they're temporarily disabled due to an issue, a TODO comment or issue reference should be added explaining why and when they will be re-enabled.
|
|
||
| # Create random input | ||
| x = torch.randn(batch_size, sequence_length, input_dim) | ||
| print("Input shape:", x.shape) |
There was a problem hiding this comment.
Debug print statement should be removed or replaced with proper logging. Print statements in tests can clutter output and should be removed before merging, unless they serve a specific testing purpose.
| task.dev_dataset = datasets.GeometricInvarianceDataset( | ||
| records=dev_rec["geom"], | ||
| window_size=window_size, | ||
| window_offset=window_offset, | ||
| transform=dev_transform, | ||
| base_seed=run_seeds["dataset_geom"], | ||
| ) |
There was a problem hiding this comment.
The GeometricInvarianceDataset is initialized without the fps_scaling parameter, but the dataset's init method expects it (see iterable_style.py line 727). This will cause a TypeError when the dataset is instantiated. The fps_scaling parameter should be added to this dataset initialization.
|
|
||
| # Create random input | ||
| x = torch.randn(batch_size, sequence_length, input_dim) | ||
| print("Input shape:", x.shape) |
There was a problem hiding this comment.
Debug print statement should be removed or replaced with proper logging. Print statements in tests can clutter output and should be removed before merging.
| # head = ProjectionHead( | ||
| # input_dim=256, | ||
| # hidden_dim=512, | ||
| # projection_dim=128, | ||
| # normalize=False, | ||
| # ) | ||
|
|
||
| # x = torch.randn(8, 256) | ||
| # output = head(x) | ||
|
|
||
| # assert output.shape == (8, 128) |
There was a problem hiding this comment.
Commented-out code should be removed. Lines 103-113 contain commented-out code that appears to be an older version of the test. Dead code should be removed to improve maintainability and reduce confusion.
| # head = ProjectionHead( | |
| # input_dim=256, | |
| # hidden_dim=512, | |
| # projection_dim=128, | |
| # normalize=False, | |
| # ) | |
| # x = torch.randn(8, 256) | |
| # output = head(x) | |
| # assert output.shape == (8, 128) |
| # def test_save_model_config(tmp_path): | ||
| # run_id = "testrun" | ||
| # # Use Task dataclass for tasks | ||
| # task1 = Task( | ||
| # task_id="multiclass", | ||
| # head=None, | ||
| # out_dim=3, | ||
| # loss_function=None, | ||
| # train_dataset=None, | ||
| # train_loss=None, | ||
| # train_score=None, | ||
| # ) | ||
| # task2 = Task( | ||
| # task_id="order", | ||
| # head=None, | ||
| # out_dim=1, | ||
| # loss_function=None, | ||
| # train_dataset=None, | ||
| # train_loss=None, | ||
| # train_score=None, | ||
| # ) | ||
| # tasks = [task1, task2] | ||
| # input_features = [["mouse", "nose", "x"], ["mouse", "nose", "y"]] | ||
| # dump_model_config( | ||
| # tmp_path, run_id, 200, 0, -1, 8, 32, 128, 4, 4, 200, tasks, input_features | ||
| # ) | ||
| # config_path = tmp_path / "models" / run_id / "model_config.yml" | ||
| # assert config_path.exists() | ||
|
|
||
| # # Check input_features in config | ||
| # with open(config_path, encoding="utf-8") as f: | ||
| # config = yaml.safe_load(f) | ||
| # assert "input_features" in config | ||
| # assert config["input_features"] == input_features |
There was a problem hiding this comment.
The test test_save_model_config has been commented out without explanation. If this test is no longer needed or is being replaced, it should be removed entirely. If it's temporarily disabled due to an issue, a TODO comment or issue reference should be added explaining why and when it will be re-enabled.
| cfg1 = DataAugmentationConfig(name="zoom") | ||
| assert cfg1.name == "zoom" | ||
| assert cfg1.p == 1.0 | ||
| assert cfg1.frac is None | ||
|
|
||
| # zoom with probability | ||
| cfg2 = DataAugmentationConfig(name="zoom", p=0.3) |
There was a problem hiding this comment.
The test uses DataAugmentationConfig(name="zoom") but the valid augmentation name is "all_zoom" (with the "all_" prefix) as defined in config/schemas.py lines 103-105. This test will fail during validation. The test should use "all_zoom" instead of "zoom".
| cfg1 = DataAugmentationConfig(name="zoom") | |
| assert cfg1.name == "zoom" | |
| assert cfg1.p == 1.0 | |
| assert cfg1.frac is None | |
| # zoom with probability | |
| cfg2 = DataAugmentationConfig(name="zoom", p=0.3) | |
| cfg1 = DataAugmentationConfig(name="all_zoom") | |
| assert cfg1.name == "all_zoom" | |
| assert cfg1.p == 1.0 | |
| assert cfg1.frac is None | |
| # zoom with probability | |
| cfg2 = DataAugmentationConfig(name="all_zoom", p=0.3) |
Add a data augmentation and a new task for robustness to side-view recording.
Data augmentation :
Task:
Has been tested :