From 4aed370c63093ad058430660822daa86b77557fa Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Tue, 18 Feb 2025 15:52:32 +0100 Subject: [PATCH 01/86] test --- .../loss/gradcache_late_interaction_losses.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 colpali_engine/loss/gradcache_late_interaction_losses.py diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py new file mode 100644 index 000000000..a8c098c4c --- /dev/null +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from contextlib import nullcontext +from functools import partial +from collections.abc import Iterator +from typing import Any, Dict + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import get_device_states, set_device_states + +# ------------------------------------------------------------------------------ +# Utility: A context manager that saves/restores RNG state. +# ------------------------------------------------------------------------------ +class RandContext: + def __init__(self, *tensors) -> None: + self.fwd_cpu_state = torch.get_rng_state() + self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) + + def __enter__(self) -> None: + self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True) + self._fork.__enter__() + torch.set_rng_state(self.fwd_cpu_state) + set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self._fork.__exit__(exc_type, exc_val, exc_tb) + self._fork = None + +# ------------------------------------------------------------------------------ +# Backward hook for multi–input grad cache. +# ------------------------------------------------------------------------------ +def _backward_hook_multi(grad_output: Tensor, features_dict: Dict[str, dict], loss_obj: CachedColbertLossBase) -> None: + """ + For each input branch (e.g. 'query', 'doc', 'neg_doc'), re–run the forward pass + (with gradients) using the saved RNG contexts and then backpropagate the cached gradients. + """ + with torch.enable_grad(): + for key, feat in features_dict.items(): + cached_grads = loss_obj.cache[key] # list (one per mini–batch) + rand_states = loss_obj.random_states[key] + for (reps_mb, _), grad_mb, rand_state in zip( + loss_obj.embed_minibatch_iter(feat, with_grad=True, copy_random_state=False, random_states=rand_states), + cached_grads, + rand_states + ): + surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output + surrogate.backward() + +# ------------------------------------------------------------------------------ +# Base class that implements grad cache embedding passes. +# ------------------------------------------------------------------------------ +class CachedColbertLossBase(nn.Module): + def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: + """ + model: a SentenceTransformer–like model which, given a features dict, + returns a dict with key "sentence_embedding" of shape (bsz, num_tokens, dim) + """ + super().__init__() + self.model = model + self.mini_batch_size = mini_batch_size + self.show_progress_bar = show_progress_bar + # These will be dictionaries keyed by input type (e.g. 'query', 'doc', etc.) + self.cache: Dict[str, list[Tensor]] = {} + self.random_states: Dict[str, list[RandContext]] = {} + + def embed_minibatch( + self, + features: dict, + begin: int, + end: int, + with_grad: bool, + copy_random_state: bool, + random_state: RandContext | None = None, + ) -> tuple[Tensor, RandContext | None]: + """ + Run the model on a mini–batch of the features. + """ + grad_context = nullcontext if with_grad else torch.no_grad + random_state_context = nullcontext() if random_state is None else random_state + features_mb = {k: v[begin:end] for k, v in features.items()} + with random_state_context: + with grad_context(): + new_rand_state = RandContext(*features_mb.values()) if copy_random_state else None + # Expect model(features) returns a dict with key "sentence_embedding" + reps = self.model(features_mb)["sentence_embedding"] + return reps, new_rand_state + + def embed_minibatch_iter( + self, + features: dict, + with_grad: bool, + copy_random_state: bool, + random_states: list[RandContext] | None = None, + ) -> Iterator[tuple[Tensor, RandContext | None]]: + input_ids: Tensor = features["input_ids"] + bsz = input_ids.shape[0] + for i in range(0, bsz, self.mini_batch_size): + e = i + self.mini_batch_size + reps, new_rand_state = self.embed_minibatch( + features=features, + begin=i, + end=e, + with_grad=with_grad, + copy_random_state=copy_random_state, + random_state=None if random_states is None else random_states[i], + ) + yield reps, new_rand_state + + def _embed_all(self, features: dict) -> tuple[list[Tensor], list[RandContext]]: + reps_list = [] + rand_state_list = [] + for reps_mb, rand_state in self.embed_minibatch_iter(features, with_grad=False, copy_random_state=True): + # Detach and mark for gradient in the second pass. + reps_list.append(reps_mb.detach().requires_grad_()) + rand_state_list.append(rand_state) + return reps_list, rand_state_list + + def _aggregate_embeddings(self, reps: list[Tensor]) -> Tensor: + return torch.cat(reps, dim=0) + + def forward(self, **features: dict) -> Tensor: + """ + Expects keyword–arguments for each input branch. + For example: + forward(query=..., doc=...) + or forward(query=..., doc=..., neg_doc=...) + Each input is a features dict (with keys like "input_ids"). + """ + reps: dict[str, list[Tensor]] = {} + rand_states: dict[str, list[RandContext]] = {} + for key, feat in features.items(): + reps[key], rand_states[key] = self._embed_all(feat) + self.random_states = rand_states + + if torch.is_grad_enabled(): + loss = self._compute_loss_and_cache_gradients(**reps) + loss.register_hook(partial(_backward_hook_multi, features_dict=features, loss_obj=self)) + else: + agg = {key: self._aggregate_embeddings(reps[key]) for key in reps} + if "neg_doc" in agg: + loss = self._compute_loss(agg["query"], agg["doc"], agg["neg_doc"], with_backward=False) + else: + loss = self._compute_loss(agg["query"], agg["doc"], with_backward=False) + return loss + + # The following two methods are meant to be implemented by subclasses: + def _compute_loss(self, *args, **kwargs) -> Tensor: + raise NotImplementedError + + def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: + # In our subclasses we first aggregate the mini–batches and then compute the loss in a mini–batch loop. + raise NotImplementedError + +# ------------------------------------------------------------------------------ +# Cached ColBERT Loss (simple cross–entropy over scores) +# ------------------------------------------------------------------------------ +class CachedColbertLoss(CachedColbertLossBase): + def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: + super().__init__(model, mini_batch_size, show_progress_bar) + self.ce_loss = CrossEntropyLoss() + + def _compute_loss(self, query: Tensor, doc: Tensor, with_backward: bool = False) -> Tensor: + # query: (B, Nq, D), doc: (B, Nd, D) + batch_size = query.shape[0] + # Compute scores: + # scores: (B, B, Nq, Nd) + scores = torch.einsum("bnd,csd->bcns", query, doc) + # For each query-document pair, take the max over document tokens then sum over query tokens. + scores = scores.max(dim=3)[0].sum(dim=2) # shape: (B, B) + labels = torch.arange(batch_size, device=query.device) + loss_total = 0.0 + for i in range(0, batch_size, self.mini_batch_size): + j = i + self.mini_batch_size + scores_mbatch = scores[i:j] # (mb, B) + loss_mbatch = self.ce_loss(scores_mbatch, labels[i:j]) + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() + loss_total = loss_total + loss_mbatch * (scores_mbatch.shape[0] / batch_size) + return loss_total + + def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: + agg_query = self._aggregate_embeddings(reps["query"]) + agg_doc = self._aggregate_embeddings(reps["doc"]) + loss = self._compute_loss(agg_query, agg_doc, with_backward=True) + return loss.detach().requires_grad_() + +# ------------------------------------------------------------------------------ +# Cached ColBERT Pairwise CE Loss +# ------------------------------------------------------------------------------ +class CachedColbertPairwiseCELoss(CachedColbertLossBase): + def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: + super().__init__(model, mini_batch_size, show_progress_bar) + self.ce_loss = CrossEntropyLoss() + + def _compute_loss(self, query: Tensor, doc: Tensor, with_backward: bool = False) -> Tensor: + batch_size = query.shape[0] + scores = torch.einsum("bnd,csd->bcns", query, doc) + scores = scores.max(dim=3)[0].sum(dim=2) # (B, B) + pos_scores = scores.diagonal() # (B,) + mask = torch.eye(batch_size, device=scores.device) * 1e6 + neg_scores = (scores - mask).max(dim=1)[0] + loss_total = F.softplus(neg_scores - pos_scores).mean() + if with_backward: + loss_total.backward() + loss_total = loss_total.detach() + return loss_total + + def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: + agg_query = self._aggregate_embeddings(reps["query"]) + agg_doc = self._aggregate_embeddings(reps["doc"]) + loss = self._compute_loss(agg_query, agg_doc, with_backward=True) + return loss.detach().requires_grad_() + +# ------------------------------------------------------------------------------ +# Cached ColBERT Pairwise Negative CE Loss +# ------------------------------------------------------------------------------ +class CachedColbertPairwiseNegativeCELoss(CachedColbertLossBase): + def __init__(self, model: nn.Module, in_batch_term: bool = False, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: + super().__init__(model, mini_batch_size, show_progress_bar) + self.in_batch_term = in_batch_term + + def _compute_loss(self, query: Tensor, doc: Tensor, neg_doc: Tensor, with_backward: bool = False) -> Tensor: + # Compute positive and negative scores using token-level max and sum. + pos_scores = torch.einsum("bnd,bsd->bns", query, doc).max(dim=2)[0].sum(dim=1) + neg_scores = torch.einsum("bnd,bsd->bns", query, neg_doc).max(dim=2)[0].sum(dim=1) + loss_total = F.softplus(neg_scores - pos_scores).mean() + if self.in_batch_term: + scores = torch.einsum("bnd,csd->bcns", query, doc) + scores = scores.max(dim=3)[0].sum(dim=2) + pos_scores_ib = scores.diagonal() + mask = torch.eye(scores.shape[0], device=scores.device) * 1e6 + neg_scores_ib = (scores - mask).max(dim=1)[0] + loss_total = loss_total + F.softplus(neg_scores_ib - pos_scores_ib).mean() + loss_total = loss_total / 2 + if with_backward: + loss_total.backward() + loss_total = loss_total.detach() + return loss_total + + def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: + agg_query = self._aggregate_embeddings(reps["query"]) + agg_doc = self._aggregate_embeddings(reps["doc"]) + agg_neg_doc = self._aggregate_embeddings(reps["neg_doc"]) + loss = self._compute_loss(agg_query, agg_doc, agg_neg_doc, with_backward=True) + return loss.detach().requires_grad_() From 2151d53aae70a3d07a1e8718e878952c4a1f61f7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Tue, 18 Feb 2025 16:05:01 +0100 Subject: [PATCH 02/86] fff --- colpali_engine/loss/__init__.py | 7 ++ .../qwen2/train_colqwen2_gradcache_model.yaml | 72 +++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index 1e08318bb..db74d099b 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -3,6 +3,13 @@ BiPairwiseCELoss, BiPairwiseNegativeCELoss, ) + +from .gradcache_late_interaction_losses import ( + CachedColbertLoss, + CachedColbertPairwiseCELoss, + CachedColbertPairwiseNegativeCELoss, +) + from .late_interaction_losses import ( ColbertLoss, ColbertPairwiseCELoss, diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml new file mode 100644 index 000000000..cc4d48cb5 --- /dev/null +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -0,0 +1,72 @@ +config: + (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig + output_dir: !path ../../../models/colqwen2-gradcache + processor: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor + pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" + # num_image_tokens: 2048 + # max_length: 50 + + model: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2 + pretrained_model_name_or_path: "./models/colqwen2_base" + torch_dtype: !ext torch.bfloat16 + use_cache: false + attn_implementation: "flash_attention_2" +# device_map: "auto" +# quantization_config: +# (): transformers.BitsAndBytesConfig +# load_in_4bit: true +# bnb_4bit_quant_type: "nf4" +# bnb_4bit_compute_dtype: "bfloat16" +# bnb_4bit_use_double_quant: true + + dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set + eval_dataset_loader: !import ../data/test_data.yaml + + # max_length: 50 + run_eval: true + loss_func: + (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss + tr_args: + (): transformers.training_args.TrainingArguments + output_dir: null + overwrite_output_dir: true + num_train_epochs: 1 + per_device_train_batch_size: 512 + gradient_checkpointing: true + gradient_checkpointing_kwargs: { "use_reentrant": false } + # gradient_checkpointing: true + # 6 x 8 gpus = 48 batch size + # gradient_accumulation_steps: 4 + per_device_eval_batch_size: 16 + eval_strategy: "steps" + dataloader_num_workers: 8 + # bf16: true + save_steps: 500 + logging_steps: 10 + eval_steps: 100 + warmup_steps: 100 + learning_rate: 5e-4 + save_total_limit: 1 + # resume_from_checkpoint: true + # optim: "paged_adamw_8bit" + # wandb logging + # wandb_project: "colqwen2" + # run_name: "colqwen2-ba32-nolora" + report_to: "wandb" + + + peft_config: + (): peft.LoraConfig + r: 32 + lora_alpha: 32 + lora_dropout: 0.1 + init_lora_weights: "gaussian" + bias: "none" + task_type: "FEATURE_EXTRACTION" + target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + From 68f8ff679499ac74fd06ff5730195a6bf5275274 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 11:11:00 +0100 Subject: [PATCH 03/86] gg --- .../loss/gradcache_late_interaction_losses.py | 342 +++++++----------- .../qwen2/colqwen2/modeling_colqwen2.py | 50 +-- colpali_engine/trainer/contrastive_trainer.py | 24 +- .../qwen2/train_colqwen2_gradacc_model.yaml | 2 +- 4 files changed, 168 insertions(+), 250 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index a8c098c4c..c02717605 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -1,249 +1,161 @@ -from __future__ import annotations - -from contextlib import nullcontext +import torch.nn as nn from functools import partial -from collections.abc import Iterator -from typing import Any, Dict +import tqdm + -import torch -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import get_device_states, set_device_states +import torch + -# ------------------------------------------------------------------------------ -# Utility: A context manager that saves/restores RNG state. -# ------------------------------------------------------------------------------ class RandContext: + """ + Random-state context manager that captures both CPU and GPU random states. + This ensures that when re‑executing a forward pass (e.g. in GradCache’s second pass), + stochastic operations produce identical outputs. + """ def __init__(self, *tensors) -> None: + # Capture CPU RNG state. self.fwd_cpu_state = torch.get_rng_state() + # Capture GPU states for all devices associated with the provided tensors. self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) def __enter__(self) -> None: + # Fork the RNG states on the captured devices. self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True) self._fork.__enter__() + # Reset the CPU RNG state. torch.set_rng_state(self.fwd_cpu_state) + # Reset the GPU RNG states. set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) def __exit__(self, exc_type, exc_val, exc_tb) -> None: self._fork.__exit__(exc_type, exc_val, exc_tb) self._fork = None -# ------------------------------------------------------------------------------ -# Backward hook for multi–input grad cache. -# ------------------------------------------------------------------------------ -def _backward_hook_multi(grad_output: Tensor, features_dict: Dict[str, dict], loss_obj: CachedColbertLossBase) -> None: + + +def _backward_hook(grad_output, sentence_features, random_states, loss_obj, model): """ - For each input branch (e.g. 'query', 'doc', 'neg_doc'), re–run the forward pass - (with gradients) using the saved RNG contexts and then backpropagate the cached gradients. + Backward hook that re-computes the embeddings in mini-batches with gradients enabled + and uses the cached gradients to backpropagate. This version wraps the forward pass in the + corresponding RandContext to reproduce the same randomness. """ - with torch.enable_grad(): - for key, feat in features_dict.items(): - cached_grads = loss_obj.cache[key] # list (one per mini–batch) - rand_states = loss_obj.random_states[key] - for (reps_mb, _), grad_mb, rand_state in zip( - loss_obj.embed_minibatch_iter(feat, with_grad=True, copy_random_state=False, random_states=rand_states), - cached_grads, - rand_states - ): - surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output - surrogate.backward() - -# ------------------------------------------------------------------------------ -# Base class that implements grad cache embedding passes. -# ------------------------------------------------------------------------------ -class CachedColbertLossBase(nn.Module): - def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: + mini_batch_size = loss_obj.mini_batch_size + # sentence_features: a list with two dicts [query_features, doc_features] + # random_states: a list with two lists of RandContext objects. + for branch_feature, branch_cache, branch_random_states in zip(sentence_features, random_states): + input_ids = branch_feature["input_ids"] + bsz = input_ids.size(0) + # Iterate over mini-batches. + for idx, start in enumerate(range(0, bsz, mini_batch_size)): + end = start + mini_batch_size + mini_feature = {k: v[start:end] for k, v in branch_feature.items()} + # Use the stored RandContext if available. + r_state = branch_random_states[idx] + if r_state is not None: + with r_state: + mini_embeds = model.inner_forward(**mini_feature) + else: + mini_embeds = model.inner_forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) + cached_grad = branch_cache[idx] + # Compute a surrogate loss that replays the cached gradient. + surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output + surrogate.backward() + + +class GradCacheColbertLoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, scale: float = 1.0, show_progress_bar: bool = False): """ - model: a SentenceTransformer–like model which, given a features dict, - returns a dict with key "sentence_embedding" of shape (bsz, num_tokens, dim) + GradCache enabled version of the ColBERT loss. + + Args: + mini_batch_size: Number of items per mini-batch. + scale: Scaling factor for the similarity scores. + show_progress_bar: If True, shows progress bars during mini-batch processing. """ super().__init__() - self.model = model self.mini_batch_size = mini_batch_size + self.scale = scale + self.ce_loss = nn.CrossEntropyLoss() + self.cache = None + self.random_states = None self.show_progress_bar = show_progress_bar - # These will be dictionaries keyed by input type (e.g. 'query', 'doc', etc.) - self.cache: Dict[str, list[Tensor]] = {} - self.random_states: Dict[str, list[RandContext]] = {} - - def embed_minibatch( - self, - features: dict, - begin: int, - end: int, - with_grad: bool, - copy_random_state: bool, - random_state: RandContext | None = None, - ) -> tuple[Tensor, RandContext | None]: + self.gradcache_enabled = True # Flag indicating GradCache is active. + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = None + if copy_random_state: + random_state = RandContext(*mini_feature.values()) + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.inner_forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(with_grad) + yield mini_embeds, random_state + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: """ - Run the model on a mini–batch of the features. + Calculate the ColBERT-style loss. + reps: list with two elements – reps[0] for query embeddings, reps[1] for doc embeddings. + Each element is a list of mini-batch tensors. """ - grad_context = nullcontext if with_grad else torch.no_grad - random_state_context = nullcontext() if random_state is None else random_state - features_mb = {k: v[begin:end] for k, v in features.items()} - with random_state_context: - with grad_context(): - new_rand_state = RandContext(*features_mb.values()) if copy_random_state else None - # Expect model(features) returns a dict with key "sentence_embedding" - reps = self.model(features_mb)["sentence_embedding"] - return reps, new_rand_state - - def embed_minibatch_iter( - self, - features: dict, - with_grad: bool, - copy_random_state: bool, - random_states: list[RandContext] | None = None, - ) -> Iterator[tuple[Tensor, RandContext | None]]: - input_ids: Tensor = features["input_ids"] - bsz = input_ids.shape[0] - for i in range(0, bsz, self.mini_batch_size): - e = i + self.mini_batch_size - reps, new_rand_state = self.embed_minibatch( - features=features, - begin=i, - end=e, - with_grad=with_grad, - copy_random_state=copy_random_state, - random_state=None if random_states is None else random_states[i], - ) - yield reps, new_rand_state - - def _embed_all(self, features: dict) -> tuple[list[Tensor], list[RandContext]]: - reps_list = [] - rand_state_list = [] - for reps_mb, rand_state in self.embed_minibatch_iter(features, with_grad=False, copy_random_state=True): - # Detach and mark for gradient in the second pass. - reps_list.append(reps_mb.detach().requires_grad_()) - rand_state_list.append(rand_state) - return reps_list, rand_state_list - - def _aggregate_embeddings(self, reps: list[Tensor]) -> Tensor: - return torch.cat(reps, dim=0) - - def forward(self, **features: dict) -> Tensor: + embeddings_query = torch.cat(reps[0], dim=0) # shape: (total_query, seq_len, dim) + embeddings_doc = torch.cat(reps[1], dim=0) # shape: (total_doc, seq_len, dim) + scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc).max(dim=3)[0].sum(dim=2) + batch_size = scores.size(0) + labels = torch.arange(batch_size, device=scores.device) + loss = self.ce_loss(scores * self.scale, labels) + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + # Cache gradients for each mini-batch. + self.cache = [] + for branch in reps: + branch_cache = [] + for r in branch: + branch_cache.append(r.grad) + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: """ - Expects keyword–arguments for each input branch. - For example: - forward(query=..., doc=...) - or forward(query=..., doc=..., neg_doc=...) - Each input is a features dict (with keys like "input_ids"). + inputs: dict containing keys with prefixes "query_" and "doc_". """ - reps: dict[str, list[Tensor]] = {} - rand_states: dict[str, list[RandContext]] = {} - for key, feat in features.items(): - reps[key], rand_states[key] = self._embed_all(feat) - self.random_states = rand_states + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + + # === First Pass: Get embeddings without gradients, capturing RandContext. + reps_query = [] + rs_query = [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc = [] + rs_doc = [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps = [reps_query, reps_doc] + self.random_states = [rs_query, rs_doc] if torch.is_grad_enabled(): - loss = self._compute_loss_and_cache_gradients(**reps) - loss.register_hook(partial(_backward_hook_multi, features_dict=features, loss_obj=self)) + # Step (2): Compute loss and cache gradients. + loss = self.calculate_loss_and_cache_gradients(reps) + # Step (3): Re-run embeddings with gradients enabled and register a backward hook that uses the cached gradients. + loss.register_hook(partial(_backward_hook, sentence_features=[query_features, doc_features], + random_states=self.random_states, loss_obj=self, model=model)) else: - agg = {key: self._aggregate_embeddings(reps[key]) for key in reps} - if "neg_doc" in agg: - loss = self._compute_loss(agg["query"], agg["doc"], agg["neg_doc"], with_backward=False) - else: - loss = self._compute_loss(agg["query"], agg["doc"], with_backward=False) + loss = self.calculate_loss(reps, with_backward=False) return loss - - # The following two methods are meant to be implemented by subclasses: - def _compute_loss(self, *args, **kwargs) -> Tensor: - raise NotImplementedError - - def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: - # In our subclasses we first aggregate the mini–batches and then compute the loss in a mini–batch loop. - raise NotImplementedError - -# ------------------------------------------------------------------------------ -# Cached ColBERT Loss (simple cross–entropy over scores) -# ------------------------------------------------------------------------------ -class CachedColbertLoss(CachedColbertLossBase): - def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: - super().__init__(model, mini_batch_size, show_progress_bar) - self.ce_loss = CrossEntropyLoss() - - def _compute_loss(self, query: Tensor, doc: Tensor, with_backward: bool = False) -> Tensor: - # query: (B, Nq, D), doc: (B, Nd, D) - batch_size = query.shape[0] - # Compute scores: - # scores: (B, B, Nq, Nd) - scores = torch.einsum("bnd,csd->bcns", query, doc) - # For each query-document pair, take the max over document tokens then sum over query tokens. - scores = scores.max(dim=3)[0].sum(dim=2) # shape: (B, B) - labels = torch.arange(batch_size, device=query.device) - loss_total = 0.0 - for i in range(0, batch_size, self.mini_batch_size): - j = i + self.mini_batch_size - scores_mbatch = scores[i:j] # (mb, B) - loss_mbatch = self.ce_loss(scores_mbatch, labels[i:j]) - if with_backward: - loss_mbatch.backward() - loss_mbatch = loss_mbatch.detach() - loss_total = loss_total + loss_mbatch * (scores_mbatch.shape[0] / batch_size) - return loss_total - - def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: - agg_query = self._aggregate_embeddings(reps["query"]) - agg_doc = self._aggregate_embeddings(reps["doc"]) - loss = self._compute_loss(agg_query, agg_doc, with_backward=True) - return loss.detach().requires_grad_() - -# ------------------------------------------------------------------------------ -# Cached ColBERT Pairwise CE Loss -# ------------------------------------------------------------------------------ -class CachedColbertPairwiseCELoss(CachedColbertLossBase): - def __init__(self, model: nn.Module, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: - super().__init__(model, mini_batch_size, show_progress_bar) - self.ce_loss = CrossEntropyLoss() - - def _compute_loss(self, query: Tensor, doc: Tensor, with_backward: bool = False) -> Tensor: - batch_size = query.shape[0] - scores = torch.einsum("bnd,csd->bcns", query, doc) - scores = scores.max(dim=3)[0].sum(dim=2) # (B, B) - pos_scores = scores.diagonal() # (B,) - mask = torch.eye(batch_size, device=scores.device) * 1e6 - neg_scores = (scores - mask).max(dim=1)[0] - loss_total = F.softplus(neg_scores - pos_scores).mean() - if with_backward: - loss_total.backward() - loss_total = loss_total.detach() - return loss_total - - def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: - agg_query = self._aggregate_embeddings(reps["query"]) - agg_doc = self._aggregate_embeddings(reps["doc"]) - loss = self._compute_loss(agg_query, agg_doc, with_backward=True) - return loss.detach().requires_grad_() - -# ------------------------------------------------------------------------------ -# Cached ColBERT Pairwise Negative CE Loss -# ------------------------------------------------------------------------------ -class CachedColbertPairwiseNegativeCELoss(CachedColbertLossBase): - def __init__(self, model: nn.Module, in_batch_term: bool = False, mini_batch_size: int = 32, show_progress_bar: bool = False) -> None: - super().__init__(model, mini_batch_size, show_progress_bar) - self.in_batch_term = in_batch_term - - def _compute_loss(self, query: Tensor, doc: Tensor, neg_doc: Tensor, with_backward: bool = False) -> Tensor: - # Compute positive and negative scores using token-level max and sum. - pos_scores = torch.einsum("bnd,bsd->bns", query, doc).max(dim=2)[0].sum(dim=1) - neg_scores = torch.einsum("bnd,bsd->bns", query, neg_doc).max(dim=2)[0].sum(dim=1) - loss_total = F.softplus(neg_scores - pos_scores).mean() - if self.in_batch_term: - scores = torch.einsum("bnd,csd->bcns", query, doc) - scores = scores.max(dim=3)[0].sum(dim=2) - pos_scores_ib = scores.diagonal() - mask = torch.eye(scores.shape[0], device=scores.device) * 1e6 - neg_scores_ib = (scores - mask).max(dim=1)[0] - loss_total = loss_total + F.softplus(neg_scores_ib - pos_scores_ib).mean() - loss_total = loss_total / 2 - if with_backward: - loss_total.backward() - loss_total = loss_total.detach() - return loss_total - - def _compute_loss_and_cache_gradients(self, **reps: list[Tensor]) -> Tensor: - agg_query = self._aggregate_embeddings(reps["query"]) - agg_doc = self._aggregate_embeddings(reps["doc"]) - agg_neg_doc = self._aggregate_embeddings(reps["neg_doc"]) - loss = self._compute_loss(agg_query, agg_doc, agg_neg_doc, with_backward=True) - return loss.detach().requires_grad_() diff --git a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py index 6b43d13b5..29abaa805 100644 --- a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py @@ -41,21 +41,22 @@ def from_pretrained( ) def inner_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: + if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: @@ -90,11 +91,13 @@ def inner_forward( hidden_states = outputs[0] return hidden_states + + def forward(self, *args, **kwargs) -> torch.Tensor: # Delete output_hidden_states from kwargs kwargs.pop("output_hidden_states", None) - # The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs + # Hack to make sure scatter in DDP is done correctly on multiple GPUs. if "pixel_values" in kwargs: # compute pixel_values offsets offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] @@ -109,15 +112,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor: video_grid_thw=None, attention_mask=kwargs.get("attention_mask", None), ) - last_hidden_states = self.inner_forward( - *args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True - ) # (batch_size, sequence_length, hidden_size) - - proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) - - # L2 normalization - proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) - proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) + last_hidden_states = self.inner_forward(*args, + **kwargs, + position_ids=position_ids, + use_cache=False, + output_hidden_states=True) + proj = self.custom_text_proj(last_hidden_states) + proj = proj / proj.norm(dim=-1, keepdim=True) # L2 normalization + proj = proj * kwargs["attention_mask"].unsqueeze(-1) return proj @property diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index abb479b2c..3e24d4251 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -1,7 +1,6 @@ import torch from transformers import Trainer - class ContrastiveTrainer(Trainer): def __init__(self, loss_func, is_vision_model, *args, **kwargs): super().__init__(*args, **kwargs) @@ -9,16 +8,21 @@ def __init__(self, loss_func, is_vision_model, *args, **kwargs): self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) - # feed only kwargs with 'doc_' prefix - doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - if "neg_doc_input_ids" in inputs: - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss + # If the loss function supports gradcache, delegate the computation. + if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: + loss = self.loss_func(model, inputs) + return (loss, None) if return_outputs else loss + else: + query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + # feed only kwargs with 'doc_' prefix + doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) + if "neg_doc_input_ids" in inputs: + neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) + return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss - loss = self.loss_func(query_outputs, doc_outputs) - return (loss, (query_outputs, doc_outputs)) if return_outputs else loss + loss = self.loss_func(query_outputs, doc_outputs) + return (loss, (query_outputs, doc_outputs)) if return_outputs else loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True): """This function is used to generate predictions and return the loss for the given inputs.""" diff --git a/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml index 5ed04cae6..c83b6a492 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml @@ -28,7 +28,7 @@ config: run_eval: true loss_func: - (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss + (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss tr_args: (): transformers.training_args.TrainingArguments output_dir: null From 2079dd35ce5dc20480fb827eb55b564bdb6f9b4b Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 11:26:36 +0100 Subject: [PATCH 04/86] lesgo --- scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml | 2 +- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml index c83b6a492..5ed04cae6 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradacc_model.yaml @@ -28,7 +28,7 @@ config: run_eval: true loss_func: - (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss + (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss tr_args: (): transformers.training_args.TrainingArguments output_dir: null diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index cc4d48cb5..a3c7cf715 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -29,7 +29,7 @@ config: # max_length: 50 run_eval: true loss_func: - (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss + (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss tr_args: (): transformers.training_args.TrainingArguments output_dir: null From 2a2605b331c6204556f91e13b2e1f07a4071234e Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 11:40:41 +0100 Subject: [PATCH 05/86] fix --- colpali_engine/loss/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index db74d099b..4a8fd1209 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -5,9 +5,7 @@ ) from .gradcache_late_interaction_losses import ( - CachedColbertLoss, - CachedColbertPairwiseCELoss, - CachedColbertPairwiseNegativeCELoss, + GradCacheColbertLoss ) from .late_interaction_losses import ( From 89e9c720cc62e7b4727620a7b42f679d91f130bd Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 11:50:42 +0100 Subject: [PATCH 06/86] etst --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index a3c7cf715..ace556fad 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From a4ae460b39235d157171479115ac1bf1cb2c3129 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 15:04:26 +0100 Subject: [PATCH 07/86] fr --- colpali_engine/trainer/contrastive_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 3e24d4251..2b26b5ddd 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -10,6 +10,7 @@ def __init__(self, loss_func, is_vision_model, *args, **kwargs): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # If the loss function supports gradcache, delegate the computation. if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: + breakpoint() loss = self.loss_func(model, inputs) return (loss, None) if return_outputs else loss else: From 82894566ded2644ec9c8ad54aa6a4a4cb38f4705 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 15:39:01 +0100 Subject: [PATCH 08/86] brr --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 + colpali_engine/trainer/contrastive_trainer.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index c02717605..f25595efe 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -88,6 +88,7 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c bsz = input_ids.size(0) for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", disable=not self.show_progress_bar): + breakpoint() end = start + self.mini_batch_size mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} random_state = None diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 2b26b5ddd..3e24d4251 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -10,7 +10,6 @@ def __init__(self, loss_func, is_vision_model, *args, **kwargs): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # If the loss function supports gradcache, delegate the computation. if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: - breakpoint() loss = self.loss_func(model, inputs) return (loss, None) if return_outputs else loss else: From 96064d924aaf86e1a546c10cd8e631076d96f9b7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 15:51:52 +0100 Subject: [PATCH 09/86] debugf --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index f25595efe..23869d29b 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -88,7 +88,6 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c bsz = input_ids.size(0) for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", disable=not self.show_progress_bar): - breakpoint() end = start + self.mini_batch_size mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} random_state = None @@ -96,6 +95,7 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c random_state = RandContext(*mini_feature.values()) grad_context = torch.enable_grad() if with_grad else torch.no_grad() with grad_context: + breakpoint() mini_embeds = model.inner_forward(**mini_feature) mini_embeds = mini_embeds.detach().requires_grad_(with_grad) yield mini_embeds, random_state diff --git a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py index 29abaa805..7ddade3e0 100644 --- a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py @@ -57,6 +57,7 @@ def inner_forward( video_grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: + breakpoint() if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: From 3b28dd48cee023189f6f77c70c6caf533cec66ee Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 15:57:50 +0100 Subject: [PATCH 10/86] fic --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 23869d29b..061afa46f 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -96,7 +96,7 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c grad_context = torch.enable_grad() if with_grad else torch.no_grad() with grad_context: breakpoint() - mini_embeds = model.inner_forward(**mini_feature) + mini_embeds = model.forward(**mini_feature) mini_embeds = mini_embeds.detach().requires_grad_(with_grad) yield mini_embeds, random_state From df26bd6989c1041c00593ef9d1815822e7a876a7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:01:49 +0100 Subject: [PATCH 11/86] etst --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 - colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py | 1 - 2 files changed, 2 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 061afa46f..7c00e0c48 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -95,7 +95,6 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c random_state = RandContext(*mini_feature.values()) grad_context = torch.enable_grad() if with_grad else torch.no_grad() with grad_context: - breakpoint() mini_embeds = model.forward(**mini_feature) mini_embeds = mini_embeds.detach().requires_grad_(with_grad) yield mini_embeds, random_state diff --git a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py index 7ddade3e0..29abaa805 100644 --- a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py @@ -57,7 +57,6 @@ def inner_forward( video_grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: - breakpoint() if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: From 12df70928a9c483f33fa95b0ef4f8d60dccdc2a4 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:10:18 +0100 Subject: [PATCH 12/86] debug --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 7c00e0c48..af090c9a4 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -110,6 +110,7 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc).max(dim=3)[0].sum(dim=2) batch_size = scores.size(0) labels = torch.arange(batch_size, device=scores.device) + breakpoint() loss = self.ce_loss(scores * self.scale, labels) if with_backward: loss.backward() From 33aadab8f9089813dcba6cc67cd7acf44fff6a95 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:27:27 +0100 Subject: [PATCH 13/86] etst --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index af090c9a4..0dcd4b7db 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -96,7 +96,7 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c grad_context = torch.enable_grad() if with_grad else torch.no_grad() with grad_context: mini_embeds = model.forward(**mini_feature) - mini_embeds = mini_embeds.detach().requires_grad_(with_grad) + mini_embeds = mini_embeds.detach().requires_grad_(True) # is this the key ? yield mini_embeds, random_state def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: From 318c02759088f08b990e565bee20b3b703a01938 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:29:47 +0100 Subject: [PATCH 14/86] test --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 0dcd4b7db..b01f7b27f 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -110,7 +110,6 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc).max(dim=3)[0].sum(dim=2) batch_size = scores.size(0) labels = torch.arange(batch_size, device=scores.device) - breakpoint() loss = self.ce_loss(scores * self.scale, labels) if with_backward: loss.backward() From 6827d2d2de052c3bff72574c11a7e54fe72ab56c Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:43:57 +0100 Subject: [PATCH 15/86] fix --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index b01f7b27f..f0fd2a725 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -43,7 +43,7 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode mini_batch_size = loss_obj.mini_batch_size # sentence_features: a list with two dicts [query_features, doc_features] # random_states: a list with two lists of RandContext objects. - for branch_feature, branch_cache, branch_random_states in zip(sentence_features, random_states): + for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): input_ids = branch_feature["input_ids"] bsz = input_ids.size(0) # Iterate over mini-batches. From 6bb085b4b8ee7ad748ac53a879f90832e23619e2 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:47:10 +0100 Subject: [PATCH 16/86] test --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index f0fd2a725..86f401ee0 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -43,9 +43,11 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode mini_batch_size = loss_obj.mini_batch_size # sentence_features: a list with two dicts [query_features, doc_features] # random_states: a list with two lists of RandContext objects. + breakpoint() for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): input_ids = branch_feature["input_ids"] bsz = input_ids.size(0) + breakpoint() # Iterate over mini-batches. for idx, start in enumerate(range(0, bsz, mini_batch_size)): end = start + mini_batch_size From 12b4df485c87285d46aeb1dd72bd90fbe6ba8fe0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 16:56:19 +0100 Subject: [PATCH 17/86] fff --- colpali_engine/loss/gradcache_late_interaction_losses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 86f401ee0..3745d52b8 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -43,11 +43,9 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode mini_batch_size = loss_obj.mini_batch_size # sentence_features: a list with two dicts [query_features, doc_features] # random_states: a list with two lists of RandContext objects. - breakpoint() for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): input_ids = branch_feature["input_ids"] bsz = input_ids.size(0) - breakpoint() # Iterate over mini-batches. for idx, start in enumerate(range(0, bsz, mini_batch_size)): end = start + mini_batch_size @@ -59,6 +57,7 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode mini_embeds = model.inner_forward(**mini_feature) else: mini_embeds = model.inner_forward(**mini_feature) + breakpoint() mini_embeds = mini_embeds.detach().requires_grad_(True) cached_grad = branch_cache[idx] # Compute a surrogate loss that replays the cached gradient. From e48842660e29565c64eca3430559b86590af4830 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 17:02:24 +0100 Subject: [PATCH 18/86] fix --- colpali_engine/loss/gradcache_late_interaction_losses.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 3745d52b8..2aaaba8af 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -54,10 +54,9 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode r_state = branch_random_states[idx] if r_state is not None: with r_state: - mini_embeds = model.inner_forward(**mini_feature) + mini_embeds = model.forward(**mini_feature) else: - mini_embeds = model.inner_forward(**mini_feature) - breakpoint() + mini_embeds = model.forward(**mini_feature) mini_embeds = mini_embeds.detach().requires_grad_(True) cached_grad = branch_cache[idx] # Compute a surrogate loss that replays the cached gradient. From 05f2d494aea6b54cdf8718fa4cdd8135b29284e0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 17:14:22 +0100 Subject: [PATCH 19/86] test --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 2aaaba8af..ff8600832 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -60,6 +60,7 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode mini_embeds = mini_embeds.detach().requires_grad_(True) cached_grad = branch_cache[idx] # Compute a surrogate loss that replays the cached gradient. + breakpoint() surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output surrogate.backward() From 2983cf898d35d733576f36c3743704e621608ee1 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 17:33:43 +0100 Subject: [PATCH 20/86] fff --- .../loss/gradcache_late_interaction_losses.py | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index ff8600832..758884920 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -42,27 +42,30 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode """ mini_batch_size = loss_obj.mini_batch_size # sentence_features: a list with two dicts [query_features, doc_features] - # random_states: a list with two lists of RandContext objects. - for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): - input_ids = branch_feature["input_ids"] - bsz = input_ids.size(0) - # Iterate over mini-batches. - for idx, start in enumerate(range(0, bsz, mini_batch_size)): - end = start + mini_batch_size - mini_feature = {k: v[start:end] for k, v in branch_feature.items()} - # Use the stored RandContext if available. - r_state = branch_random_states[idx] - if r_state is not None: - with r_state: + # random_states: a list with two lists of RandContext objects.1 + assert loss_obj.cache is not None + assert random_states is not None + with torch.enable_grad(): + for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): + input_ids = branch_feature["input_ids"] + bsz = input_ids.size(0) + # Iterate over mini-batches. + for idx, start in enumerate(range(0, bsz, mini_batch_size)): + end = start + mini_batch_size + mini_feature = {k: v[start:end] for k, v in branch_feature.items()} + # Use the stored RandContext if available. + r_state = branch_random_states[idx] + if r_state is not None: + with r_state: + mini_embeds = model.forward(**mini_feature) + else: mini_embeds = model.forward(**mini_feature) - else: - mini_embeds = model.forward(**mini_feature) - mini_embeds = mini_embeds.detach().requires_grad_(True) - cached_grad = branch_cache[idx] - # Compute a surrogate loss that replays the cached gradient. - breakpoint() - surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output - surrogate.backward() + # mini_embeds = mini_embeds.detach().requires_grad_(True) + cached_grad = branch_cache[idx] + # Compute a surrogate loss that replays the cached gradient. + breakpoint() + surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output + surrogate.backward() class GradCacheColbertLoss(nn.Module): From 92ffe94fdc22df44f743f19e54cc98f1914a646a Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 17:36:01 +0100 Subject: [PATCH 21/86] fff --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 758884920..492516ba4 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -63,7 +63,6 @@ def _backward_hook(grad_output, sentence_features, random_states, loss_obj, mode # mini_embeds = mini_embeds.detach().requires_grad_(True) cached_grad = branch_cache[idx] # Compute a surrogate loss that replays the cached gradient. - breakpoint() surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output surrogate.backward() From dd1aa137caff571386adca0bde2be03bea190e30 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 18:32:01 +0100 Subject: [PATCH 22/86] pred step --- colpali_engine/trainer/contrastive_trainer.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 3e24d4251..c38d64bdb 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -30,13 +30,15 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True) raise ValueError("prediction_step is only called with prediction_loss_only=True") with torch.no_grad(): - # feed only kwargs with 'doc_' prefix - doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) - if "neg_doc_input_ids" in inputs: - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return loss, None, None - - loss = self.loss_func(query_outputs, doc_outputs) + if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: + loss = self.loss_func(model, inputs) + else: + # feed only kwargs with 'doc_' prefix + doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) + query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + if "neg_doc_input_ids" in inputs: + neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) + return loss, None, None + loss = self.loss_func(query_outputs, doc_outputs) return loss, None, None From 7d553e37ec0cb7cf409242f9a1931b6931f4b9b1 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 18:32:24 +0100 Subject: [PATCH 23/86] tets --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index ace556fad..c4f03411e 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -47,7 +47,7 @@ config: # bf16: true save_steps: 500 logging_steps: 10 - eval_steps: 100 + eval_steps: 20 warmup_steps: 100 learning_rate: 5e-4 save_total_limit: 1 From ae6f978b11cf2c25cc77a21e577a47e4ec87a868 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 18:33:26 +0100 Subject: [PATCH 24/86] fff --- colpali_engine/loss/__init__.py | 6 +----- colpali_engine/loss/gradcache_late_interaction_losses.py | 9 ++++----- colpali_engine/trainer/contrastive_trainer.py | 4 +++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index 4a8fd1209..f8713d827 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -3,11 +3,7 @@ BiPairwiseCELoss, BiPairwiseNegativeCELoss, ) - -from .gradcache_late_interaction_losses import ( - GradCacheColbertLoss -) - +from .gradcache_late_interaction_losses import GradCacheColbertLoss from .late_interaction_losses import ( ColbertLoss, ColbertPairwiseCELoss, diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 492516ba4..88cbec99d 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -1,10 +1,9 @@ -import torch.nn as nn from functools import partial -import tqdm - -from torch.utils.checkpoint import get_device_states, set_device_states import torch +import torch.nn as nn +import tqdm +from torch.utils.checkpoint import get_device_states, set_device_states class RandContext: @@ -156,7 +155,7 @@ def forward(self, model, inputs: dict) -> torch.Tensor: if torch.is_grad_enabled(): # Step (2): Compute loss and cache gradients. loss = self.calculate_loss_and_cache_gradients(reps) - # Step (3): Re-run embeddings with gradients enabled and register a backward hook that uses the cached gradients. + # Step (3): Re-run embeddings with gradients enabled and register a hook that uses the cached gradients. loss.register_hook(partial(_backward_hook, sentence_features=[query_features, doc_features], random_states=self.random_states, loss_obj=self, model=model)) else: diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index c38d64bdb..7712c27a6 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -1,6 +1,7 @@ import torch from transformers import Trainer + class ContrastiveTrainer(Trainer): def __init__(self, loss_func, is_vision_model, *args, **kwargs): super().__init__(*args, **kwargs) @@ -35,7 +36,8 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True) else: # feed only kwargs with 'doc_' prefix doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + query_outputs = model(input_ids=inputs["query_input_ids"], + attention_mask=inputs["query_attention_mask"]) if "neg_doc_input_ids" in inputs: neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) From f4ca45d19482e6e9bcc65190472d46904ee5b374 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 18:40:20 +0100 Subject: [PATCH 25/86] test --- .../loss/gradcache_late_interaction_losses.py | 192 ++++++++++++++++++ 1 file changed, 192 insertions(+) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 88cbec99d..126241a6c 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F # noqa: N812 import tqdm from torch.utils.checkpoint import get_device_states, set_device_states @@ -161,3 +162,194 @@ def forward(self, model, inputs: dict) -> torch.Tensor: else: loss = self.calculate_loss(reps, with_backward=False) return loss + + + +class GradCacheColbertPairwiseCELoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, scale: float = 1.0, show_progress_bar: bool = False): + """ + GradCache-enabled version of the ColBERTPairwiseCELoss. + """ + super().__init__() + self.mini_batch_size = mini_batch_size + self.scale = scale + self.ce_loss = nn.CrossEntropyLoss() + self.cache = None + self.random_states = None + self.show_progress_bar = show_progress_bar + self.gradcache_enabled = True + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = RandContext(*mini_feature.values()) if copy_random_state else None + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) + yield mini_embeds, random_state + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + """ + Compute the ColBERTPairwiseCELoss using cached embeddings. + reps is a list with two elements: reps[0] for query embeddings and reps[1] for doc embeddings. + """ + embeddings_query = torch.cat(reps[0], dim=0) # shape: (batch, num_query_tokens, dim) + embeddings_doc = torch.cat(reps[1], dim=0) # shape: (batch, num_doc_tokens, dim) + scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ + .max(dim=3)[0].sum(dim=2) # (batch, batch) + pos_scores = scores.diagonal() + neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 + neg_scores = neg_scores.max(dim=1)[0] + loss = F.softplus(neg_scores - pos_scores).mean() + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + self.cache = [] + for branch in reps: + branch_cache = [r.grad for r in branch] + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + + # First pass: get embeddings without gradients (and capture RandContext). + reps_query, rs_query = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc, rs_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps = [reps_query, reps_doc] + self.random_states = [rs_query, rs_doc] + + if torch.is_grad_enabled(): + loss = self.calculate_loss_and_cache_gradients(reps) + loss.register_hook(partial(_backward_hook, + sentence_features=[query_features, doc_features], + random_states=self.random_states, + loss_obj=self, model=model)) + else: + loss = self.calculate_loss(reps, with_backward=False) + return loss + + + +class GradCacheColbertPairwiseNegativeCELoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, in_batch_term: bool = False, show_progress_bar: bool = False): + """ + GradCache-enabled version of the ColBERTPairwiseNegativeCELoss. + + Args: + in_batch_term: If True, includes an additional in-batch loss term. + """ + super().__init__() + self.mini_batch_size = mini_batch_size + self.in_batch_term = in_batch_term + self.cache = None + self.random_states = None + self.show_progress_bar = show_progress_bar + self.gradcache_enabled = True + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = RandContext(*mini_feature.values()) if copy_random_state else None + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) + yield mini_embeds, random_state + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + """ + Compute the ColBERTPairwiseNegativeCELoss. + reps is a list with three elements: + reps[0]: query embeddings, + reps[1]: positive doc embeddings, + reps[2]: negative doc embeddings. + """ + embeddings_query = torch.cat(reps[0], dim=0) # (batch, num_query_tokens, dim) + embeddings_doc = torch.cat(reps[1], dim=0) # (batch, num_doc_tokens, dim) + embeddings_neg_doc = torch.cat(reps[2], dim=0) # (batch, num_neg_doc_tokens, dim) + + # Compute scores for positive and negative documents. + pos_scores = torch.einsum("bnd,bsd->bns", embeddings_query, embeddings_doc) \ + .max(dim=2)[0].sum(dim=1) + neg_scores = torch.einsum("bnd,bsd->bns", embeddings_query, embeddings_neg_doc) \ + .max(dim=2)[0].sum(dim=1) + loss = F.softplus(neg_scores - pos_scores).mean() + + if self.in_batch_term: + scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ + .max(dim=3)[0].sum(dim=2) + pos_scores_in = scores.diagonal() + neg_scores_in = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 + neg_scores_in = neg_scores_in.max(dim=1)[0] + loss_in = F.softplus(neg_scores_in - pos_scores_in).mean() + loss = (loss + loss_in) / 2 + + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + self.cache = [] + for branch in reps: + branch_cache = [r.grad for r in branch] + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + neg_doc_features = {k.replace("neg_doc_", ""): v for k, v in inputs.items() if k.startswith("neg_doc_")} + + # First pass: get embeddings without gradients and capture RandContext. + reps_query, rs_query = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc, rs_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps_neg_doc, rs_neg_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, neg_doc_features, with_grad=False, + copy_random_state=True): + reps_neg_doc.append(mini_embeds) + rs_neg_doc.append(rs) + reps = [reps_query, reps_doc, reps_neg_doc] + self.random_states = [rs_query, rs_doc, rs_neg_doc] + + if torch.is_grad_enabled(): + loss = self.calculate_loss_and_cache_gradients(reps) + loss.register_hook(partial(_backward_hook, + sentence_features=[query_features, doc_features, neg_doc_features], + random_states=self.random_states, + loss_obj=self, model=model)) + else: + loss = self.calculate_loss(reps, with_backward=False) + return loss From b1e7d685cb90048b33cf86ac0e8d30991637573f Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 18:42:24 +0100 Subject: [PATCH 26/86] vfdg --- colpali_engine/loss/__init__.py | 6 +++++- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index f8713d827..4dc461473 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -3,7 +3,11 @@ BiPairwiseCELoss, BiPairwiseNegativeCELoss, ) -from .gradcache_late_interaction_losses import GradCacheColbertLoss +from .gradcache_late_interaction_losses import ( + GradCacheColbertLoss, + GradCacheColbertPairwiseCELoss, + GradCacheColbertPairwiseNegativeCELoss, +) from .late_interaction_losses import ( ColbertLoss, ColbertPairwiseCELoss, diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index c4f03411e..915eb8da9 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -29,13 +29,13 @@ config: # max_length: 50 run_eval: true loss_func: - (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss + (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss # GradCacheColbertPairwiseCELoss tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 128 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 964f96bfea755587a6534818d3f6c3e5faf3dafa Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 19:06:59 +0100 Subject: [PATCH 27/86] fgr --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 915eb8da9..ddaf2d81f 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -29,13 +29,13 @@ config: # max_length: 50 run_eval: true loss_func: - (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertLoss # GradCacheColbertPairwiseCELoss + (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 4c1e8f3c215104fd119d67634ed99b45f7a8dc0c Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:24:49 +0100 Subject: [PATCH 28/86] dataloadr --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index ddaf2d81f..879bc1322 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -43,7 +43,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 8 + dataloader_num_workers: 32 # bf16: true save_steps: 500 logging_steps: 10 From 88b657f8723a1efa608e4f1c6f9f2f93f7000e11 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:29:11 +0100 Subject: [PATCH 29/86] gg --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 879bc1322..fa9f3cf45 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 256 + per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From efe6fea83598e69fe6836045cac4e12e4d24b25f Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:33:18 +0100 Subject: [PATCH 30/86] fff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 ++ scripts/configs/qwen2/train_colqwen2_model.yaml | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index fa9f3cf45..22090d0a5 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -56,6 +56,8 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" + dataloader_pin_memory: true + torch_compile: true report_to: "wandb" diff --git a/scripts/configs/qwen2/train_colqwen2_model.yaml b/scripts/configs/qwen2/train_colqwen2_model.yaml index 6ed38b7e2..684d9767a 100644 --- a/scripts/configs/qwen2/train_colqwen2_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 5 - per_device_train_batch_size: 64 + per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -56,6 +56,8 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" + dataloader_pin_memory: true + torch_compile: true report_to: "wandb" From b9d0f48facd37c9407281f30aadf915c019b2f2d Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:37:06 +0100 Subject: [PATCH 31/86] low worker --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 22090d0a5..0e452b954 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -43,7 +43,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 32 + dataloader_num_workers: 4 # bf16: true save_steps: 500 logging_steps: 10 From 2c3aaae2b82c5524b334a28b20715257d76d71a1 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:37:21 +0100 Subject: [PATCH 32/86] low worker --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 0e452b954..81bbdccab 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -56,7 +56,7 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" - dataloader_pin_memory: true + dataloader_pin_memory: false # true torch_compile: true report_to: "wandb" From 79d14e5af0a46c053e34f933ed267fdde915e8f5 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:38:03 +0100 Subject: [PATCH 33/86] ff --- scripts/configs/qwen2/train_colqwen2_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_model.yaml b/scripts/configs/qwen2/train_colqwen2_model.yaml index 684d9767a..3efed497d 100644 --- a/scripts/configs/qwen2/train_colqwen2_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 5 - per_device_train_batch_size: 128 + per_device_train_batch_size: 64 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From fa014081e27c8b3c8c56ccb388809866080b304c Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Wed, 19 Feb 2025 23:50:14 +0100 Subject: [PATCH 34/86] 512 --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 81bbdccab..2c3c54631 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 128 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 59869ff038972ec7ff2bad45d9187577223f35b4 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 00:05:08 +0100 Subject: [PATCH 35/86] 512 --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2c3c54631..16d4c7176 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -56,7 +56,7 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" - dataloader_pin_memory: false # true + dataloader_pin_memory: true # false torch_compile: true report_to: "wandb" From b143122f38ad733fba0175154a25ab7e8ad0f4cf Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 00:05:57 +0100 Subject: [PATCH 36/86] tt --- scripts/configs/qwen2/train_colqwen2_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_model.yaml b/scripts/configs/qwen2/train_colqwen2_model.yaml index 3efed497d..684d9767a 100644 --- a/scripts/configs/qwen2/train_colqwen2_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 5 - per_device_train_batch_size: 64 + per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 33d6cbedb9730a0b906c0ea97d86a9120d83c54d Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 00:09:20 +0100 Subject: [PATCH 37/86] test --- scripts/configs/qwen2/train_colqwen2_model.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/configs/qwen2/train_colqwen2_model.yaml b/scripts/configs/qwen2/train_colqwen2_model.yaml index 684d9767a..9e18e239c 100644 --- a/scripts/configs/qwen2/train_colqwen2_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_model.yaml @@ -38,6 +38,7 @@ config: per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } + ddp_find_unused_parameters: false # gradient_checkpointing: true # 6 x 8 gpus = 48 batch size # gradient_accumulation_steps: 4 From f546674e17c6feac9d51283e30611828a55b93c7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 00:25:08 +0100 Subject: [PATCH 38/86] ff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 16d4c7176..2af411bbf 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -43,11 +43,11 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 4 + dataloader_num_workers: 8 # bf16: true save_steps: 500 logging_steps: 10 - eval_steps: 20 + eval_steps: 100 warmup_steps: 100 learning_rate: 5e-4 save_total_limit: 1 From 2426507af68918d0088c99f3113106a5aafda7a3 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 10:00:18 +0100 Subject: [PATCH 39/86] gradcache --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2af411bbf..2521070ea 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 256 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 7f35d0fd1b6c5b69ab70ad8c4a8823d54d4e64b1 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 10:13:54 +0100 Subject: [PATCH 40/86] tt --- colpali_engine/loss/gradcache_late_interaction_losses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 126241a6c..a66b78c5f 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -199,7 +199,10 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F reps is a list with two elements: reps[0] for query embeddings and reps[1] for doc embeddings. """ embeddings_query = torch.cat(reps[0], dim=0) # shape: (batch, num_query_tokens, dim) - embeddings_doc = torch.cat(reps[1], dim=0) # shape: (batch, num_doc_tokens, dim) + embeddings_doc = torch.cat(reps[1], dim=0) + # shape: (batch, num_doc_tokens, dim) + print(f"embeddings_query.shape: {embeddings_query.shape}; embeddings_doc.shape: {embeddings_doc.shape}") + breakpoint() scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ .max(dim=3)[0].sum(dim=2) # (batch, batch) pos_scores = scores.diagonal() From bfced75deaa67e14039f7935a08c19664f88bd03 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 10:14:18 +0100 Subject: [PATCH 41/86] debug --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2521070ea..2af411bbf 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 17450e3637f36347d4893bb28ec0f0c2622c82c5 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 11:11:01 +0100 Subject: [PATCH 42/86] ff --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index a66b78c5f..ed6a69a90 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -202,7 +202,7 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F embeddings_doc = torch.cat(reps[1], dim=0) # shape: (batch, num_doc_tokens, dim) print(f"embeddings_query.shape: {embeddings_query.shape}; embeddings_doc.shape: {embeddings_doc.shape}") - breakpoint() + # breakpoint() scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ .max(dim=3)[0].sum(dim=2) # (batch, batch) pos_scores = scores.diagonal() From e19b79621de06077a552916e1a823a89025ac8cc Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 11:45:10 +0100 Subject: [PATCH 43/86] fff --- .../loss/gradcache_late_interaction_losses.py | 76 ++++++++++++++++--- .../qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index ed6a69a90..8ad6b95f6 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -193,22 +193,74 @@ def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, c mini_embeds = mini_embeds.detach().requires_grad_(True) yield mini_embeds, random_state + # def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + # """ + # Compute the ColBERTPairwiseCELoss using cached embeddings. + # reps is a list with two elements: reps[0] for query embeddings and reps[1] for doc embeddings. + # """ + # embeddings_query = torch.cat(reps[0], dim=0) # shape: (batch, num_query_tokens, dim) + # embeddings_doc = torch.cat(reps[1], dim=0) + # # shape: (batch, num_doc_tokens, dim) + # print(f"embeddings_query.shape: {embeddings_query.shape}; embeddings_doc.shape: {embeddings_doc.shape}") + # # breakpoint() + # scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ + # .max(dim=3)[0].sum(dim=2) # (batch, batch) + # pos_scores = scores.diagonal() + # neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 + # neg_scores = neg_scores.max(dim=1)[0] + # loss = F.softplus(neg_scores - pos_scores).mean() + # if with_backward: + # loss.backward() + # return loss + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: """ - Compute the ColBERTPairwiseCELoss using cached embeddings. - reps is a list with two elements: reps[0] for query embeddings and reps[1] for doc embeddings. + Compute the ColBERTPairwiseCELoss using cached embeddings without concatenating query embeddings. + reps[0] contains query embedding chunks (each of shape: (chunk_size, num_query_tokens, dim)), + while reps[1] contains doc embeddings, which we concatenate. + + For each query chunk, we: + - Compute scores with all docs using an einsum. + - Reduce the scores by taking a max over the doc tokens and summing over query tokens. + - Extract the positive score for each query based on its overall index (assuming query i matches doc i). + - Mask out the positive score and take the max over negatives. + - Compute the softplus loss over the difference (neg_score - pos_score). + + The overall loss is the average over all queries, and remains differentiable. """ - embeddings_query = torch.cat(reps[0], dim=0) # shape: (batch, num_query_tokens, dim) + # Concatenate document embeddings (shape: (total_docs, num_doc_tokens, dim)) embeddings_doc = torch.cat(reps[1], dim=0) - # shape: (batch, num_doc_tokens, dim) - print(f"embeddings_query.shape: {embeddings_query.shape}; embeddings_doc.shape: {embeddings_doc.shape}") - # breakpoint() - scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ - .max(dim=3)[0].sum(dim=2) # (batch, batch) - pos_scores = scores.diagonal() - neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 - neg_scores = neg_scores.max(dim=1)[0] - loss = F.softplus(neg_scores - pos_scores).mean() + + total_loss = 0.0 + total_queries = 0 + global_index = 0 # Tracks the overall index for positive pairing + + # Loop over query chunks + for query_chunk in reps[0]: + chunk_size = query_chunk.size(0) + # Compute pairwise scores: + # Resulting shape: (chunk_size, total_docs, num_query_tokens, num_doc_tokens) + scores_chunk = torch.einsum("bnd,csd->bcns", query_chunk, embeddings_doc) + # Reduce: max over document tokens then sum over query tokens -> shape: (chunk_size, total_docs) + scores_chunk = scores_chunk.max(dim=3)[0].sum(dim=2) + + # For each query in the chunk, the positive doc index is global_index + local_index + row_idx = torch.arange(chunk_size, device=scores_chunk.device) + pos_idx = torch.arange(global_index, global_index + chunk_size, device=scores_chunk.device) + pos_scores = scores_chunk[row_idx, pos_idx] + + # Mask out the positive scores by setting them to a very low value, then take the max over negatives + scores_masked = scores_chunk.clone() + scores_masked[row_idx, pos_idx] = -1e6 + neg_scores = scores_masked.max(dim=1)[0] + + # Compute loss for this chunk (sum over the chunk's queries) + chunk_loss = F.softplus(neg_scores - pos_scores).sum() + total_loss += chunk_loss + total_queries += chunk_size + global_index += chunk_size + + loss = total_loss / total_queries if with_backward: loss.backward() return loss diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2af411bbf..2521070ea 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -35,7 +35,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 256 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From a3c5935f5979874e1fbb743291cd6d61c60fb479 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 12:06:51 +0100 Subject: [PATCH 44/86] mini bs --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2521070ea..d3cccff63 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,12 +30,13 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # + mini_batch_size: 64 tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 1024 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 1414a8e25457d234556d72bd25e55ef8c3fbf3fc Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 12:12:43 +0100 Subject: [PATCH 45/86] fff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index d3cccff63..2647088f6 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,7 +30,7 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 64 + mini_batch_size: 32 tr_args: (): transformers.training_args.TrainingArguments output_dir: null From b5dce315d20d1a2caab4d689d72b2e3ef33df848 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 12:15:50 +0100 Subject: [PATCH 46/86] ffff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 2647088f6..95c01b1dd 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,13 +30,13 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 32 + mini_batch_size: 64 tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 1024 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 524b6fd273a617fe636718192d51cc1833b35ff7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 15:56:38 +0100 Subject: [PATCH 47/86] ffg --- colpali_engine/loss/gradcache_late_interaction_losses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 8ad6b95f6..c51a11a4f 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -250,9 +250,9 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F pos_scores = scores_chunk[row_idx, pos_idx] # Mask out the positive scores by setting them to a very low value, then take the max over negatives - scores_masked = scores_chunk.clone() - scores_masked[row_idx, pos_idx] = -1e6 - neg_scores = scores_masked.max(dim=1)[0] + # scores_masked = scores_chunk.clone() + scores_chunk[row_idx, pos_idx] = -1e6 + neg_scores = scores_chunk.max(dim=1)[0] # Compute loss for this chunk (sum over the chunk's queries) chunk_loss = F.softplus(neg_scores - pos_scores).sum() From a76abe3f0ea2e888de11d749165006be1f2d3f4f Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 17:44:10 +0100 Subject: [PATCH 48/86] break --- .../collators/visual_retriever_collator.py | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 070115198..a4f1dc06a 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -20,6 +20,7 @@ def __init__( self.processor = processor self.image_token_id = None self.max_length = max_length + self.minibatch_size = 32 if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor): self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[ @@ -41,34 +42,27 @@ def __call__( """ # Placeholders texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query - images: List[Image] = [] - neg_images: List[Image] = [] - if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - raise ValueError("Processor should be provided for vision collator.") + # if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): + # raise ValueError("Processor should be provided for vision collator.") # Process each example - for example in examples: - texts_query.append(example["query"]) - if example["image"] is None: - raise ValueError("Image is None - This collator does not support None images yet.") - - images.append(cast(Image, example["image"])) - - if "neg_image" in example and example["neg_image"] is not None: - neg_images.append(cast(Image, example["neg_image"])) - - # Process the documents - batch_doc = self.processor.process_images( - images=images, - ) - - # Process the negative documents (if available) - batch_neg_doc = None - if len(neg_images) > 0: - batch_neg_doc = self.processor.process_images( - images=neg_images, - ) + batch_doc = [] + batch_neg_doc = [] + for i in range(0, len(examples), self.minibatch_size): + # Process the documents + batch_doc += [self.processor.process_images( + images=examples[i : i + self.minibatch_size]["image"], + )] + + # Process the negative documents (if available) + batch_neg_doc = None + if "neg_image" in examples[i]: + batch_neg_doc += [self.processor.process_images( + images=examples[i : i + self.minibatch_size]["neg_image"], + )] + + breakpoint() # Process the queries batch_query = None @@ -80,21 +74,18 @@ def __call__( # If it's the first query that is not None but the rest are None, then it's hard negatives. raise ValueError("Some queries are None. This collator does not support None queries yet.") else: - texts_query = cast(List[str], texts_query) batch_query = self.processor.process_queries( - queries=texts_query, + queries=examples["query"], max_length=self.max_length, ) # Prefix each key with "doc_" or "query_" to avoid key conflicts batch_all = {f"doc_{k}": v for k, v in batch_doc.items()} - del batch_doc + if batch_query is not None: - batch_query = {f"query_{k}": v for k, v in batch_query.items()} - batch_all.update(batch_query) - del batch_query + batch_all.update({f"query_{k}": v for k, v in batch_query.items()}) + if batch_neg_doc is not None: - batch_neg_doc = {f"neg_doc_{k}": v for k, v in batch_neg_doc.items()} - batch_all.update(batch_neg_doc) + batch_all.update({f"neg_doc_{k}": v for k, v in batch_neg_doc.items()}) return batch_all From aa27ae1e1c40c29fb110a17d1a3ce425feee1155 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:00:08 +0100 Subject: [PATCH 49/86] fff --- .../collators/visual_retriever_collator.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index a4f1dc06a..5551a9525 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Union, cast -from PIL.Image import Image - from colpali_engine.models.idefics_2 import ColIdefics2Processor from colpali_engine.models.paligemma import ColPaliProcessor from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -16,11 +14,12 @@ def __init__( self, processor: BaseVisualRetrieverProcessor, max_length: int = 2048, + minibatch_size: int = 32, ): self.processor = processor self.image_token_id = None self.max_length = max_length - self.minibatch_size = 32 + self.minibatch_size = minibatch_size if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor): self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[ @@ -51,6 +50,7 @@ def __call__( batch_neg_doc = [] for i in range(0, len(examples), self.minibatch_size): # Process the documents + breakpoint() batch_doc += [self.processor.process_images( images=examples[i : i + self.minibatch_size]["image"], )] @@ -89,3 +89,30 @@ def __call__( batch_all.update({f"neg_doc_{k}": v for k, v in batch_neg_doc.items()}) return batch_all + + +if __name__ == "__main__": + from PIL import Image + + processor = ColPaliProcessor.from_pretrained("vidore/colpali") + collator = VisualRetrieverCollator(processor=processor, minibatch_size=2) + examples = [ + { + "image": Image.new("RGB", (100, 100)), + "query": "What is this?", + }, + { + "image": Image.new("RGB", (150, 100)), + "query": "What is this?", + }, + { + "image": Image.new("RGB", (200, 100)), + "query": "What is this?", + }, + { + "image": Image.new("RGB", (100, 200)), + "query": "What is this?", + }, + ] + from datasets import Dataset + collator(Dataset.from_list(examples)) \ No newline at end of file From 7563254dacbd58c1199f1f73c2f8d96732a721e0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:14:05 +0100 Subject: [PATCH 50/86] pad --- .../collators/visual_retriever_collator.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 5551a9525..7d26eceb3 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,9 @@ +import torch + from typing import Any, Dict, List, Union, cast +from transformers.image_processing_base import BatchFeature + from colpali_engine.models.idefics_2 import ColIdefics2Processor from colpali_engine.models.paligemma import ColPaliProcessor from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -46,23 +50,33 @@ def __call__( # raise ValueError("Processor should be provided for vision collator.") # Process each example - batch_doc = [] - batch_neg_doc = [] + tmp_batch_doc = [] + tmp_batch_neg_doc = [] for i in range(0, len(examples), self.minibatch_size): # Process the documents - breakpoint() - batch_doc += [self.processor.process_images( + tmp_batch_doc += [self.processor.process_images( images=examples[i : i + self.minibatch_size]["image"], )] # Process the negative documents (if available) batch_neg_doc = None if "neg_image" in examples[i]: - batch_neg_doc += [self.processor.process_images( + tmp_batch_neg_doc += [self.processor.process_images( images=examples[i : i + self.minibatch_size]["neg_image"], )] - breakpoint() + batch_doc = {} + batch_neg_doc = None if tmp_batch_neg_doc is None else {} + for key in tmp_batch_doc[0].keys(): + batch_doc[key] = torch.nn.utils.rnn.pad_sequence([a for b in tmp_batch_doc for a in b[key]], batch_first=True, padding_value=0) + + batch_doc = BatchFeature(batch_doc) + + if tmp_batch_neg_doc is not None: + for key in tmp_batch_neg_doc[0].keys(): + batch_neg_doc[key] = torch.nn.utils.rnn.pad_sequence([a for b in tmp_batch_neg_doc for a in b[key]], batch_first=True, padding_value=0) + + batch_neg_doc = BatchFeature(batch_neg_doc) # Process the queries batch_query = None From cf831d3e56a8ba5109944dfeb75be61450781a4a Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:15:01 +0100 Subject: [PATCH 51/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 7d26eceb3..9343b1fcc 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -54,6 +54,7 @@ def __call__( tmp_batch_neg_doc = [] for i in range(0, len(examples), self.minibatch_size): # Process the documents + breakpoint() tmp_batch_doc += [self.processor.process_images( images=examples[i : i + self.minibatch_size]["image"], )] From c637e6f6fba20eb6bd83e794a6e9f97d8e7bbaaa Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:17:15 +0100 Subject: [PATCH 52/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 9343b1fcc..a71becf80 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -52,6 +52,7 @@ def __call__( # Process each example tmp_batch_doc = [] tmp_batch_neg_doc = [] + breakpoint() for i in range(0, len(examples), self.minibatch_size): # Process the documents breakpoint() From 88d096fd7aa91315a7b25784fae777ed7ba3ac57 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:25:03 +0100 Subject: [PATCH 53/86] fff --- .../collators/visual_retriever_collator.py | 76 +++---------------- 1 file changed, 12 insertions(+), 64 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index a71becf80..b17739e22 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,9 +1,5 @@ -import torch - from typing import Any, Dict, List, Union, cast -from transformers.image_processing_base import BatchFeature - from colpali_engine.models.idefics_2 import ColIdefics2Processor from colpali_engine.models.paligemma import ColPaliProcessor from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -18,12 +14,10 @@ def __init__( self, processor: BaseVisualRetrieverProcessor, max_length: int = 2048, - minibatch_size: int = 32, ): self.processor = processor self.image_token_id = None self.max_length = max_length - self.minibatch_size = minibatch_size if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor): self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[ @@ -46,39 +40,20 @@ def __call__( # Placeholders texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query - # if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - # raise ValueError("Processor should be provided for vision collator.") - - # Process each example - tmp_batch_doc = [] - tmp_batch_neg_doc = [] - breakpoint() - for i in range(0, len(examples), self.minibatch_size): - # Process the documents - breakpoint() - tmp_batch_doc += [self.processor.process_images( - images=examples[i : i + self.minibatch_size]["image"], - )] - - # Process the negative documents (if available) - batch_neg_doc = None - if "neg_image" in examples[i]: - tmp_batch_neg_doc += [self.processor.process_images( - images=examples[i : i + self.minibatch_size]["neg_image"], - )] - - batch_doc = {} - batch_neg_doc = None if tmp_batch_neg_doc is None else {} - for key in tmp_batch_doc[0].keys(): - batch_doc[key] = torch.nn.utils.rnn.pad_sequence([a for b in tmp_batch_doc for a in b[key]], batch_first=True, padding_value=0) + if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): + raise ValueError("Processor should be provided for vision collator.") - batch_doc = BatchFeature(batch_doc) + # Process the documents + batch_doc = self.processor.process_images( + images=examples["image"] + ) - if tmp_batch_neg_doc is not None: - for key in tmp_batch_neg_doc[0].keys(): - batch_neg_doc[key] = torch.nn.utils.rnn.pad_sequence([a for b in tmp_batch_neg_doc for a in b[key]], batch_first=True, padding_value=0) - - batch_neg_doc = BatchFeature(batch_neg_doc) + # Process the negative documents (if available) + batch_neg_doc = None + if "neg_image" in examples: + batch_neg_doc = self.processor.process_images( + images=examples["neg_image"] + ) # Process the queries batch_query = None @@ -105,30 +80,3 @@ def __call__( batch_all.update({f"neg_doc_{k}": v for k, v in batch_neg_doc.items()}) return batch_all - - -if __name__ == "__main__": - from PIL import Image - - processor = ColPaliProcessor.from_pretrained("vidore/colpali") - collator = VisualRetrieverCollator(processor=processor, minibatch_size=2) - examples = [ - { - "image": Image.new("RGB", (100, 100)), - "query": "What is this?", - }, - { - "image": Image.new("RGB", (150, 100)), - "query": "What is this?", - }, - { - "image": Image.new("RGB", (200, 100)), - "query": "What is this?", - }, - { - "image": Image.new("RGB", (100, 200)), - "query": "What is this?", - }, - ] - from datasets import Dataset - collator(Dataset.from_list(examples)) \ No newline at end of file From d3f11ab44c061304d6b6d0ccf614a0341f58d845 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:27:06 +0100 Subject: [PATCH 54/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index b17739e22..ff132fb86 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -37,13 +37,11 @@ def __call__( """ Collate function for the vision retriever associated to the collator's processor. """ - # Placeholders - texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query - - if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - raise ValueError("Processor should be provided for vision collator.") + # if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): + # raise ValueError("Processor should be provided for vision collator.") # Process the documents + breakpoint() batch_doc = self.processor.process_images( images=examples["image"] ) From 605c0e84cc7bc3551f1b1b74257d4b5562f32d4b Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:29:53 +0100 Subject: [PATCH 55/86] revert --- .../collators/visual_retriever_collator.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index ff132fb86..070115198 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Union, cast +from PIL.Image import Image + from colpali_engine.models.idefics_2 import ColIdefics2Processor from colpali_engine.models.paligemma import ColPaliProcessor from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -37,20 +39,35 @@ def __call__( """ Collate function for the vision retriever associated to the collator's processor. """ - # if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - # raise ValueError("Processor should be provided for vision collator.") + # Placeholders + texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query + images: List[Image] = [] + neg_images: List[Image] = [] + + if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): + raise ValueError("Processor should be provided for vision collator.") + + # Process each example + for example in examples: + texts_query.append(example["query"]) + if example["image"] is None: + raise ValueError("Image is None - This collator does not support None images yet.") + + images.append(cast(Image, example["image"])) + + if "neg_image" in example and example["neg_image"] is not None: + neg_images.append(cast(Image, example["neg_image"])) # Process the documents - breakpoint() batch_doc = self.processor.process_images( - images=examples["image"] + images=images, ) # Process the negative documents (if available) batch_neg_doc = None - if "neg_image" in examples: + if len(neg_images) > 0: batch_neg_doc = self.processor.process_images( - images=examples["neg_image"] + images=neg_images, ) # Process the queries @@ -63,18 +80,21 @@ def __call__( # If it's the first query that is not None but the rest are None, then it's hard negatives. raise ValueError("Some queries are None. This collator does not support None queries yet.") else: + texts_query = cast(List[str], texts_query) batch_query = self.processor.process_queries( - queries=examples["query"], + queries=texts_query, max_length=self.max_length, ) # Prefix each key with "doc_" or "query_" to avoid key conflicts batch_all = {f"doc_{k}": v for k, v in batch_doc.items()} - + del batch_doc if batch_query is not None: - batch_all.update({f"query_{k}": v for k, v in batch_query.items()}) - + batch_query = {f"query_{k}": v for k, v in batch_query.items()} + batch_all.update(batch_query) + del batch_query if batch_neg_doc is not None: - batch_all.update({f"neg_doc_{k}": v for k, v in batch_neg_doc.items()}) + batch_neg_doc = {f"neg_doc_{k}": v for k, v in batch_neg_doc.items()} + batch_all.update(batch_neg_doc) return batch_all From 6c9d73105cb34a91639b11c5c52a7256f991bda0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:36:55 +0100 Subject: [PATCH 56/86] test --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 95c01b1dd..de4c6cac8 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 64 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 8 + dataloader_num_workers: 1 # bf16: true save_steps: 500 logging_steps: 10 From 38009d0372109ce75374702cb1f4a9f5bff94c81 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:37:34 +0100 Subject: [PATCH 57/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 070115198..c9f098827 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -43,6 +43,7 @@ def __call__( texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query images: List[Image] = [] neg_images: List[Image] = [] + breakpoint() if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): raise ValueError("Processor should be provided for vision collator.") From c391de31a4249e82f7be13dfb7231814ad1aed36 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 18:45:21 +0100 Subject: [PATCH 58/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 2 +- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index c9f098827..6917368e9 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -43,7 +43,7 @@ def __call__( texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query images: List[Image] = [] neg_images: List[Image] = [] - breakpoint() + # breakpoint() if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): raise ValueError("Processor should be provided for vision collator.") diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index de4c6cac8..95c01b1dd 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 64 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 1 + dataloader_num_workers: 8 # bf16: true save_steps: 500 logging_steps: 10 From 921a681f82c2975ed1ac2274b6c53da8fff40603 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 23:45:21 +0100 Subject: [PATCH 59/86] goo --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 95c01b1dd..e8ef6b940 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,13 +30,13 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 64 + mini_batch_size: 32 tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 512 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 6f766305f25141a2434057656c8d051cd1e271f7 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 23:48:57 +0100 Subject: [PATCH 60/86] test --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index e8ef6b940..cc72dded3 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 256 + per_device_train_batch_size: 64 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 13c36781c8c6e81902ae785531401f7d4534e884 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Thu, 20 Feb 2025 23:58:08 +0100 Subject: [PATCH 61/86] shape --- colpali_engine/loss/gradcache_late_interaction_losses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index c51a11a4f..541a3df1b 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -238,6 +238,7 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F # Loop over query chunks for query_chunk in reps[0]: chunk_size = query_chunk.size(0) + print(f"Shape of query chunk: {query_chunk.shape}, Shape of embeddings doc: {embeddings_doc.shape}") # Compute pairwise scores: # Resulting shape: (chunk_size, total_docs, num_query_tokens, num_doc_tokens) scores_chunk = torch.einsum("bnd,csd->bcns", query_chunk, embeddings_doc) From bad7a537a6be1c542898ae4d17746449190d3f36 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 00:06:53 +0100 Subject: [PATCH 62/86] gooo --- colpali_engine/loss/gradcache_late_interaction_losses.py | 2 +- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py index 541a3df1b..05f2a3878 100644 --- a/colpali_engine/loss/gradcache_late_interaction_losses.py +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -238,7 +238,7 @@ def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = F # Loop over query chunks for query_chunk in reps[0]: chunk_size = query_chunk.size(0) - print(f"Shape of query chunk: {query_chunk.shape}, Shape of embeddings doc: {embeddings_doc.shape}") + # print(f"Shape of query chunk: {query_chunk.shape}, Shape of embeddings doc: {embeddings_doc.shape}") # Compute pairwise scores: # Resulting shape: (chunk_size, total_docs, num_query_tokens, num_doc_tokens) scores_chunk = torch.einsum("bnd,csd->bcns", query_chunk, embeddings_doc) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index cc72dded3..387d11409 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 8 + dataloader_num_workers: 4 # bf16: true save_steps: 500 logging_steps: 10 @@ -57,7 +57,7 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" - dataloader_pin_memory: true # false + dataloader_pin_memory: false # true # false torch_compile: true report_to: "wandb" From 33a52b640fd1f4f5b9c23390eaec6b25d0596160 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 09:26:20 +0100 Subject: [PATCH 63/86] train --- .../collators/corpus_query_collator.py | 3 + .../collators/visual_retriever_collator.py | 109 ++++++++++++++++-- colpali_engine/trainer/colmodel_training.py | 15 +++ 3 files changed, 115 insertions(+), 12 deletions(-) diff --git a/colpali_engine/collators/corpus_query_collator.py b/colpali_engine/collators/corpus_query_collator.py index 8a95e3ebe..74684b81c 100644 --- a/colpali_engine/collators/corpus_query_collator.py +++ b/colpali_engine/collators/corpus_query_collator.py @@ -15,16 +15,19 @@ def __init__( image_dataset: Optional["Dataset"] = None, # noqa: F821 mined_negatives: bool = True, corpus_format: str = "wikiss", + process_images_before_training: bool = False, ): super().__init__( processor=processor, max_length=max_length, + process_images_before_training=process_images_before_training, ) if image_dataset is None: raise ValueError("`image_dataset` must be provided") self.image_dataset = image_dataset self.mined_negatives = mined_negatives self.corpus_format = corpus_format + self.process_images_before_training = process_images_before_training if self.corpus_format == "wikiss": print("Mapping docids to indices") diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 6917368e9..d5b5d1ac1 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Union +import torch from PIL.Image import Image from colpali_engine.models.idefics_2 import ColIdefics2Processor @@ -16,10 +17,12 @@ def __init__( self, processor: BaseVisualRetrieverProcessor, max_length: int = 2048, + process_images_before_training: bool = False, ): self.processor = processor self.image_token_id = None self.max_length = max_length + self.process_images_before_training = process_images_before_training if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor): self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[ @@ -35,6 +38,15 @@ def __init__( def __call__( self, examples: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if self.process_images_before_training: + return self.offline_processing(examples) + return self.online_processing(examples) + + + def online_processing( + self, + examples: List[Dict[str, Any]], ) -> Dict[str, Any]: """ Collate function for the vision retriever associated to the collator's processor. @@ -45,19 +57,14 @@ def __call__( neg_images: List[Image] = [] # breakpoint() - if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - raise ValueError("Processor should be provided for vision collator.") # Process each example for example in examples: texts_query.append(example["query"]) - if example["image"] is None: - raise ValueError("Image is None - This collator does not support None images yet.") - - images.append(cast(Image, example["image"])) + images.append(example["image"]) if "neg_image" in example and example["neg_image"] is not None: - neg_images.append(cast(Image, example["neg_image"])) + neg_images.append(example["neg_image"]) # Process the documents batch_doc = self.processor.process_images( @@ -77,11 +84,8 @@ def __call__( if all([t is None for t in texts_query]): # print("All queries are `None`. Returning `None` for all queries.") pass - elif any([t is None for t in texts_query]): - # If it's the first query that is not None but the rest are None, then it's hard negatives. - raise ValueError("Some queries are None. This collator does not support None queries yet.") else: - texts_query = cast(List[str], texts_query) + texts_query: List[str] = texts_query batch_query = self.processor.process_queries( queries=texts_query, max_length=self.max_length, @@ -99,3 +103,84 @@ def __call__( batch_all.update(batch_neg_doc) return batch_all + + + def offline_procesing( + self, + examples: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """ + Collate function for the vision retriever associated to the collator's processor. + """ + # Placeholders + texts_query = [] + pixel_values = [] + image_grid_thw = [] + input_ids = [] + attention_mask = [] + neg_pixel_values = [] + neg_image_grid_thw = [] + neg_input_ids = [] + neg_attention_mask = [] + + breakpoint() + + for example in examples: + texts_query.append(example["query"]) + pixel_values.append(example["pixel_values"]) + image_grid_thw.append(example["image_grid_thw"]) + input_ids.append(example["input_ids"]) + attention_mask.append(example["attention_mask"]) + + if "neg_pixel_values" in example: + neg_pixel_values.append(example["neg_pixel_values"]) + neg_image_grid_thw.append(example["neg_image_grid_thw"]) + neg_input_ids.append(example["neg_input_ids"]) + neg_attention_mask.append(example["neg_attention_mask"]) + + # Pad pixel values + pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0) + image_grid_thw = torch.stack(image_grid_thw) + + # Pad input sequences + batch_doc = self.processor.tokenizer.pad( + {"input_ids": input_ids, "attention_mask": attention_mask}, + padding=True, + return_tensors="pt" + ) + + batch_all = { + "doc_pixel_values": pixel_values, + "doc_image_grid_thw": image_grid_thw, + "doc_input_ids": batch_doc["input_ids"], + "doc_attention_mask": batch_doc["attention_mask"], + } + + # Process queries + if any(texts_query): # Ensure there are valid queries + batch_query = self.processor.process_queries( + queries=texts_query, + max_length=self.max_length + ) + batch_all["query_input_ids"] = batch_query["input_ids"] + batch_all["query_attention_mask"] = batch_query["attention_mask"] + + # Process negatives if present + if neg_pixel_values: + neg_pixel_values = torch.nn.utils.rnn.pad_sequence(neg_pixel_values, batch_first=True, padding_value=0) + neg_image_grid_thw = torch.stack(neg_image_grid_thw) + + batch_neg_doc = self.processor.tokenizer.pad( + {"input_ids": neg_input_ids, "attention_mask": neg_attention_mask}, + padding=True, + return_tensors="pt" + ) + + batch_all.update({ + "neg_doc_pixel_values": neg_pixel_values, + "neg_doc_image_grid_thw": neg_image_grid_thw, + "neg_doc_input_ids": batch_neg_doc["input_ids"], + "neg_doc_attention_mask": batch_neg_doc["attention_mask"], + }) + + return batch_all diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index ead037d65..0d401f1b2 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -109,9 +109,24 @@ def __init__(self, config: ColModelTrainingConfig) -> None: processor=self.config.processor, max_length=self.config.max_length, ) + + breakpoint() + self.dataset = self.dataset.map(self.preprocess_example, num_proc=self.config.tr_args.dataloader_num_workers) + breakpoint() self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retrieval_evaluator = CustomRetrievalEvaluator() + def preprocess_example(self, example: Dict): + processed = self.config.processor.process_images([example["image"]]) + for key in processed: + example[key] = processed[key].squeeze(0) + if "neg_image" in example and example["neg_image"] is not None: + neg_processed = self.config.processor.process_images([example["neg_image"]]) + for key in neg_processed: + example[f"neg_{key}"] = neg_processed[key].squeeze(0) + + return example + def train(self) -> None: if isinstance(self.collator, CorpusQueryCollator) and self.collator.mined_negatives: print("Training with hard negatives") From 54e3e092dc2bc3b15e51d2bf155e55e768596bf0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:07:22 +0100 Subject: [PATCH 64/86] fff --- colpali_engine/trainer/colmodel_training.py | 11 ++++++----- .../configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 0d401f1b2..db44ef217 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -110,18 +110,19 @@ def __init__(self, config: ColModelTrainingConfig) -> None: max_length=self.config.max_length, ) - breakpoint() - self.dataset = self.dataset.map(self.preprocess_example, num_proc=self.config.tr_args.dataloader_num_workers) + self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), + num_proc=self.config.tr_args.dataloader_num_workers) breakpoint() self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retrieval_evaluator = CustomRetrievalEvaluator() - def preprocess_example(self, example: Dict): - processed = self.config.processor.process_images([example["image"]]) + @staticmethod + def preprocess_example(example: Dict, processor): + processed = processor.process_images([example["image"]]) for key in processed: example[key] = processed[key].squeeze(0) if "neg_image" in example and example["neg_image"] is not None: - neg_processed = self.config.processor.process_images([example["neg_image"]]) + neg_processed = processor.process_images([example["neg_image"]]) for key in neg_processed: example[f"neg_{key}"] = neg_processed[key].squeeze(0) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 387d11409..f418baf1c 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 4 + dataloader_num_workers: 16 # bf16: true save_steps: 500 logging_steps: 10 From a29a5b75e481194c2913f46c24f5c8f76cfa1ae4 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:19:31 +0100 Subject: [PATCH 65/86] fff --- colpali_engine/trainer/colmodel_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index db44ef217..a674d532d 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -111,7 +111,7 @@ def __init__(self, config: ColModelTrainingConfig) -> None: ) self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - num_proc=self.config.tr_args.dataloader_num_workers) + num_proc=1) # self.config.tr_args.dataloader_num_workers) breakpoint() self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retrieval_evaluator = CustomRetrievalEvaluator() From 784873ce5b4891e877046ebe9fa33ad8a547d279 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:40:48 +0100 Subject: [PATCH 66/86] debug --- colpali_engine/trainer/colmodel_training.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index a674d532d..b85fc57c2 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -110,15 +110,14 @@ def __init__(self, config: ColModelTrainingConfig) -> None: max_length=self.config.max_length, ) - self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - num_proc=1) # self.config.tr_args.dataloader_num_workers) - breakpoint() self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retrieval_evaluator = CustomRetrievalEvaluator() @staticmethod def preprocess_example(example: Dict, processor): + breakpoint() processed = processor.process_images([example["image"]]) + breakpoint() for key in processed: example[key] = processed[key].squeeze(0) if "neg_image" in example and example["neg_image"] is not None: @@ -134,6 +133,11 @@ def train(self) -> None: else: print("Training with in-batch negatives") + + self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), + num_proc=1) # self.config.tr_args.dataloader_num_workers) + + trainer = ContrastiveTrainer( model=self.model, train_dataset=self.dataset["train"], From e39c0b8137e47b5340d621d6ab64e6822b6afd76 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:46:42 +0100 Subject: [PATCH 67/86] fff --- colpali_engine/trainer/colmodel_training.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index b85fc57c2..7ec35e6ab 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -115,17 +115,16 @@ def __init__(self, config: ColModelTrainingConfig) -> None: @staticmethod def preprocess_example(example: Dict, processor): - breakpoint() processed = processor.process_images([example["image"]]) - breakpoint() + new_example = {} for key in processed: - example[key] = processed[key].squeeze(0) + new_example[key] = processed[key].squeeze(0) if "neg_image" in example and example["neg_image"] is not None: neg_processed = processor.process_images([example["neg_image"]]) for key in neg_processed: - example[f"neg_{key}"] = neg_processed[key].squeeze(0) + new_example[f"neg_{key}"] = neg_processed[key].squeeze(0) - return example + return new_example def train(self) -> None: if isinstance(self.collator, CorpusQueryCollator) and self.collator.mined_negatives: From 427a0c8108d4b2ccbe534cae3f521dfc8bcfa5c4 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:49:40 +0100 Subject: [PATCH 68/86] ff --- colpali_engine/trainer/colmodel_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 7ec35e6ab..05c13a320 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -134,7 +134,7 @@ def train(self) -> None: self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - num_proc=1) # self.config.tr_args.dataloader_num_workers) + num_proc=4) # self.config.tr_args.dataloader_num_workers) trainer = ContrastiveTrainer( From 023c33e9b12fcff0bc6fd238c086a6d68715ce69 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:52:29 +0100 Subject: [PATCH 69/86] fff --- colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py | 1 + colpali_engine/trainer/colmodel_training.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py index db9c7455d..7e21e7b70 100644 --- a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py @@ -124,6 +124,7 @@ def process_images( # separate pixel_values for each image pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist()) + breakpoint() # pad pixel_values to the same length to be able to make it into a tensor max_length = max([len(pv) for pv in pixel_values]) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 05c13a320..7ec35e6ab 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -134,7 +134,7 @@ def train(self) -> None: self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - num_proc=4) # self.config.tr_args.dataloader_num_workers) + num_proc=1) # self.config.tr_args.dataloader_num_workers) trainer = ContrastiveTrainer( From 23a9d21c5d606a97c5a57934d29687067f0fda81 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:57:22 +0100 Subject: [PATCH 70/86] fff --- colpali_engine/collators/visual_retriever_collator.py | 2 +- colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py | 1 - colpali_engine/trainer/colmodel_training.py | 4 ++-- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index d5b5d1ac1..49e08db20 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -105,7 +105,7 @@ def online_processing( return batch_all - def offline_procesing( + def offline_processing( self, examples: List[Dict[str, Any]], ) -> Dict[str, Any]: diff --git a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py index 7e21e7b70..db9c7455d 100644 --- a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py @@ -124,7 +124,6 @@ def process_images( # separate pixel_values for each image pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist()) - breakpoint() # pad pixel_values to the same length to be able to make it into a tensor max_length = max([len(pv) for pv in pixel_values]) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 7ec35e6ab..e68dd0c60 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -133,8 +133,8 @@ def train(self) -> None: print("Training with in-batch negatives") - self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - num_proc=1) # self.config.tr_args.dataloader_num_workers) + # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), + # num_proc=self.config.tr_args.dataloader_num_workers) trainer = ContrastiveTrainer( diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index f418baf1c..83ac17f80 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 256 + per_device_train_batch_size: 2048 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 16 + dataloader_num_workers: 1 # bf16: true save_steps: 500 logging_steps: 10 From 3d2266e031f5115a32fb9803c122f5a8f140049c Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 11:57:53 +0100 Subject: [PATCH 71/86] ff --- colpali_engine/collators/visual_retriever_collator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 49e08db20..5c5fc4827 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -123,8 +123,6 @@ def offline_processing( neg_input_ids = [] neg_attention_mask = [] - breakpoint() - for example in examples: texts_query.append(example["query"]) pixel_values.append(example["pixel_values"]) From 1dc3499fdb273fc09070a05e2611664887d4ef43 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 12:13:51 +0100 Subject: [PATCH 72/86] 1024 --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 83ac17f80..c79cc74e2 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -36,7 +36,7 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 1 - per_device_train_batch_size: 2048 + per_device_train_batch_size: 1024 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true From 3a13aea9acd8f34952d61b8efc8846921ff05436 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 12:14:57 +0100 Subject: [PATCH 73/86] test --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index c79cc74e2..23c7cfad0 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,7 +30,7 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 32 + mini_batch_size: 128 tr_args: (): transformers.training_args.TrainingArguments output_dir: null @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 1 + dataloader_num_workers: 2 # bf16: true save_steps: 500 logging_steps: 10 From df5a89a84fd4912299544a2bcb997c0b532e5cfe Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 12:54:42 +0100 Subject: [PATCH 74/86] smart resize --- colpali_engine/trainer/colmodel_training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index e68dd0c60..e2b1ab41d 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -136,6 +136,8 @@ def train(self) -> None: # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), # num_proc=self.config.tr_args.dataloader_num_workers) + self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, num_proc=1) + trainer = ContrastiveTrainer( model=self.model, From 0fd33c3b81d8e26c64d2dd764fa99dff14373d4d Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 13:01:16 +0100 Subject: [PATCH 75/86] writer batch size --- colpali_engine/trainer/colmodel_training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index e2b1ab41d..6b0733a7f 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -136,7 +136,9 @@ def train(self) -> None: # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), # num_proc=self.config.tr_args.dataloader_num_workers) - self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, num_proc=1) + self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, + num_proc=self.config.tr_args.dataloader_num_workers, + writer_batch_size=32) trainer = ContrastiveTrainer( From e29b9bd5cc1987be534528a2e0f177587a78b193 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 13:03:43 +0100 Subject: [PATCH 76/86] fff --- colpali_engine/trainer/colmodel_training.py | 2 +- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 6b0733a7f..4ced97320 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -134,7 +134,7 @@ def train(self) -> None: # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), - # num_proc=self.config.tr_args.dataloader_num_workers) + # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, num_proc=self.config.tr_args.dataloader_num_workers, diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 23c7cfad0..dd4dececb 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,7 +30,7 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 128 + mini_batch_size: 64 tr_args: (): transformers.training_args.TrainingArguments output_dir: null @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 2 + dataloader_num_workers: 16 # bf16: true save_steps: 500 logging_steps: 10 From ff4888172cc312678a006704e22ad6f6ab4999d0 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 13:28:15 +0100 Subject: [PATCH 77/86] workers --- colpali_engine/trainer/colmodel_training.py | 2 +- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 4ced97320..2293eb7d1 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -137,7 +137,7 @@ def train(self) -> None: # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, - num_proc=self.config.tr_args.dataloader_num_workers, + num_proc=self.config.tr_args.dataloader_num_workers*16, writer_batch_size=32) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index dd4dececb..55dfcac4a 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 16 + dataloader_num_workers: 4 # bf16: true save_steps: 500 logging_steps: 10 From 62388622b5aa21a5e93bd92b1a626899b0f021af Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 13:44:39 +0100 Subject: [PATCH 78/86] dd --- colpali_engine/trainer/colmodel_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 2293eb7d1..85c33f0a8 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -137,7 +137,7 @@ def train(self) -> None: # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, - num_proc=self.config.tr_args.dataloader_num_workers*16, + num_proc=self.config.tr_args.dataloader_num_workers*8, writer_batch_size=32) From 1adc6b85d3eb363a0b1b7161bc240192efa6f134 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 14:26:17 +0100 Subject: [PATCH 79/86] test --- colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py index db9c7455d..8e54b2e5a 100644 --- a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py @@ -109,11 +109,11 @@ def process_images( """ texts_doc = [self.visual_prompt_prefix] * len(images) - resized_images: List[Image.Image] = [self.smart_resize(image) for image in images] + # resized_images: List[Image.Image] = [self.smart_resize(image) for image in images] batch_doc = self( text=texts_doc, - images=resized_images, + images=images, padding="longest", return_tensors="pt", ) From 2db198d4dd83359e3b0a98cfdf0e9229b1c1cfa3 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 14:35:54 +0100 Subject: [PATCH 80/86] call --- colpali_engine/trainer/colmodel_training.py | 67 ++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 85c33f0a8..5ffeba032 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -136,7 +136,72 @@ def train(self) -> None: # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) - self.dataset = self.dataset.map(lambda x: {"image": self.config.processor.smart_resize(x["image"])}, + from PIL import Image + import math + + def round_by_factor(number: float, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: float, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: float, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize_helper( + width: int, + height: int, + factor: int, + max_ratio: int, + min_pixels: int, + max_pixels: int, + ) -> Tuple[int, int]: + """ + Returns the image size so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + + if max(height, width) / min(height, width) > max_ratio: + raise ValueError( + f"absolute aspect ratio must be smaller than {max_ratio}, " + f"got {max(height, width) / min(height, width)}" + ) + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + return h_bar, w_bar + + def smart_resize(image: Image.Image) -> Image.Image: + """ + Resize and convert the image to the required format. + """ + image_size = image.size + resized_height, resized_width = smart_resize_helper( + width=image_size[0], + height=image_size[1], + factor=28, + max_ratio=200, + min_pixels=self.config.processor.min_pixels, + max_pixels=self.config.processor.max_pixels, + ) + return image.convert("RGB").resize((resized_width, resized_height)) + + self.dataset = self.dataset.map(lambda x: {"image": smart_resize(x["image"])}, num_proc=self.config.tr_args.dataloader_num_workers*8, writer_batch_size=32) From c1770960480566c90f6f976fd16ee5caeb3d5be1 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 15:02:48 +0100 Subject: [PATCH 81/86] revert --- .../qwen2/colqwen2/processing_colqwen2.py | 4 +- colpali_engine/trainer/colmodel_training.py | 70 ------------------- 2 files changed, 2 insertions(+), 72 deletions(-) diff --git a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py index 8e54b2e5a..db9c7455d 100644 --- a/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py @@ -109,11 +109,11 @@ def process_images( """ texts_doc = [self.visual_prompt_prefix] * len(images) - # resized_images: List[Image.Image] = [self.smart_resize(image) for image in images] + resized_images: List[Image.Image] = [self.smart_resize(image) for image in images] batch_doc = self( text=texts_doc, - images=images, + images=resized_images, padding="longest", return_tensors="pt", ) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 5ffeba032..f8312bdd6 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -136,76 +136,6 @@ def train(self) -> None: # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) - from PIL import Image - import math - - def round_by_factor(number: float, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - def ceil_by_factor(number: float, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - def floor_by_factor(number: float, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - def smart_resize_helper( - width: int, - height: int, - factor: int, - max_ratio: int, - min_pixels: int, - max_pixels: int, - ) -> Tuple[int, int]: - """ - Returns the image size so that the following conditions are met: - 1. Both dimensions (height and width) are divisible by 'factor'. - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - 3. The aspect ratio of the image is maintained as closely as possible. - """ - - if max(height, width) / min(height, width) > max_ratio: - raise ValueError( - f"absolute aspect ratio must be smaller than {max_ratio}, " - f"got {max(height, width) / min(height, width)}" - ) - - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - - return h_bar, w_bar - - def smart_resize(image: Image.Image) -> Image.Image: - """ - Resize and convert the image to the required format. - """ - image_size = image.size - resized_height, resized_width = smart_resize_helper( - width=image_size[0], - height=image_size[1], - factor=28, - max_ratio=200, - min_pixels=self.config.processor.min_pixels, - max_pixels=self.config.processor.max_pixels, - ) - return image.convert("RGB").resize((resized_width, resized_height)) - - self.dataset = self.dataset.map(lambda x: {"image": smart_resize(x["image"])}, - num_proc=self.config.tr_args.dataloader_num_workers*8, - writer_batch_size=32) - - trainer = ContrastiveTrainer( model=self.model, train_dataset=self.dataset["train"], From a4ecc36f3960aeade595e3cbcae63278d1105565 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 15:06:43 +0100 Subject: [PATCH 82/86] fff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 55dfcac4a..1e0b5c71f 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -1,6 +1,6 @@ config: (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig - output_dir: !path ../../../models/colqwen2-gradcache + output_dir: !path ../../../models/colqwen2-gradcache-2048-3e processor: (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor @@ -35,8 +35,8 @@ config: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true - num_train_epochs: 1 - per_device_train_batch_size: 1024 + num_train_epochs: 3 + per_device_train_batch_size: 512 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -49,7 +49,7 @@ config: save_steps: 500 logging_steps: 10 eval_steps: 100 - warmup_steps: 100 + warmup_steps: 20 learning_rate: 5e-4 save_total_limit: 1 # resume_from_checkpoint: true From 7d9decfdd1d522de8c4e7a65cfe2bbf27fd7e081 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 15:22:58 +0100 Subject: [PATCH 83/86] gradcache --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 1e0b5c71f..1fbd2275c 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,7 +30,7 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 64 + mini_batch_size: 32 tr_args: (): transformers.training_args.TrainingArguments output_dir: null From 369fe1a15d0c516834590c04be88e9a11601838e Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 15:35:42 +0100 Subject: [PATCH 84/86] test --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 1fbd2275c..5f9074a9b 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 4 + dataloader_num_workers: 2 # 4 # bf16: true save_steps: 500 logging_steps: 10 @@ -57,7 +57,7 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" - dataloader_pin_memory: false # true # false + dataloader_pin_memory: true # true # false torch_compile: true report_to: "wandb" From 3748e8b1e32154a56a6e57e05f43a3afb7a46c23 Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 16:59:23 +0100 Subject: [PATCH 85/86] fff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 5f9074a9b..0011d2fd3 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,13 +30,13 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 32 + mini_batch_size: 64 tr_args: (): transformers.training_args.TrainingArguments output_dir: null overwrite_output_dir: true num_train_epochs: 3 - per_device_train_batch_size: 512 + per_device_train_batch_size: 256 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } # gradient_checkpointing: true @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 2 # 4 + dataloader_num_workers: 8 # 4 # bf16: true save_steps: 500 logging_steps: 10 From 487085deb8aa8f0590ff9e6182307eee058bc0bc Mon Sep 17 00:00:00 2001 From: ManuelFay Date: Fri, 21 Feb 2025 17:39:52 +0100 Subject: [PATCH 86/86] ff --- scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml index 0011d2fd3..70bd39bb6 100644 --- a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -30,7 +30,7 @@ config: run_eval: true loss_func: (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # - mini_batch_size: 64 + mini_batch_size: 32 tr_args: (): transformers.training_args.TrainingArguments output_dir: null @@ -44,7 +44,7 @@ config: # gradient_accumulation_steps: 4 per_device_eval_batch_size: 16 eval_strategy: "steps" - dataloader_num_workers: 8 # 4 + dataloader_num_workers: 4 # 4 # bf16: true save_steps: 500 logging_steps: 10