diff --git a/scripts/sample.py b/scripts/sample.py index ce3df41..c363a1f 100644 --- a/scripts/sample.py +++ b/scripts/sample.py @@ -12,7 +12,7 @@ from src.data.commonsense_dataset import CommonSenseDataset from src.data.utils import Preprocessor from src.diffusion.model import DiDi, get_components -from src.sampling import sample +from src.pipeline.sampling import sample def configure_arg_parser(): diff --git a/scripts/train.py b/scripts/train.py index 7f3ce9f..bf15d4e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,7 +12,7 @@ from src.data.reddit_dataset import RedditDataset from src.diffusion.model import DiDi from src.diffusion.model import get_components -from src.training import train_model +from src.pipeline.training import train_model from src.utils import filter_warnings, setup_logger, zero_rank_info diff --git a/scripts/train_adapter.py b/scripts/train_adapter.py new file mode 100644 index 0000000..210734c --- /dev/null +++ b/scripts/train_adapter.py @@ -0,0 +1,80 @@ +import argparse +from os import environ +from os.path import join + +import torch +from functools import partial +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from src.conditioning.adapter import Adapter +from src.data.convai2_dataset import ConvAI2Dataset +from src.diffusion.model import DiDi +from src.pipeline.training import train_model +from src.utils import filter_warnings, setup_logger, zero_rank_info + + +def configure_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str, help="Path to YAML config file") + parser.add_argument("dataset_dir", type=str, help="Path to dataset directory") + parser.add_argument("model_path", type=str, help="Path to DiDi model") + return parser + + +def main(config_path: str, dataset_dir: str, model_path: str): + filter_warnings() + setup_logger() + environ["TOKENIZERS_PARALLELISM"] = "false" + + torch.set_float32_matmul_precision("high") + + config = OmegaConf.load(config_path) + zero_rank_info(f"Loaded config:\n{OmegaConf.to_yaml(config, resolve=False, sort_keys=True)}") + + train_dataset = ConvAI2Dataset( + join(dataset_dir, f"train_both_revised_no_cands.txt"), config.base_name, **config.dataset + ) + val_dataset = ConvAI2Dataset( + join(dataset_dir, f"valid_both_revised_no_cands.txt"), config.base_name, **config.dataset + ) + + train_collate_fn = partial( + train_dataset.collate_fn, return_partner_persona=config.persona.partner, return_my_persona=config.persona.my + ) + val_collate_fn = partial( + val_dataset.collate_fn, return_partner_persona=config.persona.partner, return_my_persona=config.persona.my + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=config.batch_size, + collate_fn=train_collate_fn, + pin_memory=True, + num_workers=1, + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config.val_batch_size, + collate_fn=val_collate_fn, + pin_memory=True, + num_workers=1, + ) + + didi = DiDi.load_from_checkpoint(model_path) + model = Adapter(didi) + + train_model( + model, + train_dataloader, + val_dataloader, + config.trainer, + seed=config.seed, + save_interval=config.save_interval, + ) + + +if __name__ == "__main__": + _args = configure_arg_parser().parse_args() + main(**vars(_args)) diff --git a/src/conditioning/__init__.py b/src/conditioning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py new file mode 100644 index 0000000..6412728 --- /dev/null +++ b/src/conditioning/adapter.py @@ -0,0 +1,119 @@ +import torch +from lightning import LightningModule +from torch import nn +from xformers.components.attention.utils import maybe_merge_masks +from xformers.ops import memory_efficient_attention + +from src.diffusion.model import DiDi +from src.diffusion.utils import get_diffusion_variables, get_x0 +from src.pipeline.utils import calculate_train_step, freeze_params, get_cached_content, get_optimizers + + +class AdapterBlock(nn.Module): + def __init__(self, input_dim: int, num_heads: int = 4): + super().__init__() + self.head_dim = input_dim // num_heads + self.num_heads = num_heads + + self.query = nn.Linear(input_dim, input_dim) + self.key = nn.Linear(input_dim, input_dim) + self.value = nn.Linear(input_dim, input_dim) + self.out = nn.Linear(input_dim, input_dim) + + def split_heads(self, x, batch_size): + x = x.view(batch_size, -1, self.num_heads, self.head_dim) + return x + + def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): + batch_size, trg_len, emb_dim = hidden_states.size() + src_len = encoder_hidden_states.shape[1] + + query = self.split_heads(self.query(hidden_states), batch_size) + key = self.split_heads(self.key(encoder_hidden_states), batch_size) + value = self.split_heads(self.value(encoder_hidden_states), batch_size) + + mask = maybe_merge_masks(None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len) + mask = mask.view(batch_size, self.num_heads, src_len, trg_len) + float_mask = torch.where(mask, 0, float("-inf")) + attn_result = memory_efficient_attention(query, key, value, attn_bias=float_mask) + output = self.out(attn_result.view(batch_size, trg_len, emb_dim)) + return output + + +class Adapter(LightningModule): + def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: float = None): + super().__init__() + self.didi = didi + freeze_params(self.didi) + + self.decoder_layers = [] + adapter_layers = [] + for layer in didi.decoder.encoder.layer: + self.decoder_layers.append(layer) + adapter_layers.append(AdapterBlock(layer.output.dense.out_features)) + + self.adapter_layers = nn.ModuleList(adapter_layers) + + self.lr, self.warmup, self.min_lr = lr, warmup_steps, min_lr + + def configure_optimizers(self): + return get_optimizers(self.parameters(), self.lr, self.warmup, self.min_lr) + + def forward( + self, + encoder_input_ids: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + decoder_inputs_embeds: torch.Tensor = None, + condition_input_ids: torch.Tensor = None, + condition_attention_mask: torch.Tensor = None, + time_ids: torch.Tensor = None, + context: torch.Tensor = None, + condition: torch.Tensor = None, + ): + if encoder_input_ids is None and context is None: + raise ValueError("Either `encoder_input_ids` or `context` must be provided.") + + if condition_input_ids is None and condition is None: + raise ValueError("Either `condition_input_ids` or `condition` must be provided.") + + context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask) + condition = condition or get_cached_content(self.didi, condition_input_ids, condition_attention_mask) + + time_embeds = self.didi.time_embeds(time_ids) + hidden_states = decoder_inputs_embeds + time_embeds + + for decoder_layer, adapter_layer in zip(self.decoder_layers, self.adapter_layers): + output = decoder_layer( + hidden_states=hidden_states, + encoder_hidden_states=context, + encoder_attention_mask=encoder_attention_mask, + )[0] + hidden_states = adapter_layer( + hidden_states=output, + encoder_hidden_states=condition, + encoder_attention_mask=condition_attention_mask, + ) + + return hidden_states, context, condition + + def training_step(self, batch: dict, batch_idx: int): + raw_context, target, condition = batch["context"], batch["target"], batch["condition"] + emb = self.didi.emb(target.input_ids) + x_0 = get_x0(emb, self.didi.std_0) + noise = torch.randn_like(x_0) + + # x: [batch size; seq len; emb dim], t: [batch size] + x_t, t = get_diffusion_variables(self.didi.diffusion_steps, x_0, self.didi.sigmas, noise) + + x_0_hat, *_ = self( + encoder_input_ids=raw_context.input_ids, + encoder_attention_mask=raw_context.attention_mask, + decoder_inputs_embeds=x_t, + time_ids=t, + condition_input_ids=condition.input_ids, + condition_attention_mask=condition.attention_mask, + ) # [batch size; seq len; emb dim] + + loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t) + self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) + return loss diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index a0fa8f7..10dedb7 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from itertools import zip_longest from typing import Optional from loguru import logger @@ -6,6 +7,8 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer +from src.data.utils import Preprocessor + @dataclass class ConvAI2Dialog: @@ -19,20 +22,28 @@ class ConvAI2Dataset(Dataset): _YOUR_PERSONA_PREFIX = "your persona: " _PARTNER_PERSONA_PREFIX = "partner's persona: " - def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, have_candidates=True): + def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, max_condition_len=None): self.dataset = [] self.num_dialogs = 0 self.context_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, truncation_side="left") self.candidate_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - - bos = self.context_tokenizer.bos_token - eos = self.context_tokenizer.eos_token + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "return_tensors": "pt", + "add_special_tokens": False, + } + preprocessor = Preprocessor(tokenizer_name) + bos = preprocessor.bos + eos = preprocessor.eos + self.sep = preprocessor.sep self.max_context_len = max_context_len self.max_target_len = max_target_len or max_context_len + self.max_condition_len = max_condition_len or max_context_len - self.have_candidates = have_candidates + self.have_candidates = not "no_cands" in path self.vocab_size = self.context_tokenizer.vocab_size logger.info(f"Loading dataset from '{path}'") @@ -53,7 +64,7 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, h partner_persona.append(line[len(self._PARTNER_PERSONA_PREFIX) :]) continue - if have_candidates: + if self.have_candidates: utterance1, utterance2, _, candidates_str = line.split("\t") else: utterance1, utterance2, *_ = line.split("\t") @@ -75,18 +86,20 @@ def __len__(self): def __getitem__(self, idx: int) -> ConvAI2Dialog: return self.dataset[idx] - def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = False): + def collate_fn( + self, + samples: list[ConvAI2Dialog], + return_all_candidates: bool = False, + return_my_persona: bool = False, + return_partner_persona: bool = False, + ): + output = {} return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] # [batch size, context seq len] - b_contexts = self.context_tokenizer( - str_contexts, - max_length=self.max_context_len, - padding=True, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - ).input_ids + output["context"] = self.context_tokenizer( + str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs + ) if return_all_candidates: str_candidates = [it for sample in samples for it in sample.candidates] @@ -94,11 +107,26 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = str_candidates = [sample.candidates[0] for sample in samples] # Tokenizer truncates on the left, but for candidates we want to truncate on the right - b_candidates = self.candidate_tokenizer( - str_candidates, padding="max_length", return_tensors="pt", add_special_tokens=False - ).input_ids - b_candidates = b_candidates[:, : self.max_target_len] - # [batch size, # candidates, candidates seq len] - b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) - - return b_contexts, b_candidates.squeeze(1) + b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs) + + if return_all_candidates: + # [batch size, # candidates, candidates seq len] + b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) + output["candidates"] = b_candidates.squeeze(1) + else: + output["target"] = b_candidates + + my_personas, partner_personas = [], [] + if return_my_persona: + my_personas = [" ".join(sample.my_persona) for sample in samples] # type: ignore + if return_partner_persona: + partner_personas = [" ".join(sample.partner_persona) for sample in samples] # type: ignore + + conditions = [] + for my_persona, partner_persona in zip_longest(my_personas, partner_personas, fillvalue=""): + conditions.append(my_persona + self.sep + partner_persona) + output["condition"] = self.candidate_tokenizer( + conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs + ) + + return output diff --git a/src/diffusion/model.py b/src/diffusion/model.py index 8a17b34..fc12cd2 100644 --- a/src/diffusion/model.py +++ b/src/diffusion/model.py @@ -1,17 +1,16 @@ -from functools import partial +from enum import Enum import torch -from enum import Enum from lightning import LightningModule -from math import sqrt from torch import nn from transformers import AutoModel from transformers import BertConfig from transformers import T5EncoderModel -from src.diffusion.utils import configure_schedule, get_x0, get_diffusion_variables, scale_input +from src.diffusion.utils import configure_schedule, get_diffusion_variables, get_x0 from src.metrics import calculate_batch_ce -from src.sampling import sample +from src.pipeline.sampling import sample +from src.pipeline.utils import freeze_params, get_cached_content, get_optimizers, calculate_train_step from src.utils import zero_rank_info @@ -54,15 +53,6 @@ def get_components(name: str, **model_kwargs): return encoder, decoder, enc_dim, decoder_config.hidden_size -def freeze_params(model): - for parameter in model.parameters(): - parameter.requires_grad = False - - -def flat_mean(tensor): - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - class DiDi(LightningModule): def __init__( self, @@ -122,13 +112,7 @@ def __init__( self.batch_decoder = batch_decoder def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=1.0) # Fully control LR from scheduler - scheduler_lambda = partial(rsqrt_with_warmup, max_lr=self.lr, min_lr=self.min_lr, warmup=self.warmup) - lr_scheduler_config = { - "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), - "interval": "step", - } - return [optimizer], [lr_scheduler_config] + return get_optimizers(self.parameters(), self.lr, self.warmup, self.min_lr) def train(self, mode: bool = True): super().train(mode) @@ -147,13 +131,7 @@ def forward( raise ValueError("Either `encoder_input_ids` or `context` must be provided.") if context is None: - with torch.no_grad(): - context = self.encoder( - input_ids=encoder_input_ids, attention_mask=encoder_attention_mask - ).last_hidden_state - - if self.encoder_dim != self.decoder_dim: - context = self.adapter(context) + context = get_cached_content(self, encoder_input_ids, encoder_attention_mask) time_embeds = self.time_embeds(time_ids) input_embeds = decoder_inputs_embeds + time_embeds @@ -181,27 +159,7 @@ def training_step(self, batch: list, batch_idx: int): time_ids=t, ) # [batch size; seq len; emb dim] - logits = self.classifier(x_0) # [batch size; seq len; vocab size] - ce = calculate_batch_ce(logits, target.input_ids, target.attention_mask) - - non_pad_mask = target.attention_mask.unsqueeze(-1) - mse = torch.where( - t == 1, - flat_mean((x_0_hat - emb) ** 2 * non_pad_mask), - flat_mean((x_0_hat - x_0) ** 2 * non_pad_mask), - ).mean() - - noise, sigma_T = torch.randn_like(x_0), self.sigmas[-1] - x_T = scale_input(x_0 + sigma_T * noise, sigma_T) - t_T_loss = (x_T**2 * non_pad_mask).mean() - - loss = mse + ce + t_T_loss - - with torch.no_grad(): - logits_hat = self.classifier(x_0_hat) - ce_hat = calculate_batch_ce(logits_hat, target.input_ids, target.attention_mask) - - metrics = {"train/mse": mse, "train/ce": ce, "train/t_T": t_T_loss, "train/loss": loss, "train/ce_hat": ce_hat} + loss, metrics = calculate_train_step(self, emb, x_0, x_0_hat, target, t) self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) return loss @@ -229,32 +187,3 @@ def on_validation_epoch_end(self): self.log_dict(metrics, sync_dist=True, on_step=False, on_epoch=True) self.val_ce.clear() self.val_acc.clear() - - -def rsqrt_with_warmup(step: int, max_lr: float, min_lr: float, warmup: int) -> float: - """Scheduler for learning rate with a form of reverse sqrt (known as Noam favorite scheduler): - `lr_t = max_lr * sqrt(1 / t)` - - Warm-up increases learning rate from 0 with square root form and then smoothly decay with reverse square root. - `lr_t = max_lr * sqrt(t / warmup)` if t <= warmup - `lr_t = max_lr * sqrt(warmup / t)` if t > warmup - - Also, there is control of minimum learning rate - - :param step: current step - :param max_lr: maximum learning rate - :param min_lr: minimum learning rate - :param warmup: number of warmup steps - :return: next learning rate - """ - if warmup != 0 and step < warmup: - return max_lr * sqrt(step / warmup) - - if warmup == 0: - lr = max_lr * sqrt(1 / step) - else: - lr = max_lr * sqrt(warmup / step) - - if min_lr is not None: - lr = max(lr, min_lr) - return lr diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sampling.py b/src/pipeline/sampling.py similarity index 100% rename from src/sampling.py rename to src/pipeline/sampling.py diff --git a/src/training.py b/src/pipeline/training.py similarity index 100% rename from src/training.py rename to src/pipeline/training.py diff --git a/src/pipeline/utils.py b/src/pipeline/utils.py new file mode 100644 index 0000000..6040475 --- /dev/null +++ b/src/pipeline/utils.py @@ -0,0 +1,89 @@ +from functools import partial + +import torch +from math import sqrt + +from src.diffusion.utils import scale_input +from src.metrics import calculate_batch_ce + + +def freeze_params(model): + for parameter in model.parameters(): + parameter.requires_grad = False + + +def flat_mean(tensor: torch.Tensor) -> torch.Tensor: + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def rsqrt_with_warmup(step: int, max_lr: float, min_lr: float, warmup: int) -> float: + """Scheduler for learning rate with a form of reverse sqrt (known as Noam favorite scheduler): + `lr_t = max_lr * sqrt(1 / t)` + + Warm-up increases learning rate from 0 with square root form and then smoothly decay with reverse square root. + `lr_t = max_lr * sqrt(t / warmup)` if t <= warmup + `lr_t = max_lr * sqrt(warmup / t)` if t > warmup + + Also, there is control of minimum learning rate + + :param step: current step + :param max_lr: maximum learning rate + :param min_lr: minimum learning rate + :param warmup: number of warmup steps + :return: next learning rate + """ + if warmup != 0 and step < warmup: + return max_lr * sqrt(step / warmup) + + if warmup == 0: + lr = max_lr * sqrt(1 / step) + else: + lr = max_lr * sqrt(warmup / step) + + if min_lr is not None: + lr = max(lr, min_lr) + return lr + + +def get_cached_content(model, encoder_input_ids, encoder_attention_mask): + with torch.no_grad(): + context = model.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask).last_hidden_state + + if model.encoder_dim != model.decoder_dim: + context = model.adapter(context) + return context + + +def get_optimizers(parameters, lr, warmup, min_lr): + optimizer = torch.optim.AdamW(parameters, lr=1.0) # Fully control LR from scheduler + scheduler_lambda = partial(rsqrt_with_warmup, max_lr=lr, min_lr=min_lr, warmup=warmup) + lr_scheduler_config = { + "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), + "interval": "step", + } + return [optimizer], [lr_scheduler_config] + + +def calculate_train_step(model, emb, x_0, x_0_hat, target, t): + logits = model.classifier(x_0) # [batch size; seq len; vocab size] + ce = calculate_batch_ce(logits, target.input_ids, target.attention_mask) + + non_pad_mask = target.attention_mask.unsqueeze(-1) + mse = torch.where( + t == 1, + flat_mean((x_0_hat - emb) ** 2 * non_pad_mask), + flat_mean((x_0_hat - x_0) ** 2 * non_pad_mask), + ).mean() + + noise, sigma_T = torch.randn_like(x_0), model.sigmas[-1] + x_T = scale_input(x_0 + sigma_T * noise, sigma_T) + t_T_loss = (x_T**2 * non_pad_mask).mean() + + loss = mse + ce + t_T_loss + + with torch.no_grad(): + logits_hat = model.classifier(x_0_hat) + ce_hat = calculate_batch_ce(logits_hat, target.input_ids, target.attention_mask) + + metrics = {"train/mse": mse, "train/ce": ce, "train/t_T": t_T_loss, "train/loss": loss, "train/ce_hat": ce_hat} + return loss, metrics