This is the official PyTorch implementation of SUE from the paper "Learning Shared Representations from Unpaired Data".
To run the project, clone this repo and then create a conda environment via:
conda env create -f environment.ymlSubsequently, activate this environment:
conda activate sueTo run an example of the project on the retrieval task, follow these steps:
-
Download the model checkpoints and data encodings from here.
-
Unzip the downloaded files.
-
Locate:
- The model checkpoint file:
checkpoints_flickr30.pth(inside thecheckpointsfolder). - The data encodings: found under
data/flickr30.
- The model checkpoint file:
-
Run the following command:
python retrieval.py --test flickr30- If you want to train the model from scratch, use the following command:
python retrieval.py --train flickr30If you want to bypass the command-line interface or apply SUE to your own data pipeline, you can interact with the Trainer class directly.
Before training, Ensuring your training data is properly formatted as weakly paired. You can achieve this using the create_weakly_parallel_data function.
from data import load_dataset, create_weakly_parallel_data
from trainer import Trainer
# 1. Load your configuration and dataset
train_set, test_set = load_dataset("your_dataset_name", n_test=400)
# 2. Make your data weakly paired (Crucial step for SUE)
train_set = create_weakly_parallel_data(train_set, n_parallel=100)
# 3. Initialize the Trainer
trainer = Trainer(
dataset_name="your_dataset_name",
n_parallel=100,
n_components=30, # Adjust based on your config
configs=your_config_dict
)
# 4. Fit the model using the weakly paired data
trainer.fit(
train_set=train_set,
)
# 5. Test the model
trainer.test(test_set=test_set)If you find our work useful, please cite it:
@inproceedings{yacobi2025sue,
title={Learning Shared Representations from Unpaired Data},
author={Yacobi, Amitai and Ben-Ari, Nir and Talmon, Ronen and Shaham, Uri},
journal={Advances in Neural Information Processing Systems},
year={2025}
}