From 859c9e613f19b6bfcb661c722d6932197aedc944 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 20:53:36 +0900 Subject: [PATCH 01/17] gc reranker --- src/tevatron/reranker/driver/train.py | 15 +++++---- src/tevatron/reranker/trainer.py | 46 +++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index eaa17734..7618ab17 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -7,17 +7,18 @@ HfArgumentParser, set_seed, ) -from transformers import TrainingArguments - -from tevatron.reranker.arguments import ModelArguments, DataArguments +from tevatron.reranker.arguments import ModelArguments, DataArguments, \ + TevatronTrainingArguments as TrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset -from tevatron.reranker.trainer import RerankerTrainer from tevatron.reranker.collator import RerankerTrainCollator +from tevatron.reranker.trainer import RerankerTrainer +from tevatron.reranker.gc_trainer import GradCacheTrainer as GCTrainer logger = logging.getLogger(__name__) + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) @@ -65,6 +66,7 @@ def main(): if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.padding_side = 'right' + model = RerankerModel.build( model_args, training_args, @@ -74,7 +76,8 @@ def main(): train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) - trainer = RerankerTrainer( + trainer_cls = GCTrainer if training_args.grad_cache else RerankerTrainer + trainer = trainer_cls( model=model, args=training_args, train_dataset=train_dataset, @@ -89,4 +92,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index f49e1baf..f7236d2f 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -2,19 +2,53 @@ from typing import Optional import torch +from torch import Tensor +from torch.nn import functional as F from transformers.trainer import Trainer from transformers.deepspeed import is_deepspeed_zero3_enabled from peft import get_peft_model_state_dict import logging + logger = logging.getLogger(__name__) +try: + from grad_cache import GradCache + + _grad_cache_available = True +except ModuleNotFoundError: + _grad_cache_available = False + + +def split_inputs(model_input: dict, chunk_size: int): + keys = list(model_input.keys()) + chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] + return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)] + + +def get_rep(x): + return x.logits + class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super(RerankerTrainer, self).__init__(*args, **kwargs) + if not _grad_cache_available: + raise ValueError( + 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') + + self.gc = GradCache( + models=[self.model], + chunk_sizes=[self.args.gc_chunk_size], + loss_fn=self.compute_loss, + split_input_fn=split_inputs, + get_rep_fn=get_rep, + fp16=self.args.fp16, + scaler=self.scaler if self.args.fp16 else None + ) + def _save(self, output_dir: Optional[str] = None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) @@ -35,6 +69,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin")) print(f"Save adapter model at {output_dir}") + def compute_loss(self, model, inputs, return_outputs=False): + outputs = model(inputs) + loss = outputs.loss + return (loss, outputs) if return_outputs else loss - def compute_loss(self, model, inputs): - return model(inputs).loss + def training_step(self, model, inputs): + model.train() + _distributed = self.args.local_rank > -1 + self.gc.models = [model] + loss = self.gc(inputs, no_sync_except_last=_distributed) + return loss From f29930c9cb0c2e41602d824f58116ba8b81a4fa1 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 20:57:48 +0900 Subject: [PATCH 02/17] reranker arguments --- examples/example_rankllama.md | 1 + src/tevatron/reranker/arguments.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/examples/example_rankllama.md b/examples/example_rankllama.md index 31275087..ae7e1b5f 100644 --- a/examples/example_rankllama.md +++ b/examples/example_rankllama.md @@ -18,4 +18,5 @@ deepspeed --include localhost:4,5,6,7 --master_port 60000 --module tevatron.rera --num_train_epochs 1 \ --logging_steps 10 \ --overwrite_output_dir + --gra ``` \ No newline at end of file diff --git a/src/tevatron/reranker/arguments.py b/src/tevatron/reranker/arguments.py index a2089322..d295cfb7 100644 --- a/src/tevatron/reranker/arguments.py +++ b/src/tevatron/reranker/arguments.py @@ -116,3 +116,11 @@ class DataArguments: "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)." }, ) + +@dataclass +class TevatronTrainingArguments(TrainingArguments): + warmup_ratio: float = field(default=0.1) + + grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) + gc_q_chunk_size: int = field(default=4) + gc_p_chunk_size: int = field(default=32) \ No newline at end of file From 3d92da64c2b773c5dba6f999d4996717682ba0e1 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 20:59:02 +0900 Subject: [PATCH 03/17] fix: training arguments --- src/tevatron/reranker/driver/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 7618ab17..b1e09404 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -7,6 +7,7 @@ HfArgumentParser, set_seed, ) +from transformers import TrainingArguments from tevatron.reranker.arguments import ModelArguments, DataArguments, \ TevatronTrainingArguments as TrainingArguments From 0127d00470b40d9b2426ec00d27eb77f8d3a988a Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:00:04 +0900 Subject: [PATCH 04/17] fix: arguments --- src/tevatron/reranker/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tevatron/reranker/arguments.py b/src/tevatron/reranker/arguments.py index d295cfb7..11480f88 100644 --- a/src/tevatron/reranker/arguments.py +++ b/src/tevatron/reranker/arguments.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from typing import Optional - +from transformers import TrainingArguments @dataclass class ModelArguments: From 52616cdd3a198cb1f7a175d04e1e54f4e791f9a5 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:03:10 +0900 Subject: [PATCH 05/17] fix: trainer --- src/tevatron/reranker/driver/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index b1e09404..cf241549 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -77,7 +77,7 @@ def main(): train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) - trainer_cls = GCTrainer if training_args.grad_cache else RerankerTrainer + trainer_cls = RerankerTrainer trainer = trainer_cls( model=model, args=training_args, From cd09b576bc9bad344276950cd380c8c5b042685d Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:03:47 +0900 Subject: [PATCH 06/17] fix: trainer --- src/tevatron/reranker/driver/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index cf241549..eb56fd64 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -15,7 +15,6 @@ from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator from tevatron.reranker.trainer import RerankerTrainer -from tevatron.reranker.gc_trainer import GradCacheTrainer as GCTrainer logger = logging.getLogger(__name__) From 6d111052aecd1a797bf78057466f54e388867a24 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:07:36 +0900 Subject: [PATCH 07/17] parser --- src/tevatron/reranker/driver/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index eb56fd64..085e8eaf 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -20,7 +20,7 @@ def main(): - parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) From e98bbf37008236ef9296530181afd7c060988693 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:08:40 +0900 Subject: [PATCH 08/17] fix: paserr --- src/tevatron/reranker/driver/train.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 085e8eaf..54d72b24 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -10,7 +10,7 @@ from transformers import TrainingArguments from tevatron.reranker.arguments import ModelArguments, DataArguments, \ - TevatronTrainingArguments as TrainingArguments + TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator @@ -18,17 +18,21 @@ logger = logging.getLogger(__name__) - def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments + tevatron_args: TevatronTrainingArguments + + # Combine TrainingArguments and TevatronTrainingArguments + for key, value in vars(tevatron_args).items(): + setattr(training_args, key, value) if ( os.path.exists(training_args.output_dir) @@ -56,6 +60,7 @@ def main(): ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) + logger.info("Tevatron parameters %s", tevatron_args) set_seed(training_args.seed) @@ -76,8 +81,7 @@ def main(): train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) - trainer_cls = RerankerTrainer - trainer = trainer_cls( + trainer = RerankerTrainer( model=model, args=training_args, train_dataset=train_dataset, @@ -90,6 +94,5 @@ def main(): if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) - if __name__ == "__main__": main() \ No newline at end of file From 851f01b0e167285b61bacbcff5ff9291d4a873b8 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:10:26 +0900 Subject: [PATCH 09/17] fix: parser --- src/tevatron/reranker/arguments.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tevatron/reranker/arguments.py b/src/tevatron/reranker/arguments.py index 11480f88..ea99bf97 100644 --- a/src/tevatron/reranker/arguments.py +++ b/src/tevatron/reranker/arguments.py @@ -121,6 +121,5 @@ class DataArguments: class TevatronTrainingArguments(TrainingArguments): warmup_ratio: float = field(default=0.1) - grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) - gc_q_chunk_size: int = field(default=4) - gc_p_chunk_size: int = field(default=32) \ No newline at end of file + grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache"}) + gc_chunk_size: Optional[int] = field(default=None, metadata={"help": "Chunk size for gradient cache"}) From bb5d87cb9ec917a4bc7d2743981f4e3c772e30f8 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:12:36 +0900 Subject: [PATCH 10/17] fix: parser --- src/tevatron/reranker/driver/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 54d72b24..f6f9d7cd 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) def main(): - parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, TevatronTrainingArguments)) + parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) From 70d65ab3f4849b87a3ca20859d7a03bc14883428 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:14:50 +0900 Subject: [PATCH 11/17] fix: trainer --- src/tevatron/reranker/driver/train.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index f6f9d7cd..f0c1efca 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -7,10 +7,8 @@ HfArgumentParser, set_seed, ) -from transformers import TrainingArguments -from tevatron.reranker.arguments import ModelArguments, DataArguments, \ - TevatronTrainingArguments +from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator @@ -22,17 +20,9 @@ def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses() - model_args: ModelArguments - data_args: DataArguments - training_args: TrainingArguments - tevatron_args: TevatronTrainingArguments - - # Combine TrainingArguments and TevatronTrainingArguments - for key, value in vars(tevatron_args).items(): - setattr(training_args, key, value) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) @@ -60,7 +50,6 @@ def main(): ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) - logger.info("Tevatron parameters %s", tevatron_args) set_seed(training_args.seed) From bfb8e5a5a9ad3b9e7207d82edf70caac94ed8dc4 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:21:18 +0900 Subject: [PATCH 12/17] fix: trust_remote_code --- src/tevatron/reranker/driver/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index f0c1efca..910f669f 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -55,7 +55,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir + cache_dir=model_args.cache_dir, + trust_remote_code=True ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.unk_token_id From ca4b04b7fb40e4d30df16d12d6ee9c9261091209 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:35:41 +0900 Subject: [PATCH 13/17] hotfix: gradient_checkpointing_enable --- src/tevatron/reranker/modeling.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index 95929390..dde7dcd4 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -9,7 +9,6 @@ from transformers import TrainingArguments from peft import LoraConfig, PeftModel, TaskType, get_peft_model - from tevatron.reranker.arguments import ModelArguments import logging @@ -22,6 +21,7 @@ class RerankerOutput(ModelOutput): loss: Optional[Tensor] = None scores: Optional[Tensor] = None + class RerankerModel(nn.Module): TRANSFORMER_CLS = AutoModelForSequenceClassification @@ -49,17 +49,18 @@ def forward(self, pair: Dict[str, Tensor] = None): grouped_logits = ranker_logits.view(self.train_batch_size, -1) loss = self.cross_entropy(grouped_logits, self.target_label) return RerankerOutput( - loss = loss, - scores = ranker_logits + loss=loss, + scores=ranker_logits ) return RerankerOutput( - loss = None, - scores = ranker_logits + loss=None, + scores=ranker_logits ) - + def gradient_checkpointing_enable(self, **kwargs): - self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs) + return False + # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs) @classmethod def build( @@ -79,7 +80,9 @@ def build( base_model.enable_input_require_grads() if model_args.lora_name_or_path: lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs) - lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2") else: lora_config = LoraConfig( base_model_name_or_path=model_args.model_name_or_path, @@ -107,7 +110,9 @@ def load(cls, model_name_or_path: str, lora_name_or_path: str = None, **hf_kwargs): - base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2") if base_model.config.pad_token_id is None: base_model.config.pad_token_id = 0 if lora_name_or_path: From a79b647c7effb80f5477d0fdcacd4138b4045fb7 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 21:54:08 +0900 Subject: [PATCH 14/17] fix: forward method --- src/tevatron/reranker/modeling.py | 56 +++++++++++++------------ src/tevatron/reranker/trainer.py | 68 +++++-------------------------- 2 files changed, 42 insertions(+), 82 deletions(-) diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index dde7dcd4..229021b6 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -1,6 +1,6 @@ -import os +import logging from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional, Dict, Any import torch from torch import nn, Tensor @@ -11,8 +11,6 @@ from tevatron.reranker.arguments import ModelArguments -import logging - logger = logging.getLogger(__name__) @@ -27,6 +25,7 @@ class RerankerModel(nn.Module): def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None): super().__init__() + logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}") self.config = hf_model.config self.hf_model = hf_model self.train_batch_size = train_batch_size @@ -36,31 +35,26 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None): 'target_label', torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device) ) - for name, param in self.hf_model.named_parameters(): - # for some reason, ds zero 3 left some weights empty - if 'modules_to_save' in name and param.numel() == 0: - logger.warning(f'parameter {name}, shape {param.shape} is empty') - param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data - logger.warning('{} data: {}'.format(name, param.data.cpu().numpy())) - - def forward(self, pair: Dict[str, Tensor] = None): - ranker_logits = self.hf_model(**pair, return_dict=True).logits - if self.train_batch_size: - grouped_logits = ranker_logits.view(self.train_batch_size, -1) - loss = self.cross_entropy(grouped_logits, self.target_label) - return RerankerOutput( - loss=loss, - scores=ranker_logits - ) + logger.info(f"RerankerModel initialized with config: {self.config}") + + def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs): + logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}") + outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + if labels is not None: + loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels) + logger.debug(f"Computed loss: {loss.item()}") + else: + loss = None + logger.debug("No labels provided, skipping loss computation") return RerankerOutput( - loss=None, - scores=ranker_logits + loss=loss, + scores=outputs.logits ) - def gradient_checkpointing_enable(self, **kwargs): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None): return False - # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs) @classmethod def build( @@ -69,21 +63,27 @@ def build( train_args: TrainingArguments, **hf_kwargs, ): + logger.info(f"Building RerankerModel with args: {model_args}") base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name_or_path, **hf_kwargs, ) if base_model.config.pad_token_id is None: base_model.config.pad_token_id = 0 + logger.info("Set pad_token_id to 0") + if model_args.lora or model_args.lora_name_or_path: + logger.info("Applying LoRA") if train_args.gradient_checkpointing: base_model.enable_input_require_grads() if model_args.lora_name_or_path: + logger.info(f"Loading LoRA from {model_args.lora_name_or_path}") lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs) lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") else: + logger.info("Initializing new LoRA") lora_config = LoraConfig( base_model_name_or_path=model_args.model_name_or_path, task_type=TaskType.SEQ_CLS, @@ -99,6 +99,7 @@ def build( train_batch_size=train_args.per_device_train_batch_size, ) else: + logger.info("Building model without LoRA") model = cls( hf_model=base_model, train_batch_size=train_args.per_device_train_batch_size, @@ -110,12 +111,15 @@ def load(cls, model_name_or_path: str, lora_name_or_path: str = None, **hf_kwargs): + logger.info(f"Loading RerankerModel from {model_name_or_path}") base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") if base_model.config.pad_token_id is None: base_model.config.pad_token_id = 0 + logger.info("Set pad_token_id to 0") if lora_name_or_path: + logger.info(f"Loading LoRA from {lora_name_or_path}") lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs) lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config) lora_model = lora_model.merge_and_unload() @@ -123,10 +127,12 @@ def load(cls, hf_model=lora_model, ) else: + logger.info("Loading model without LoRA") model = cls( hf_model=base_model, ) return model def save(self, output_dir: str): - self.hf_model.save_pretrained(output_dir) + logger.info(f"Saving model to {output_dir}") + self.hf_model.save_pretrained(output_dir) \ No newline at end of file diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index f7236d2f..cba2b36e 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -1,81 +1,35 @@ -import os -from typing import Optional +from tevatron.reranker.modeling import RerankerOutput +from tevatron.retriever.trainer import TevatronTrainer +from grad_cache import GradCache -import torch -from torch import Tensor -from torch.nn import functional as F - -from transformers.trainer import Trainer -from transformers.deepspeed import is_deepspeed_zero3_enabled -from peft import get_peft_model_state_dict - -import logging - -logger = logging.getLogger(__name__) - -try: - from grad_cache import GradCache - - _grad_cache_available = True -except ModuleNotFoundError: - _grad_cache_available = False - - -def split_inputs(model_input: dict, chunk_size: int): +def split_inputs(model_input, chunk_size): keys = list(model_input.keys()) chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)] +def get_rep(x: RerankerOutput): + return x.scores -def get_rep(x): - return x.logits - - -class RerankerTrainer(Trainer): +class RerankerTrainer(TevatronTrainer): def __init__(self, *args, **kwargs): - super(RerankerTrainer, self).__init__(*args, **kwargs) - - if not _grad_cache_available: - raise ValueError( - 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') - + super().__init__(*args, **kwargs) + loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y}) self.gc = GradCache( models=[self.model], chunk_sizes=[self.args.gc_chunk_size], - loss_fn=self.compute_loss, + loss_fn=loss_fn, split_input_fn=split_inputs, get_rep_fn=get_rep, fp16=self.args.fp16, scaler=self.scaler if self.args.fp16 else None ) - def _save(self, output_dir: Optional[str] = None, state_dict=None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - logger.info("Saving model checkpoint to %s", output_dir) - self.model.save(output_dir) - - if is_deepspeed_zero3_enabled(): - if state_dict is None: - state_dict = self.model.state_dict() - prefix = 'hf_model.' - assert all( - k.startswith(prefix) or k == "target_label" - for k in state_dict.keys() - ), list(state_dict.keys()) - state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} - lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict) - if self.args.process_index <= 0: - torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin")) - print(f"Save adapter model at {output_dir}") - def compute_loss(self, model, inputs, return_outputs=False): - outputs = model(inputs) + outputs = model(**inputs) loss = outputs.loss return (loss, outputs) if return_outputs else loss def training_step(self, model, inputs): - model.train() _distributed = self.args.local_rank > -1 self.gc.models = [model] loss = self.gc(inputs, no_sync_except_last=_distributed) From 78ff5800d86b38c0a0a9282ced59f7d81726c006 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 22:01:19 +0900 Subject: [PATCH 15/17] fix: prediction step --- src/tevatron/reranker/modeling.py | 43 +++++------------------- src/tevatron/reranker/trainer.py | 56 +++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index 229021b6..887d6d24 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Optional, Dict, Any +from typing import Optional import torch from torch import nn, Tensor @@ -23,37 +23,22 @@ class RerankerOutput(ModelOutput): class RerankerModel(nn.Module): TRANSFORMER_CLS = AutoModelForSequenceClassification - def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None): + def __init__(self, hf_model: PreTrainedModel): super().__init__() - logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}") + logger.info("Initializing RerankerModel") self.config = hf_model.config self.hf_model = hf_model - self.train_batch_size = train_batch_size - self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') - if train_batch_size: - self.register_buffer( - 'target_label', - torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device) - ) logger.info(f"RerankerModel initialized with config: {self.config}") - def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs): + def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs): logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}") outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) - if labels is not None: - loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels) - logger.debug(f"Computed loss: {loss.item()}") - else: - loss = None - logger.debug("No labels provided, skipping loss computation") - return RerankerOutput( - loss=loss, scores=outputs.logits ) - def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None): return False @classmethod @@ -94,16 +79,10 @@ def build( inference_mode=False, ) lora_model = get_peft_model(base_model, lora_config) - model = cls( - hf_model=lora_model, - train_batch_size=train_args.per_device_train_batch_size, - ) + model = cls(hf_model=lora_model) else: logger.info("Building model without LoRA") - model = cls( - hf_model=base_model, - train_batch_size=train_args.per_device_train_batch_size, - ) + model = cls(hf_model=base_model) return model @classmethod @@ -123,14 +102,10 @@ def load(cls, lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs) lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config) lora_model = lora_model.merge_and_unload() - model = cls( - hf_model=lora_model, - ) + model = cls(hf_model=lora_model) else: logger.info("Loading model without LoRA") - model = cls( - hf_model=base_model, - ) + model = cls(hf_model=base_model) return model def save(self, output_dir: str): diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index cba2b36e..f25c723b 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -1,19 +1,38 @@ -from tevatron.reranker.modeling import RerankerOutput -from tevatron.retriever.trainer import TevatronTrainer +import logging +from typing import Dict, Union, Any + +import torch +from torch import nn +from transformers import Trainer +from transformers.trainer_utils import PredictionOutput + from grad_cache import GradCache +from tevatron.reranker.arguments import TevatronTrainingArguments + +logger = logging.getLogger(__name__) + def split_inputs(model_input, chunk_size): + logger.debug(f"Splitting inputs with chunk size: {chunk_size}") keys = list(model_input.keys()) chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)] -def get_rep(x: RerankerOutput): - return x.scores +def get_rep(model_output): + logger.debug(f"Getting representation from model output: {type(model_output)}") + return model_output.scores -class RerankerTrainer(TevatronTrainer): +class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y}) + logger.info("Initializing RerankerTrainer") + self.args: TevatronTrainingArguments + + def loss_fn(scores, labels): + grouped_scores = scores.view(self.args.train_group_size, -1) + labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device) + return nn.CrossEntropyLoss()(grouped_scores, labels) + self.gc = GradCache( models=[self.model], chunk_sizes=[self.args.gc_chunk_size], @@ -23,14 +42,37 @@ def __init__(self, *args, **kwargs): fp16=self.args.fp16, scaler=self.scaler if self.args.fp16 else None ) + logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}") def compute_loss(self, model, inputs, return_outputs=False): + logger.debug(f"Computing loss with inputs: {inputs.keys()}") outputs = model(**inputs) loss = outputs.loss + logger.debug(f"Computed loss: {loss.item()}") return (loss, outputs) if return_outputs else loss - def training_step(self, model, inputs): + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + logger.debug("Entering training step") + model.train() + inputs = self._prepare_inputs(inputs) _distributed = self.args.local_rank > -1 self.gc.models = [model] loss = self.gc(inputs, no_sync_except_last=_distributed) + logger.debug(f"Training step loss: {loss.item()}") return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: bool = None, + ) -> PredictionOutput: + logger.debug("Entering prediction step") + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + outputs = model(**inputs) + loss = outputs.loss + logits = outputs.scores + logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}") + return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None) From 43a642d3bee6543fd8da81e20541dc0dfeb3df8b Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 22:11:20 +0900 Subject: [PATCH 16/17] fix: trainer --- src/tevatron/reranker/arguments.py | 2 +- src/tevatron/reranker/driver/train.py | 9 ++-- src/tevatron/reranker/trainer.py | 62 ++++++++++++++++----------- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/tevatron/reranker/arguments.py b/src/tevatron/reranker/arguments.py index ea99bf97..a48b4468 100644 --- a/src/tevatron/reranker/arguments.py +++ b/src/tevatron/reranker/arguments.py @@ -122,4 +122,4 @@ class TevatronTrainingArguments(TrainingArguments): warmup_ratio: float = field(default=0.1) grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache"}) - gc_chunk_size: Optional[int] = field(default=None, metadata={"help": "Chunk size for gradient cache"}) + gc_chunk_size: Optional[int] = field(default=2, metadata={"help": "Chunk size for gradient cache"}) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 910f669f..8115d911 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -1,21 +1,20 @@ import logging import os import sys - from transformers import AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) - from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator -from tevatron.reranker.trainer import RerankerTrainer +from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer logger = logging.getLogger(__name__) + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) @@ -71,6 +70,9 @@ def main(): train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) + # Add GradCache-specific arguments to training_args + training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2) + trainer = RerankerTrainer( model=model, args=training_args, @@ -84,5 +86,6 @@ def main(): if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index f25c723b..79cc2bc1 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -3,51 +3,64 @@ import torch from torch import nn -from transformers import Trainer +from transformers import Trainer, TrainingArguments from transformers.trainer_utils import PredictionOutput from grad_cache import GradCache -from tevatron.reranker.arguments import TevatronTrainingArguments +from grad_cache.functional import cached, cat_input_tensor +from torch.cuda.amp import autocast logger = logging.getLogger(__name__) + +@cached +@autocast() +def get_model_rep(model, inputs): + outputs = model(**inputs) + return outputs.scores + + +@cat_input_tensor +@autocast() +def contrastive_loss(scores): + batch_size = scores.size(0) // 2 + labels = torch.arange(batch_size, device=scores.device) + return nn.CrossEntropyLoss()(scores, labels) + + def split_inputs(model_input, chunk_size): logger.debug(f"Splitting inputs with chunk size: {chunk_size}") keys = list(model_input.keys()) chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)] -def get_rep(model_output): - logger.debug(f"Getting representation from model output: {type(model_output)}") - return model_output.scores class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logger.info("Initializing RerankerTrainer") - self.args: TevatronTrainingArguments + logger.info("Initializing RerankerTrainer with GradCache") + self.args: TrainingArguments - def loss_fn(scores, labels): - grouped_scores = scores.view(self.args.train_group_size, -1) - labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device) - return nn.CrossEntropyLoss()(grouped_scores, labels) + # Add these lines to include the necessary parameters + self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided self.gc = GradCache( models=[self.model], - chunk_sizes=[self.args.gc_chunk_size], - loss_fn=loss_fn, + chunk_sizes=self.gc_chunk_size, + loss_fn=contrastive_loss, split_input_fn=split_inputs, - get_rep_fn=get_rep, + get_rep_fn=lambda x: x.scores, fp16=self.args.fp16, scaler=self.scaler if self.args.fp16 else None ) - logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}") + logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") def compute_loss(self, model, inputs, return_outputs=False): logger.debug(f"Computing loss with inputs: {inputs.keys()}") outputs = model(**inputs) - loss = outputs.loss + scores = outputs.scores + loss = contrastive_loss(scores) logger.debug(f"Computed loss: {loss.item()}") return (loss, outputs) if return_outputs else loss @@ -56,23 +69,22 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, model.train() inputs = self._prepare_inputs(inputs) _distributed = self.args.local_rank > -1 - self.gc.models = [model] loss = self.gc(inputs, no_sync_except_last=_distributed) logger.debug(f"Training step loss: {loss.item()}") return loss def prediction_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: bool = None, + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: bool = None, ) -> PredictionOutput: logger.debug("Entering prediction step") inputs = self._prepare_inputs(inputs) with torch.no_grad(): outputs = model(**inputs) - loss = outputs.loss - logits = outputs.scores + scores = outputs.scores + loss = contrastive_loss(scores) logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}") - return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None) + return PredictionOutput(predictions=scores, label_ids=None, metrics=None) From 0380e4ea64b60e969a6bdc8e1c9d954928c8c86e Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 22:16:09 +0900 Subject: [PATCH 17/17] fix: ddp --- src/tevatron/reranker/driver/train.py | 48 +++++++++++++++++---------- src/tevatron/reranker/modeling.py | 3 -- src/tevatron/reranker/trainer.py | 38 ++++++++++++--------- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 8115d911..7c924001 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -1,20 +1,36 @@ import logging import os import sys +import torch from transformers import AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator -from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer +from tevatron.reranker.trainer import RerankerTrainer logger = logging.getLogger(__name__) +def setup_ddp(): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + # We're running in a distributed environment + import torch.distributed as dist + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + dist.init_process_group(backend="nccl") + return rank + else: + # We're not running in a distributed environment + return -1 + + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) @@ -23,29 +39,22 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) + local_rank = setup_ddp() + training_args.local_rank = local_rank # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + level=logging.INFO if local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, + local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, + bool(local_rank != -1), + training_args.fp16 or training_args.bf16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) @@ -67,11 +76,16 @@ def main(): cache_dir=model_args.cache_dir, ) + # Move model to GPU + if local_rank != -1: + model = model.to(local_rank) + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) - # Add GradCache-specific arguments to training_args training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2) + training_args.grad_cache = getattr(training_args, 'grad_cache', False) trainer = RerankerTrainer( model=model, @@ -81,11 +95,11 @@ def main(): ) train_dataset.trainer = trainer - trainer.train() # TODO: resume training + trainer.train() trainer.save_model() if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index 887d6d24..6007dffc 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -38,9 +38,6 @@ def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwa scores=outputs.logits ) - def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None): - return False - @classmethod def build( cls, diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index 79cc2bc1..8534b68a 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -7,7 +7,6 @@ from transformers.trainer_utils import PredictionOutput from grad_cache import GradCache - from grad_cache.functional import cached, cat_input_tensor from torch.cuda.amp import autocast @@ -39,22 +38,26 @@ def split_inputs(model_input, chunk_size): class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logger.info("Initializing RerankerTrainer with GradCache") + logger.info("Initializing RerankerTrainer") self.args: TrainingArguments - # Add these lines to include the necessary parameters - self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided + self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) + self.use_grad_cache = getattr(self.args, 'grad_cache', False) + + if self.use_grad_cache: + # If the model is wrapped in DDP, we need to use the .module attribute + model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model - self.gc = GradCache( - models=[self.model], - chunk_sizes=self.gc_chunk_size, - loss_fn=contrastive_loss, - split_input_fn=split_inputs, - get_rep_fn=lambda x: x.scores, - fp16=self.args.fp16, - scaler=self.scaler if self.args.fp16 else None - ) - logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") + self.gc = GradCache( + models=[model_for_gc], + chunk_sizes=self.gc_chunk_size, + loss_fn=contrastive_loss, + split_input_fn=split_inputs, + get_rep_fn=lambda x: x.scores, + fp16=self.args.fp16, + # scaler: GradScaler = None, + ) + logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") def compute_loss(self, model, inputs, return_outputs=False): logger.debug(f"Computing loss with inputs: {inputs.keys()}") @@ -68,8 +71,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, logger.debug("Entering training step") model.train() inputs = self._prepare_inputs(inputs) - _distributed = self.args.local_rank > -1 - loss = self.gc(inputs, no_sync_except_last=_distributed) + if self.use_grad_cache: + _distributed = self.args.local_rank != -1 + loss = self.gc(inputs, no_sync_except_last=_distributed) + else: + loss = self.compute_loss(model, inputs) logger.debug(f"Training step loss: {loss.item()}") return loss