diff --git a/colpali_engine/collators/corpus_query_collator.py b/colpali_engine/collators/corpus_query_collator.py index 8a95e3ebe..74684b81c 100644 --- a/colpali_engine/collators/corpus_query_collator.py +++ b/colpali_engine/collators/corpus_query_collator.py @@ -15,16 +15,19 @@ def __init__( image_dataset: Optional["Dataset"] = None, # noqa: F821 mined_negatives: bool = True, corpus_format: str = "wikiss", + process_images_before_training: bool = False, ): super().__init__( processor=processor, max_length=max_length, + process_images_before_training=process_images_before_training, ) if image_dataset is None: raise ValueError("`image_dataset` must be provided") self.image_dataset = image_dataset self.mined_negatives = mined_negatives self.corpus_format = corpus_format + self.process_images_before_training = process_images_before_training if self.corpus_format == "wikiss": print("Mapping docids to indices") diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 070115198..5c5fc4827 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Union +import torch from PIL.Image import Image from colpali_engine.models.idefics_2 import ColIdefics2Processor @@ -16,10 +17,12 @@ def __init__( self, processor: BaseVisualRetrieverProcessor, max_length: int = 2048, + process_images_before_training: bool = False, ): self.processor = processor self.image_token_id = None self.max_length = max_length + self.process_images_before_training = process_images_before_training if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor): self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[ @@ -35,6 +38,15 @@ def __init__( def __call__( self, examples: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if self.process_images_before_training: + return self.offline_processing(examples) + return self.online_processing(examples) + + + def online_processing( + self, + examples: List[Dict[str, Any]], ) -> Dict[str, Any]: """ Collate function for the vision retriever associated to the collator's processor. @@ -43,20 +55,16 @@ def __call__( texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query images: List[Image] = [] neg_images: List[Image] = [] + # breakpoint() - if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor): - raise ValueError("Processor should be provided for vision collator.") # Process each example for example in examples: texts_query.append(example["query"]) - if example["image"] is None: - raise ValueError("Image is None - This collator does not support None images yet.") - - images.append(cast(Image, example["image"])) + images.append(example["image"]) if "neg_image" in example and example["neg_image"] is not None: - neg_images.append(cast(Image, example["neg_image"])) + neg_images.append(example["neg_image"]) # Process the documents batch_doc = self.processor.process_images( @@ -76,11 +84,8 @@ def __call__( if all([t is None for t in texts_query]): # print("All queries are `None`. Returning `None` for all queries.") pass - elif any([t is None for t in texts_query]): - # If it's the first query that is not None but the rest are None, then it's hard negatives. - raise ValueError("Some queries are None. This collator does not support None queries yet.") else: - texts_query = cast(List[str], texts_query) + texts_query: List[str] = texts_query batch_query = self.processor.process_queries( queries=texts_query, max_length=self.max_length, @@ -98,3 +103,82 @@ def __call__( batch_all.update(batch_neg_doc) return batch_all + + + def offline_processing( + self, + examples: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """ + Collate function for the vision retriever associated to the collator's processor. + """ + # Placeholders + texts_query = [] + pixel_values = [] + image_grid_thw = [] + input_ids = [] + attention_mask = [] + neg_pixel_values = [] + neg_image_grid_thw = [] + neg_input_ids = [] + neg_attention_mask = [] + + for example in examples: + texts_query.append(example["query"]) + pixel_values.append(example["pixel_values"]) + image_grid_thw.append(example["image_grid_thw"]) + input_ids.append(example["input_ids"]) + attention_mask.append(example["attention_mask"]) + + if "neg_pixel_values" in example: + neg_pixel_values.append(example["neg_pixel_values"]) + neg_image_grid_thw.append(example["neg_image_grid_thw"]) + neg_input_ids.append(example["neg_input_ids"]) + neg_attention_mask.append(example["neg_attention_mask"]) + + # Pad pixel values + pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0) + image_grid_thw = torch.stack(image_grid_thw) + + # Pad input sequences + batch_doc = self.processor.tokenizer.pad( + {"input_ids": input_ids, "attention_mask": attention_mask}, + padding=True, + return_tensors="pt" + ) + + batch_all = { + "doc_pixel_values": pixel_values, + "doc_image_grid_thw": image_grid_thw, + "doc_input_ids": batch_doc["input_ids"], + "doc_attention_mask": batch_doc["attention_mask"], + } + + # Process queries + if any(texts_query): # Ensure there are valid queries + batch_query = self.processor.process_queries( + queries=texts_query, + max_length=self.max_length + ) + batch_all["query_input_ids"] = batch_query["input_ids"] + batch_all["query_attention_mask"] = batch_query["attention_mask"] + + # Process negatives if present + if neg_pixel_values: + neg_pixel_values = torch.nn.utils.rnn.pad_sequence(neg_pixel_values, batch_first=True, padding_value=0) + neg_image_grid_thw = torch.stack(neg_image_grid_thw) + + batch_neg_doc = self.processor.tokenizer.pad( + {"input_ids": neg_input_ids, "attention_mask": neg_attention_mask}, + padding=True, + return_tensors="pt" + ) + + batch_all.update({ + "neg_doc_pixel_values": neg_pixel_values, + "neg_doc_image_grid_thw": neg_image_grid_thw, + "neg_doc_input_ids": batch_neg_doc["input_ids"], + "neg_doc_attention_mask": batch_neg_doc["attention_mask"], + }) + + return batch_all diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index 1e08318bb..4dc461473 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -3,6 +3,11 @@ BiPairwiseCELoss, BiPairwiseNegativeCELoss, ) +from .gradcache_late_interaction_losses import ( + GradCacheColbertLoss, + GradCacheColbertPairwiseCELoss, + GradCacheColbertPairwiseNegativeCELoss, +) from .late_interaction_losses import ( ColbertLoss, ColbertPairwiseCELoss, diff --git a/colpali_engine/loss/gradcache_late_interaction_losses.py b/colpali_engine/loss/gradcache_late_interaction_losses.py new file mode 100644 index 000000000..05f2a3878 --- /dev/null +++ b/colpali_engine/loss/gradcache_late_interaction_losses.py @@ -0,0 +1,411 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import tqdm +from torch.utils.checkpoint import get_device_states, set_device_states + + +class RandContext: + """ + Random-state context manager that captures both CPU and GPU random states. + This ensures that when re‑executing a forward pass (e.g. in GradCache’s second pass), + stochastic operations produce identical outputs. + """ + def __init__(self, *tensors) -> None: + # Capture CPU RNG state. + self.fwd_cpu_state = torch.get_rng_state() + # Capture GPU states for all devices associated with the provided tensors. + self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) + + def __enter__(self) -> None: + # Fork the RNG states on the captured devices. + self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True) + self._fork.__enter__() + # Reset the CPU RNG state. + torch.set_rng_state(self.fwd_cpu_state) + # Reset the GPU RNG states. + set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self._fork.__exit__(exc_type, exc_val, exc_tb) + self._fork = None + + + +def _backward_hook(grad_output, sentence_features, random_states, loss_obj, model): + """ + Backward hook that re-computes the embeddings in mini-batches with gradients enabled + and uses the cached gradients to backpropagate. This version wraps the forward pass in the + corresponding RandContext to reproduce the same randomness. + """ + mini_batch_size = loss_obj.mini_batch_size + # sentence_features: a list with two dicts [query_features, doc_features] + # random_states: a list with two lists of RandContext objects.1 + assert loss_obj.cache is not None + assert random_states is not None + with torch.enable_grad(): + for branch_feature, branch_cache, branch_random_states in zip(sentence_features, loss_obj.cache, random_states): + input_ids = branch_feature["input_ids"] + bsz = input_ids.size(0) + # Iterate over mini-batches. + for idx, start in enumerate(range(0, bsz, mini_batch_size)): + end = start + mini_batch_size + mini_feature = {k: v[start:end] for k, v in branch_feature.items()} + # Use the stored RandContext if available. + r_state = branch_random_states[idx] + if r_state is not None: + with r_state: + mini_embeds = model.forward(**mini_feature) + else: + mini_embeds = model.forward(**mini_feature) + # mini_embeds = mini_embeds.detach().requires_grad_(True) + cached_grad = branch_cache[idx] + # Compute a surrogate loss that replays the cached gradient. + surrogate = torch.dot(mini_embeds.flatten(), cached_grad.flatten()) * grad_output + surrogate.backward() + + +class GradCacheColbertLoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, scale: float = 1.0, show_progress_bar: bool = False): + """ + GradCache enabled version of the ColBERT loss. + + Args: + mini_batch_size: Number of items per mini-batch. + scale: Scaling factor for the similarity scores. + show_progress_bar: If True, shows progress bars during mini-batch processing. + """ + super().__init__() + self.mini_batch_size = mini_batch_size + self.scale = scale + self.ce_loss = nn.CrossEntropyLoss() + self.cache = None + self.random_states = None + self.show_progress_bar = show_progress_bar + self.gradcache_enabled = True # Flag indicating GradCache is active. + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = None + if copy_random_state: + random_state = RandContext(*mini_feature.values()) + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) # is this the key ? + yield mini_embeds, random_state + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + """ + Calculate the ColBERT-style loss. + reps: list with two elements – reps[0] for query embeddings, reps[1] for doc embeddings. + Each element is a list of mini-batch tensors. + """ + embeddings_query = torch.cat(reps[0], dim=0) # shape: (total_query, seq_len, dim) + embeddings_doc = torch.cat(reps[1], dim=0) # shape: (total_doc, seq_len, dim) + scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc).max(dim=3)[0].sum(dim=2) + batch_size = scores.size(0) + labels = torch.arange(batch_size, device=scores.device) + loss = self.ce_loss(scores * self.scale, labels) + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + # Cache gradients for each mini-batch. + self.cache = [] + for branch in reps: + branch_cache = [] + for r in branch: + branch_cache.append(r.grad) + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: + """ + inputs: dict containing keys with prefixes "query_" and "doc_". + """ + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + + # === First Pass: Get embeddings without gradients, capturing RandContext. + reps_query = [] + rs_query = [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc = [] + rs_doc = [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps = [reps_query, reps_doc] + self.random_states = [rs_query, rs_doc] + + if torch.is_grad_enabled(): + # Step (2): Compute loss and cache gradients. + loss = self.calculate_loss_and_cache_gradients(reps) + # Step (3): Re-run embeddings with gradients enabled and register a hook that uses the cached gradients. + loss.register_hook(partial(_backward_hook, sentence_features=[query_features, doc_features], + random_states=self.random_states, loss_obj=self, model=model)) + else: + loss = self.calculate_loss(reps, with_backward=False) + return loss + + + +class GradCacheColbertPairwiseCELoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, scale: float = 1.0, show_progress_bar: bool = False): + """ + GradCache-enabled version of the ColBERTPairwiseCELoss. + """ + super().__init__() + self.mini_batch_size = mini_batch_size + self.scale = scale + self.ce_loss = nn.CrossEntropyLoss() + self.cache = None + self.random_states = None + self.show_progress_bar = show_progress_bar + self.gradcache_enabled = True + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = RandContext(*mini_feature.values()) if copy_random_state else None + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) + yield mini_embeds, random_state + + # def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + # """ + # Compute the ColBERTPairwiseCELoss using cached embeddings. + # reps is a list with two elements: reps[0] for query embeddings and reps[1] for doc embeddings. + # """ + # embeddings_query = torch.cat(reps[0], dim=0) # shape: (batch, num_query_tokens, dim) + # embeddings_doc = torch.cat(reps[1], dim=0) + # # shape: (batch, num_doc_tokens, dim) + # print(f"embeddings_query.shape: {embeddings_query.shape}; embeddings_doc.shape: {embeddings_doc.shape}") + # # breakpoint() + # scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ + # .max(dim=3)[0].sum(dim=2) # (batch, batch) + # pos_scores = scores.diagonal() + # neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 + # neg_scores = neg_scores.max(dim=1)[0] + # loss = F.softplus(neg_scores - pos_scores).mean() + # if with_backward: + # loss.backward() + # return loss + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + """ + Compute the ColBERTPairwiseCELoss using cached embeddings without concatenating query embeddings. + reps[0] contains query embedding chunks (each of shape: (chunk_size, num_query_tokens, dim)), + while reps[1] contains doc embeddings, which we concatenate. + + For each query chunk, we: + - Compute scores with all docs using an einsum. + - Reduce the scores by taking a max over the doc tokens and summing over query tokens. + - Extract the positive score for each query based on its overall index (assuming query i matches doc i). + - Mask out the positive score and take the max over negatives. + - Compute the softplus loss over the difference (neg_score - pos_score). + + The overall loss is the average over all queries, and remains differentiable. + """ + # Concatenate document embeddings (shape: (total_docs, num_doc_tokens, dim)) + embeddings_doc = torch.cat(reps[1], dim=0) + + total_loss = 0.0 + total_queries = 0 + global_index = 0 # Tracks the overall index for positive pairing + + # Loop over query chunks + for query_chunk in reps[0]: + chunk_size = query_chunk.size(0) + # print(f"Shape of query chunk: {query_chunk.shape}, Shape of embeddings doc: {embeddings_doc.shape}") + # Compute pairwise scores: + # Resulting shape: (chunk_size, total_docs, num_query_tokens, num_doc_tokens) + scores_chunk = torch.einsum("bnd,csd->bcns", query_chunk, embeddings_doc) + # Reduce: max over document tokens then sum over query tokens -> shape: (chunk_size, total_docs) + scores_chunk = scores_chunk.max(dim=3)[0].sum(dim=2) + + # For each query in the chunk, the positive doc index is global_index + local_index + row_idx = torch.arange(chunk_size, device=scores_chunk.device) + pos_idx = torch.arange(global_index, global_index + chunk_size, device=scores_chunk.device) + pos_scores = scores_chunk[row_idx, pos_idx] + + # Mask out the positive scores by setting them to a very low value, then take the max over negatives + # scores_masked = scores_chunk.clone() + scores_chunk[row_idx, pos_idx] = -1e6 + neg_scores = scores_chunk.max(dim=1)[0] + + # Compute loss for this chunk (sum over the chunk's queries) + chunk_loss = F.softplus(neg_scores - pos_scores).sum() + total_loss += chunk_loss + total_queries += chunk_size + global_index += chunk_size + + loss = total_loss / total_queries + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + self.cache = [] + for branch in reps: + branch_cache = [r.grad for r in branch] + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + + # First pass: get embeddings without gradients (and capture RandContext). + reps_query, rs_query = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc, rs_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps = [reps_query, reps_doc] + self.random_states = [rs_query, rs_doc] + + if torch.is_grad_enabled(): + loss = self.calculate_loss_and_cache_gradients(reps) + loss.register_hook(partial(_backward_hook, + sentence_features=[query_features, doc_features], + random_states=self.random_states, + loss_obj=self, model=model)) + else: + loss = self.calculate_loss(reps, with_backward=False) + return loss + + + +class GradCacheColbertPairwiseNegativeCELoss(nn.Module): + def __init__(self, mini_batch_size: int = 32, in_batch_term: bool = False, show_progress_bar: bool = False): + """ + GradCache-enabled version of the ColBERTPairwiseNegativeCELoss. + + Args: + in_batch_term: If True, includes an additional in-batch loss term. + """ + super().__init__() + self.mini_batch_size = mini_batch_size + self.in_batch_term = in_batch_term + self.cache = None + self.random_states = None + self.show_progress_bar = show_progress_bar + self.gradcache_enabled = True + + def embed_minibatch_iter(self, model, sentence_feature: dict, with_grad: bool, copy_random_state: bool): + input_ids = sentence_feature["input_ids"] + bsz = input_ids.size(0) + for start in tqdm.trange(0, bsz, self.mini_batch_size, desc="Embedding minibatches", + disable=not self.show_progress_bar): + end = start + self.mini_batch_size + mini_feature = {k: v[start:end] for k, v in sentence_feature.items()} + random_state = RandContext(*mini_feature.values()) if copy_random_state else None + grad_context = torch.enable_grad() if with_grad else torch.no_grad() + with grad_context: + mini_embeds = model.forward(**mini_feature) + mini_embeds = mini_embeds.detach().requires_grad_(True) + yield mini_embeds, random_state + + def calculate_loss(self, reps: list[list[torch.Tensor]], with_backward: bool = False) -> torch.Tensor: + """ + Compute the ColBERTPairwiseNegativeCELoss. + reps is a list with three elements: + reps[0]: query embeddings, + reps[1]: positive doc embeddings, + reps[2]: negative doc embeddings. + """ + embeddings_query = torch.cat(reps[0], dim=0) # (batch, num_query_tokens, dim) + embeddings_doc = torch.cat(reps[1], dim=0) # (batch, num_doc_tokens, dim) + embeddings_neg_doc = torch.cat(reps[2], dim=0) # (batch, num_neg_doc_tokens, dim) + + # Compute scores for positive and negative documents. + pos_scores = torch.einsum("bnd,bsd->bns", embeddings_query, embeddings_doc) \ + .max(dim=2)[0].sum(dim=1) + neg_scores = torch.einsum("bnd,bsd->bns", embeddings_query, embeddings_neg_doc) \ + .max(dim=2)[0].sum(dim=1) + loss = F.softplus(neg_scores - pos_scores).mean() + + if self.in_batch_term: + scores = torch.einsum("bnd,csd->bcns", embeddings_query, embeddings_doc) \ + .max(dim=3)[0].sum(dim=2) + pos_scores_in = scores.diagonal() + neg_scores_in = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 + neg_scores_in = neg_scores_in.max(dim=1)[0] + loss_in = F.softplus(neg_scores_in - pos_scores_in).mean() + loss = (loss + loss_in) / 2 + + if with_backward: + loss.backward() + return loss + + def calculate_loss_and_cache_gradients(self, reps: list[list[torch.Tensor]]) -> torch.Tensor: + loss = self.calculate_loss(reps, with_backward=True) + loss = loss.detach().requires_grad_() + self.cache = [] + for branch in reps: + branch_cache = [r.grad for r in branch] + self.cache.append(branch_cache) + return loss + + def forward(self, model, inputs: dict) -> torch.Tensor: + # Remove prefixes. + query_features = {k.replace("query_", ""): v for k, v in inputs.items() if k.startswith("query_")} + doc_features = {k.replace("doc_", ""): v for k, v in inputs.items() if k.startswith("doc_")} + neg_doc_features = {k.replace("neg_doc_", ""): v for k, v in inputs.items() if k.startswith("neg_doc_")} + + # First pass: get embeddings without gradients and capture RandContext. + reps_query, rs_query = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, query_features, with_grad=False, + copy_random_state=True): + reps_query.append(mini_embeds) + rs_query.append(rs) + reps_doc, rs_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, doc_features, with_grad=False, copy_random_state=True): + reps_doc.append(mini_embeds) + rs_doc.append(rs) + reps_neg_doc, rs_neg_doc = [], [] + for mini_embeds, rs in self.embed_minibatch_iter(model, neg_doc_features, with_grad=False, + copy_random_state=True): + reps_neg_doc.append(mini_embeds) + rs_neg_doc.append(rs) + reps = [reps_query, reps_doc, reps_neg_doc] + self.random_states = [rs_query, rs_doc, rs_neg_doc] + + if torch.is_grad_enabled(): + loss = self.calculate_loss_and_cache_gradients(reps) + loss.register_hook(partial(_backward_hook, + sentence_features=[query_features, doc_features, neg_doc_features], + random_states=self.random_states, + loss_obj=self, model=model)) + else: + loss = self.calculate_loss(reps, with_backward=False) + return loss diff --git a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py index 6b43d13b5..29abaa805 100644 --- a/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py +++ b/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py @@ -41,21 +41,22 @@ def from_pretrained( ) def inner_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: + if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: @@ -90,11 +91,13 @@ def inner_forward( hidden_states = outputs[0] return hidden_states + + def forward(self, *args, **kwargs) -> torch.Tensor: # Delete output_hidden_states from kwargs kwargs.pop("output_hidden_states", None) - # The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs + # Hack to make sure scatter in DDP is done correctly on multiple GPUs. if "pixel_values" in kwargs: # compute pixel_values offsets offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] @@ -109,15 +112,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor: video_grid_thw=None, attention_mask=kwargs.get("attention_mask", None), ) - last_hidden_states = self.inner_forward( - *args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True - ) # (batch_size, sequence_length, hidden_size) - - proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) - - # L2 normalization - proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) - proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) + last_hidden_states = self.inner_forward(*args, + **kwargs, + position_ids=position_ids, + use_cache=False, + output_hidden_states=True) + proj = self.custom_text_proj(last_hidden_states) + proj = proj / proj.norm(dim=-1, keepdim=True) # L2 normalization + proj = proj * kwargs["attention_mask"].unsqueeze(-1) return proj @property diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index ead037d65..f8312bdd6 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -109,15 +109,33 @@ def __init__(self, config: ColModelTrainingConfig) -> None: processor=self.config.processor, max_length=self.config.max_length, ) + self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retrieval_evaluator = CustomRetrievalEvaluator() + @staticmethod + def preprocess_example(example: Dict, processor): + processed = processor.process_images([example["image"]]) + new_example = {} + for key in processed: + new_example[key] = processed[key].squeeze(0) + if "neg_image" in example and example["neg_image"] is not None: + neg_processed = processor.process_images([example["neg_image"]]) + for key in neg_processed: + new_example[f"neg_{key}"] = neg_processed[key].squeeze(0) + + return new_example + def train(self) -> None: if isinstance(self.collator, CorpusQueryCollator) and self.collator.mined_negatives: print("Training with hard negatives") else: print("Training with in-batch negatives") + + # self.dataset = self.dataset.map(lambda x: self.preprocess_example(x, self.config.processor), + # num_proc=self.config.tr_args.dataloader_num_workers, writer_batch_size=32) + trainer = ContrastiveTrainer( model=self.model, train_dataset=self.dataset["train"], diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index abb479b2c..7712c27a6 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -9,16 +9,21 @@ def __init__(self, loss_func, is_vision_model, *args, **kwargs): self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) - # feed only kwargs with 'doc_' prefix - doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - if "neg_doc_input_ids" in inputs: - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss + # If the loss function supports gradcache, delegate the computation. + if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: + loss = self.loss_func(model, inputs) + return (loss, None) if return_outputs else loss + else: + query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + # feed only kwargs with 'doc_' prefix + doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) + if "neg_doc_input_ids" in inputs: + neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) + return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss - loss = self.loss_func(query_outputs, doc_outputs) - return (loss, (query_outputs, doc_outputs)) if return_outputs else loss + loss = self.loss_func(query_outputs, doc_outputs) + return (loss, (query_outputs, doc_outputs)) if return_outputs else loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True): """This function is used to generate predictions and return the loss for the given inputs.""" @@ -26,13 +31,16 @@ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True) raise ValueError("prediction_step is only called with prediction_loss_only=True") with torch.no_grad(): - # feed only kwargs with 'doc_' prefix - doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) - if "neg_doc_input_ids" in inputs: - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return loss, None, None - - loss = self.loss_func(query_outputs, doc_outputs) + if hasattr(self.loss_func, "gradcache_enabled") and self.loss_func.gradcache_enabled: + loss = self.loss_func(model, inputs) + else: + # feed only kwargs with 'doc_' prefix + doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) + query_outputs = model(input_ids=inputs["query_input_ids"], + attention_mask=inputs["query_attention_mask"]) + if "neg_doc_input_ids" in inputs: + neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) + return loss, None, None + loss = self.loss_func(query_outputs, doc_outputs) return loss, None, None diff --git a/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml new file mode 100644 index 000000000..70bd39bb6 --- /dev/null +++ b/scripts/configs/qwen2/train_colqwen2_gradcache_model.yaml @@ -0,0 +1,75 @@ +config: + (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig + output_dir: !path ../../../models/colqwen2-gradcache-2048-3e + processor: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor + pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" + # num_image_tokens: 2048 + # max_length: 50 + + model: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2 + pretrained_model_name_or_path: "./models/colqwen2_base" + torch_dtype: !ext torch.bfloat16 + use_cache: false + attn_implementation: "flash_attention_2" +# device_map: "auto" +# quantization_config: +# (): transformers.BitsAndBytesConfig +# load_in_4bit: true +# bnb_4bit_quant_type: "nf4" +# bnb_4bit_compute_dtype: "bfloat16" +# bnb_4bit_use_double_quant: true + + dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set + eval_dataset_loader: !import ../data/test_data.yaml + + # max_length: 50 + run_eval: true + loss_func: + (): colpali_engine.loss.gradcache_late_interaction_losses.GradCacheColbertPairwiseCELoss # GradCacheColbertLoss # + mini_batch_size: 32 + tr_args: + (): transformers.training_args.TrainingArguments + output_dir: null + overwrite_output_dir: true + num_train_epochs: 3 + per_device_train_batch_size: 256 + gradient_checkpointing: true + gradient_checkpointing_kwargs: { "use_reentrant": false } + # gradient_checkpointing: true + # 6 x 8 gpus = 48 batch size + # gradient_accumulation_steps: 4 + per_device_eval_batch_size: 16 + eval_strategy: "steps" + dataloader_num_workers: 4 # 4 + # bf16: true + save_steps: 500 + logging_steps: 10 + eval_steps: 100 + warmup_steps: 20 + learning_rate: 5e-4 + save_total_limit: 1 + # resume_from_checkpoint: true + # optim: "paged_adamw_8bit" + # wandb logging + # wandb_project: "colqwen2" + # run_name: "colqwen2-ba32-nolora" + dataloader_pin_memory: true # true # false + torch_compile: true + report_to: "wandb" + + + peft_config: + (): peft.LoraConfig + r: 32 + lora_alpha: 32 + lora_dropout: 0.1 + init_lora_weights: "gaussian" + bias: "none" + task_type: "FEATURE_EXTRACTION" + target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + diff --git a/scripts/configs/qwen2/train_colqwen2_model.yaml b/scripts/configs/qwen2/train_colqwen2_model.yaml index 6ed38b7e2..9e18e239c 100644 --- a/scripts/configs/qwen2/train_colqwen2_model.yaml +++ b/scripts/configs/qwen2/train_colqwen2_model.yaml @@ -35,9 +35,10 @@ config: output_dir: null overwrite_output_dir: true num_train_epochs: 5 - per_device_train_batch_size: 64 + per_device_train_batch_size: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: { "use_reentrant": false } + ddp_find_unused_parameters: false # gradient_checkpointing: true # 6 x 8 gpus = 48 batch size # gradient_accumulation_steps: 4 @@ -56,6 +57,8 @@ config: # wandb logging # wandb_project: "colqwen2" # run_name: "colqwen2-ba32-nolora" + dataloader_pin_memory: true + torch_compile: true report_to: "wandb"