From 523d51b6365d0a48ee824db479b10130f487bc4a Mon Sep 17 00:00:00 2001 From: pkseniya Date: Thu, 13 Jul 2023 19:05:09 +0300 Subject: [PATCH 01/10] Add adapter model for conditioning + refactoring --- scripts/sample.py | 2 +- scripts/train.py | 2 +- scripts/train_adapter.py | 73 ++++++++++++++++++++++++ src/conditioning/__init__.py | 0 src/conditioning/adapter.py | 101 +++++++++++++++++++++++++++++++++ src/data/convai2_dataset.py | 72 ++++++++++++++++------- src/diffusion/model.py | 86 +++------------------------- src/pipeline/__init__.py | 0 src/{ => pipeline}/sampling.py | 0 src/{ => pipeline}/training.py | 0 src/pipeline/utils.py | 91 +++++++++++++++++++++++++++++ 11 files changed, 324 insertions(+), 103 deletions(-) create mode 100644 scripts/train_adapter.py create mode 100644 src/conditioning/__init__.py create mode 100644 src/conditioning/adapter.py create mode 100644 src/pipeline/__init__.py rename src/{ => pipeline}/sampling.py (100%) rename src/{ => pipeline}/training.py (100%) create mode 100644 src/pipeline/utils.py diff --git a/scripts/sample.py b/scripts/sample.py index ce3df41..c363a1f 100644 --- a/scripts/sample.py +++ b/scripts/sample.py @@ -12,7 +12,7 @@ from src.data.commonsense_dataset import CommonSenseDataset from src.data.utils import Preprocessor from src.diffusion.model import DiDi, get_components -from src.sampling import sample +from src.pipeline.sampling import sample def configure_arg_parser(): diff --git a/scripts/train.py b/scripts/train.py index 7f3ce9f..bf15d4e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,7 +12,7 @@ from src.data.reddit_dataset import RedditDataset from src.diffusion.model import DiDi from src.diffusion.model import get_components -from src.training import train_model +from src.pipeline.training import train_model from src.utils import filter_warnings, setup_logger, zero_rank_info diff --git a/scripts/train_adapter.py b/scripts/train_adapter.py new file mode 100644 index 0000000..1c867d2 --- /dev/null +++ b/scripts/train_adapter.py @@ -0,0 +1,73 @@ +import argparse +from os import environ +from os.path import join + +import torch +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from src.conditioning.adapter import Adapter +from src.data.convai2_dataset import ConvAI2Dataset +from src.diffusion.model import DiDi +from src.pipeline.training import train_model +from src.utils import filter_warnings, setup_logger, zero_rank_info + + +def configure_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str, help="Path to YAML config file") + parser.add_argument("dataset_dir", type=str, help="Path to dataset directory") + parser.add_argument("model_path", type=str, help="Path to DiDi model") + parser.add_argument("--condition", type=str, default="other", help="Type of persona") + return parser + + +def main(config_path: str, dataset_dir: str, model_path: str, condition: str): + filter_warnings() + setup_logger() + environ["TOKENIZERS_PARALLELISM"] = "false" + + torch.set_float32_matmul_precision("high") + + config = OmegaConf.load(config_path) + zero_rank_info(f"Loaded config:\n{OmegaConf.to_yaml(config, resolve=False, sort_keys=True)}") + + train_dataset = ConvAI2Dataset( + join(dataset_dir, f"train_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset + ) + val_dataset = ConvAI2Dataset( + join(dataset_dir, f"valid_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=config.batch_size, + collate_fn=train_dataset.collate_fn, + pin_memory=True, + num_workers=1, + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config.val_batch_size, + collate_fn=val_dataset.collate_fn, + pin_memory=True, + num_workers=1, + ) + + didi = DiDi.load_from_checkpoint(model_path) + model = Adapter(didi) + + train_model( + model, + train_dataloader, + val_dataloader, + config.trainer, + seed=config.seed, + save_interval=config.save_interval, + ) + + +if __name__ == "__main__": + _args = configure_arg_parser().parse_args() + main(**vars(_args)) diff --git a/src/conditioning/__init__.py b/src/conditioning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py new file mode 100644 index 0000000..10db99e --- /dev/null +++ b/src/conditioning/adapter.py @@ -0,0 +1,101 @@ +import torch +from lightning import LightningModule +from torch import nn + +from src.diffusion.model import DiDi +from src.diffusion.utils import get_diffusion_variables, get_x0 +from src.pipeline.utils import calculate_train_step, freeze_params, get_cached_content, get_optimizers + + +class AdapterBlock(nn.Module): + def __init__(self, input_dim: int, num_heads: int): + super().__init__() + self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads) + self.query = nn.Linear(input_dim, input_dim) + self.key = nn.Linear(input_dim, input_dim) + self.value = nn.Linear(input_dim, input_dim) + + def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): + query = self.query(hidden_states) + key = self.key(encoder_hidden_states) + value = self.value(encoder_hidden_states) + return self.attention(key, query, value, need_weights=False, key_padding_mask=encoder_attention_mask > 0) + + +class Adapter(LightningModule): + def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: float = None): + super().__init__() + self.didi = didi + freeze_params(self.didi) + + self.decoder_layers = [] + adapter_layers = [] + for layer in didi.decoder.encoder.layer: + self.decoder_layers.append(layer) + adapter_layers.append(AdapterBlock(layer.output.dense.out_features, 1)) + + self.adapter_layers = nn.ModuleList(adapter_layers) + + self.lr, self.warmup, self.min_lr = lr, warmup_steps, min_lr + + def configure_optimizers(self): + return get_optimizers(self) + + def forward( + self, + encoder_input_ids: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + decoder_inputs_embeds: torch.Tensor = None, + condition_input_ids: torch.Tensor = None, + condition_attention_mask: torch.Tensor = None, + time_ids: torch.Tensor = None, + context: torch.Tensor = None, + condition: torch.Tensor = None, + ): + if encoder_input_ids is None and context is None: + raise ValueError("Either `encoder_input_ids` or `context` must be provided.") + + if condition_input_ids is None and condition is None: + raise ValueError("Either `condition_input_ids` or `condition` must be provided.") + + context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask) + condition = condition or get_cached_content(self.didi, condition_input_ids, condition_attention_mask) + + time_embeds = self.didi.time_embeds(time_ids) + hidden_states = decoder_inputs_embeds + time_embeds + + for decoder_layer, adapter_layer in zip(self.decoder_layers, self.adapter_layers): + output = decoder_layer( + hidden_states=hidden_states, + encoder_hidden_states=context, + encoder_attention_mask=encoder_attention_mask, + )[0] + hidden_states = adapter_layer( + hidden_states=output, + encoder_hidden_states=condition, + encoder_attention_mask=condition_attention_mask, + )[0] + + return hidden_states, context, condition + + def training_step(self, batch: list, batch_idx: int): + raw_context, target, condition = batch + emb = self.didi.emb(target.input_ids) + x_0 = get_x0(emb, self.didi.std_0) + noise = torch.randn_like(x_0) + + # x: [batch size; seq len; emb dim], t: [batch size] + x_t, t = get_diffusion_variables(self.didi.diffusion_steps, x_0, self.didi.sigmas, noise) + + x_0_hat, *_ = self( + encoder_input_ids=raw_context.input_ids, + encoder_attention_mask=raw_context.attention_mask, + decoder_inputs_embeds=x_t, + time_ids=t, + condition_input_ids=condition.input_ids, + condition_attention_mask=condition.attention_mask, + ) # [batch size; seq len; emb dim] + + loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t) + self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) + return loss diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index a0fa8f7..d9d16cc 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -1,10 +1,12 @@ from dataclasses import dataclass +from enum import Enum from typing import Optional from loguru import logger from torch.utils.data import Dataset from tqdm.auto import tqdm from transformers import AutoTokenizer +from src.data.utils import Preprocessor @dataclass @@ -15,25 +17,48 @@ class ConvAI2Dialog: partner_persona: Optional[list[str]] = None +class Conditions(Enum): + NONE = 0 + YOUR = 1 + PARTNERS = 2 + + +def get_condition(path: str): + if "none" in path: + return Conditions.NONE + elif "self" in path: + return Conditions.YOUR + else: + return Conditions.PARTNERS + + class ConvAI2Dataset(Dataset): _YOUR_PERSONA_PREFIX = "your persona: " _PARTNER_PERSONA_PREFIX = "partner's persona: " - def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, have_candidates=True): + def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, max_condition_len=None): self.dataset = [] self.num_dialogs = 0 self.context_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, truncation_side="left") self.candidate_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - - bos = self.context_tokenizer.bos_token - eos = self.context_tokenizer.eos_token + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "return_tensors": "pt", + "add_special_tokens": False, + } + preprocessor = Preprocessor(tokenizer_name) + bos = preprocessor.bos + eos = preprocessor.eos self.max_context_len = max_context_len self.max_target_len = max_target_len or max_context_len + self.max_condition_len = max_condition_len or max_context_len - self.have_candidates = have_candidates + self.have_candidates = not "no_cands" in path self.vocab_size = self.context_tokenizer.vocab_size + self.condition = get_condition(path) logger.info(f"Loading dataset from '{path}'") with open(path, "r") as f: @@ -53,7 +78,7 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, h partner_persona.append(line[len(self._PARTNER_PERSONA_PREFIX) :]) continue - if have_candidates: + if self.have_candidates: utterance1, utterance2, _, candidates_str = line.split("\t") else: utterance1, utterance2, *_ = line.split("\t") @@ -79,14 +104,13 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] # [batch size, context seq len] - b_contexts = self.context_tokenizer( - str_contexts, - max_length=self.max_context_len, - padding=True, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - ).input_ids + b_contexts = self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs) + + str_conditions = [] + if self.condition is Conditions.YOUR: + str_conditions = [" ".join(sample.my_persona) for sample in samples] + elif self.condition is Conditions.PARTNERS: + str_conditions = [" ".join(sample.partner_persona) for sample in samples] if return_all_candidates: str_candidates = [it for sample in samples for it in sample.candidates] @@ -94,11 +118,15 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = str_candidates = [sample.candidates[0] for sample in samples] # Tokenizer truncates on the left, but for candidates we want to truncate on the right - b_candidates = self.candidate_tokenizer( - str_candidates, padding="max_length", return_tensors="pt", add_special_tokens=False - ).input_ids - b_candidates = b_candidates[:, : self.max_target_len] - # [batch size, # candidates, candidates seq len] - b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) - - return b_contexts, b_candidates.squeeze(1) + b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs) + # b_candidates = b_candidates[:, : self.max_target_len] + # # [batch size, # candidates, candidates seq len] + # b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) + + if self.condition is Conditions.NONE: + return b_contexts, b_candidates # .squeeze(1) + else: + b_conditions = self.candidate_tokenizer( + str_conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs + ) + return b_contexts, b_candidates, b_conditions # .squeeze(1) diff --git a/src/diffusion/model.py b/src/diffusion/model.py index 8a17b34..e5379e9 100644 --- a/src/diffusion/model.py +++ b/src/diffusion/model.py @@ -1,17 +1,16 @@ -from functools import partial +from enum import Enum import torch -from enum import Enum from lightning import LightningModule -from math import sqrt from torch import nn from transformers import AutoModel from transformers import BertConfig from transformers import T5EncoderModel -from src.diffusion.utils import configure_schedule, get_x0, get_diffusion_variables, scale_input +from src.diffusion.utils import configure_schedule, get_diffusion_variables, get_x0 from src.metrics import calculate_batch_ce -from src.sampling import sample +from src.pipeline.sampling import sample +from src.pipeline.utils import freeze_params, get_cached_content, get_optimizers, calculate_train_step from src.utils import zero_rank_info @@ -54,15 +53,6 @@ def get_components(name: str, **model_kwargs): return encoder, decoder, enc_dim, decoder_config.hidden_size -def freeze_params(model): - for parameter in model.parameters(): - parameter.requires_grad = False - - -def flat_mean(tensor): - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - class DiDi(LightningModule): def __init__( self, @@ -122,13 +112,7 @@ def __init__( self.batch_decoder = batch_decoder def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=1.0) # Fully control LR from scheduler - scheduler_lambda = partial(rsqrt_with_warmup, max_lr=self.lr, min_lr=self.min_lr, warmup=self.warmup) - lr_scheduler_config = { - "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), - "interval": "step", - } - return [optimizer], [lr_scheduler_config] + return get_optimizers(self) def train(self, mode: bool = True): super().train(mode) @@ -146,14 +130,7 @@ def forward( if encoder_input_ids is None and context is None: raise ValueError("Either `encoder_input_ids` or `context` must be provided.") - if context is None: - with torch.no_grad(): - context = self.encoder( - input_ids=encoder_input_ids, attention_mask=encoder_attention_mask - ).last_hidden_state - - if self.encoder_dim != self.decoder_dim: - context = self.adapter(context) + context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask) time_embeds = self.time_embeds(time_ids) input_embeds = decoder_inputs_embeds + time_embeds @@ -181,27 +158,7 @@ def training_step(self, batch: list, batch_idx: int): time_ids=t, ) # [batch size; seq len; emb dim] - logits = self.classifier(x_0) # [batch size; seq len; vocab size] - ce = calculate_batch_ce(logits, target.input_ids, target.attention_mask) - - non_pad_mask = target.attention_mask.unsqueeze(-1) - mse = torch.where( - t == 1, - flat_mean((x_0_hat - emb) ** 2 * non_pad_mask), - flat_mean((x_0_hat - x_0) ** 2 * non_pad_mask), - ).mean() - - noise, sigma_T = torch.randn_like(x_0), self.sigmas[-1] - x_T = scale_input(x_0 + sigma_T * noise, sigma_T) - t_T_loss = (x_T**2 * non_pad_mask).mean() - - loss = mse + ce + t_T_loss - - with torch.no_grad(): - logits_hat = self.classifier(x_0_hat) - ce_hat = calculate_batch_ce(logits_hat, target.input_ids, target.attention_mask) - - metrics = {"train/mse": mse, "train/ce": ce, "train/t_T": t_T_loss, "train/loss": loss, "train/ce_hat": ce_hat} + loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t) self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) return loss @@ -229,32 +186,3 @@ def on_validation_epoch_end(self): self.log_dict(metrics, sync_dist=True, on_step=False, on_epoch=True) self.val_ce.clear() self.val_acc.clear() - - -def rsqrt_with_warmup(step: int, max_lr: float, min_lr: float, warmup: int) -> float: - """Scheduler for learning rate with a form of reverse sqrt (known as Noam favorite scheduler): - `lr_t = max_lr * sqrt(1 / t)` - - Warm-up increases learning rate from 0 with square root form and then smoothly decay with reverse square root. - `lr_t = max_lr * sqrt(t / warmup)` if t <= warmup - `lr_t = max_lr * sqrt(warmup / t)` if t > warmup - - Also, there is control of minimum learning rate - - :param step: current step - :param max_lr: maximum learning rate - :param min_lr: minimum learning rate - :param warmup: number of warmup steps - :return: next learning rate - """ - if warmup != 0 and step < warmup: - return max_lr * sqrt(step / warmup) - - if warmup == 0: - lr = max_lr * sqrt(1 / step) - else: - lr = max_lr * sqrt(warmup / step) - - if min_lr is not None: - lr = max(lr, min_lr) - return lr diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sampling.py b/src/pipeline/sampling.py similarity index 100% rename from src/sampling.py rename to src/pipeline/sampling.py diff --git a/src/training.py b/src/pipeline/training.py similarity index 100% rename from src/training.py rename to src/pipeline/training.py diff --git a/src/pipeline/utils.py b/src/pipeline/utils.py new file mode 100644 index 0000000..f225b50 --- /dev/null +++ b/src/pipeline/utils.py @@ -0,0 +1,91 @@ +from functools import partial + +import torch +from math import sqrt + +from src.diffusion.utils import scale_input +from src.metrics import calculate_batch_ce + + +def freeze_params(model): + for parameter in model.parameters(): + parameter.requires_grad = False + + +def flat_mean(tensor: torch.Tensor) -> torch.Tensor: + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def rsqrt_with_warmup(step: int, max_lr: float, min_lr: float, warmup: int) -> float: + """Scheduler for learning rate with a form of reverse sqrt (known as Noam favorite scheduler): + `lr_t = max_lr * sqrt(1 / t)` + + Warm-up increases learning rate from 0 with square root form and then smoothly decay with reverse square root. + `lr_t = max_lr * sqrt(t / warmup)` if t <= warmup + `lr_t = max_lr * sqrt(warmup / t)` if t > warmup + + Also, there is control of minimum learning rate + + :param step: current step + :param max_lr: maximum learning rate + :param min_lr: minimum learning rate + :param warmup: number of warmup steps + :return: next learning rate + """ + if warmup != 0 and step < warmup: + return max_lr * sqrt(step / warmup) + + if warmup == 0: + lr = max_lr * sqrt(1 / step) + else: + lr = max_lr * sqrt(warmup / step) + + if min_lr is not None: + lr = max(lr, min_lr) + return lr + + +def get_cached_content(model, encoder_input_ids, encoder_attention_mask): + with torch.no_grad(): + context = model.encoder( + input_ids=encoder_input_ids, attention_mask=encoder_attention_mask + ).last_hidden_state + + if model.encoder_dim != model.decoder_dim: + context = model.adapter(context) + return context + + +def get_optimizers(model): + optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) # Fully control LR from scheduler + scheduler_lambda = partial(rsqrt_with_warmup, max_lr=model.lr, min_lr=model.min_lr, warmup=model.warmup) + lr_scheduler_config = { + "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), + "interval": "step", + } + return [optimizer], [lr_scheduler_config] + + +def calculate_train_step(model, emb, x_0, x_0_hat, target, t): + logits = model.classifier(x_0) # [batch size; seq len; vocab size] + ce = calculate_batch_ce(logits, target.input_ids, target.attention_mask) + + non_pad_mask = target.attention_mask.unsqueeze(-1) + mse = torch.where( + t == 1, + flat_mean((x_0_hat - emb) ** 2 * non_pad_mask), + flat_mean((x_0_hat - x_0) ** 2 * non_pad_mask), + ).mean() + + noise, sigma_T = torch.randn_like(x_0), model.sigmas[-1] + x_T = scale_input(x_0 + sigma_T * noise, sigma_T) + t_T_loss = (x_T ** 2 * non_pad_mask).mean() + + loss = mse + ce + t_T_loss + + with torch.no_grad(): + logits_hat = model.classifier(x_0_hat) + ce_hat = calculate_batch_ce(logits_hat, target.input_ids, target.attention_mask) + + metrics = {"train/mse": mse, "train/ce": ce, "train/t_T": t_T_loss, "train/loss": loss, "train/ce_hat": ce_hat} + return loss, metrics From f69100bb13a06c14ec3404d4cc524efd2c6a631e Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 17 Jul 2023 09:07:00 +0300 Subject: [PATCH 02/10] fix black --- src/pipeline/utils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/pipeline/utils.py b/src/pipeline/utils.py index f225b50..42cb292 100644 --- a/src/pipeline/utils.py +++ b/src/pipeline/utils.py @@ -47,9 +47,7 @@ def rsqrt_with_warmup(step: int, max_lr: float, min_lr: float, warmup: int) -> f def get_cached_content(model, encoder_input_ids, encoder_attention_mask): with torch.no_grad(): - context = model.encoder( - input_ids=encoder_input_ids, attention_mask=encoder_attention_mask - ).last_hidden_state + context = model.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask).last_hidden_state if model.encoder_dim != model.decoder_dim: context = model.adapter(context) @@ -57,13 +55,13 @@ def get_cached_content(model, encoder_input_ids, encoder_attention_mask): def get_optimizers(model): - optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) # Fully control LR from scheduler - scheduler_lambda = partial(rsqrt_with_warmup, max_lr=model.lr, min_lr=model.min_lr, warmup=model.warmup) - lr_scheduler_config = { - "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), - "interval": "step", - } - return [optimizer], [lr_scheduler_config] + optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) # Fully control LR from scheduler + scheduler_lambda = partial(rsqrt_with_warmup, max_lr=model.lr, min_lr=model.min_lr, warmup=model.warmup) + lr_scheduler_config = { + "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), + "interval": "step", + } + return [optimizer], [lr_scheduler_config] def calculate_train_step(model, emb, x_0, x_0_hat, target, t): @@ -79,7 +77,7 @@ def calculate_train_step(model, emb, x_0, x_0_hat, target, t): noise, sigma_T = torch.randn_like(x_0), model.sigmas[-1] x_T = scale_input(x_0 + sigma_T * noise, sigma_T) - t_T_loss = (x_T ** 2 * non_pad_mask).mean() + t_T_loss = (x_T**2 * non_pad_mask).mean() loss = mse + ce + t_T_loss From 0de05a248b629eb11ae6714bfaf310f316e8bbe2 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 17 Jul 2023 09:24:05 +0300 Subject: [PATCH 03/10] ignore mypy --- src/data/convai2_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index d9d16cc..a90d571 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -108,9 +108,9 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = str_conditions = [] if self.condition is Conditions.YOUR: - str_conditions = [" ".join(sample.my_persona) for sample in samples] + str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore elif self.condition is Conditions.PARTNERS: - str_conditions = [" ".join(sample.partner_persona) for sample in samples] + str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore if return_all_candidates: str_candidates = [it for sample in samples for it in sample.candidates] From 16cd998f6bc50e552031ff69f2b95f7fad6325ca Mon Sep 17 00:00:00 2001 From: pkseniya Date: Wed, 19 Jul 2023 11:21:32 +0300 Subject: [PATCH 04/10] use xformers attention --- src/conditioning/adapter.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py index 10db99e..1cd0ffa 100644 --- a/src/conditioning/adapter.py +++ b/src/conditioning/adapter.py @@ -1,6 +1,8 @@ import torch from lightning import LightningModule from torch import nn +from xformers.components.attention import ScaledDotProduct +from xformers.components.attention.utils import maybe_merge_masks from src.diffusion.model import DiDi from src.diffusion.utils import get_diffusion_variables, get_x0 @@ -8,18 +10,26 @@ class AdapterBlock(nn.Module): - def __init__(self, input_dim: int, num_heads: int): + def __init__(self, input_dim: int, num_heads: int = 4): super().__init__() - self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads) + self.attention = ScaledDotProduct() self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) + self.num_heads = num_heads def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): query = self.query(hidden_states) key = self.key(encoder_hidden_states) value = self.value(encoder_hidden_states) - return self.attention(key, query, value, need_weights=False, key_padding_mask=encoder_attention_mask > 0) + mask = maybe_merge_masks( + None, + encoder_attention_mask.bool(), + *encoder_hidden_states.shape[:2], + self.num_heads, + hidden_states.shape[1] + ) + return self.attention(query, key, value, mask=mask, num_heads=self.num_heads) class Adapter(LightningModule): @@ -32,7 +42,7 @@ def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: adapter_layers = [] for layer in didi.decoder.encoder.layer: self.decoder_layers.append(layer) - adapter_layers.append(AdapterBlock(layer.output.dense.out_features, 1)) + adapter_layers.append(AdapterBlock(layer.output.dense.out_features)) self.adapter_layers = nn.ModuleList(adapter_layers) @@ -74,7 +84,7 @@ def forward( hidden_states=output, encoder_hidden_states=condition, encoder_attention_mask=condition_attention_mask, - )[0] + ) return hidden_states, context, condition From 198c8a5fbb138c874139e4d6d0ccc3d3a1e104b0 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Wed, 19 Jul 2023 11:28:42 +0300 Subject: [PATCH 05/10] update optimizers --- src/conditioning/adapter.py | 2 +- src/diffusion/model.py | 2 +- src/pipeline/utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py index 1cd0ffa..a21315c 100644 --- a/src/conditioning/adapter.py +++ b/src/conditioning/adapter.py @@ -49,7 +49,7 @@ def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: self.lr, self.warmup, self.min_lr = lr, warmup_steps, min_lr def configure_optimizers(self): - return get_optimizers(self) + return get_optimizers(self.parameters(), self.lr, self.warmup, self.min_lr) def forward( self, diff --git a/src/diffusion/model.py b/src/diffusion/model.py index e5379e9..7b75889 100644 --- a/src/diffusion/model.py +++ b/src/diffusion/model.py @@ -112,7 +112,7 @@ def __init__( self.batch_decoder = batch_decoder def configure_optimizers(self): - return get_optimizers(self) + return get_optimizers(self.parameters(), self.lr, self.warmup, self.min_lr) def train(self, mode: bool = True): super().train(mode) diff --git a/src/pipeline/utils.py b/src/pipeline/utils.py index 42cb292..6040475 100644 --- a/src/pipeline/utils.py +++ b/src/pipeline/utils.py @@ -54,9 +54,9 @@ def get_cached_content(model, encoder_input_ids, encoder_attention_mask): return context -def get_optimizers(model): - optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) # Fully control LR from scheduler - scheduler_lambda = partial(rsqrt_with_warmup, max_lr=model.lr, min_lr=model.min_lr, warmup=model.warmup) +def get_optimizers(parameters, lr, warmup, min_lr): + optimizer = torch.optim.AdamW(parameters, lr=1.0) # Fully control LR from scheduler + scheduler_lambda = partial(rsqrt_with_warmup, max_lr=lr, min_lr=min_lr, warmup=warmup) lr_scheduler_config = { "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda), "interval": "step", From 5f8044a3a77242dedcef59a9f536ac9b8b6e9d3e Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 24 Jul 2023 11:34:20 +0300 Subject: [PATCH 06/10] update attention + dataset --- src/conditioning/adapter.py | 34 +++++++++++++++++++++------------- src/data/convai2_dataset.py | 31 +++++++++++++++++-------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py index a21315c..9f9c8ca 100644 --- a/src/conditioning/adapter.py +++ b/src/conditioning/adapter.py @@ -1,8 +1,8 @@ import torch from lightning import LightningModule from torch import nn -from xformers.components.attention import ScaledDotProduct from xformers.components.attention.utils import maybe_merge_masks +from xformers.ops import memory_efficient_attention from src.diffusion.model import DiDi from src.diffusion.utils import get_diffusion_variables, get_x0 @@ -12,24 +12,32 @@ class AdapterBlock(nn.Module): def __init__(self, input_dim: int, num_heads: int = 4): super().__init__() - self.attention = ScaledDotProduct() + self.head_dim = input_dim // num_heads + self.num_heads = num_heads + + self.attention = memory_efficient_attention self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) - self.num_heads = num_heads + self.out = nn.Linear(input_dim, input_dim) + + def split_heads(self, x, batch_size): + x = x.view(batch_size, -1, self.num_heads, self.head_dim) + return x def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): - query = self.query(hidden_states) - key = self.key(encoder_hidden_states) - value = self.value(encoder_hidden_states) + batch_size, trg_len, emb_dim = hidden_states.size() + src_len = encoder_hidden_states.shape[1] + + query = self.split_heads(self.query(hidden_states), batch_size) + key = self.split_heads(self.key(encoder_hidden_states), batch_size) + value = self.split_heads(self.value(encoder_hidden_states), batch_size) + mask = maybe_merge_masks( - None, - encoder_attention_mask.bool(), - *encoder_hidden_states.shape[:2], - self.num_heads, - hidden_states.shape[1] - ) - return self.attention(query, key, value, mask=mask, num_heads=self.num_heads) + None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len + ).view(batch_size, self.num_heads, src_len, trg_len) + float_mask = torch.where(mask, 0, float("-inf")) + return self.out(self.attention(query, key, value, attn_bias=float_mask).view(batch_size, trg_len, emb_dim)) class Adapter(LightningModule): diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index a90d571..027fc3c 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -100,18 +100,16 @@ def __len__(self): def __getitem__(self, idx: int) -> ConvAI2Dialog: return self.dataset[idx] - def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = False): + def collate_fn(self, + samples: list[ConvAI2Dialog], + return_all_candidates: bool = False, + return_condition: bool = False, + ): return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] # [batch size, context seq len] b_contexts = self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs) - str_conditions = [] - if self.condition is Conditions.YOUR: - str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore - elif self.condition is Conditions.PARTNERS: - str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore - if return_all_candidates: str_candidates = [it for sample in samples for it in sample.candidates] else: @@ -119,14 +117,19 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = # Tokenizer truncates on the left, but for candidates we want to truncate on the right b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs) - # b_candidates = b_candidates[:, : self.max_target_len] - # # [batch size, # candidates, candidates seq len] - # b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) - if self.condition is Conditions.NONE: - return b_contexts, b_candidates # .squeeze(1) - else: + if return_condition: + str_conditions = [] + if self.condition is Conditions.YOUR: + str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore + elif self.condition is Conditions.PARTNERS: + str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore + b_conditions = self.candidate_tokenizer( str_conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs ) - return b_contexts, b_candidates, b_conditions # .squeeze(1) + return b_contexts, b_candidates, b_conditions + else: + # [batch size, # candidates, candidates seq len] + b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) + return b_contexts, b_candidates.squeeze(1) From f9fb4b9a3bd932b93194ebaaad6acfaca8448762 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 24 Jul 2023 11:40:44 +0300 Subject: [PATCH 07/10] fix black --- src/data/convai2_dataset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index 027fc3c..f1de538 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -100,11 +100,12 @@ def __len__(self): def __getitem__(self, idx: int) -> ConvAI2Dialog: return self.dataset[idx] - def collate_fn(self, - samples: list[ConvAI2Dialog], - return_all_candidates: bool = False, - return_condition: bool = False, - ): + def collate_fn( + self, + samples: list[ConvAI2Dialog], + return_all_candidates: bool = False, + return_condition: bool = False, + ): return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] # [batch size, context seq len] From 849f449e036f54e444a685fc6841ead2b5e26ac9 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 24 Jul 2023 11:47:10 +0300 Subject: [PATCH 08/10] fix black second time --- src/data/convai2_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index f1de538..b8df858 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -101,10 +101,10 @@ def __getitem__(self, idx: int) -> ConvAI2Dialog: return self.dataset[idx] def collate_fn( - self, - samples: list[ConvAI2Dialog], - return_all_candidates: bool = False, - return_condition: bool = False, + self, + samples: list[ConvAI2Dialog], + return_all_candidates: bool = False, + return_condition: bool = False, ): return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] From 6a7fcda5dc7e831d9ced22c427fc910e83dc1703 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Mon, 31 Jul 2023 18:11:14 +0300 Subject: [PATCH 09/10] simplify code --- src/conditioning/adapter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py index 9f9c8ca..fdb53b5 100644 --- a/src/conditioning/adapter.py +++ b/src/conditioning/adapter.py @@ -15,7 +15,6 @@ def __init__(self, input_dim: int, num_heads: int = 4): self.head_dim = input_dim // num_heads self.num_heads = num_heads - self.attention = memory_efficient_attention self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) @@ -33,11 +32,12 @@ def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): key = self.split_heads(self.key(encoder_hidden_states), batch_size) value = self.split_heads(self.value(encoder_hidden_states), batch_size) - mask = maybe_merge_masks( - None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len - ).view(batch_size, self.num_heads, src_len, trg_len) + mask = maybe_merge_masks(None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len) + mask = mask.view(batch_size, self.num_heads, src_len, trg_len) float_mask = torch.where(mask, 0, float("-inf")) - return self.out(self.attention(query, key, value, attn_bias=float_mask).view(batch_size, trg_len, emb_dim)) + attn_result = memory_efficient_attention(query, key, value, attn_bias=float_mask) + output = self.out(attn_result.view(batch_size, trg_len, emb_dim)) + return output class Adapter(LightningModule): From 8398864df5fd47de377123981291673c61e2ffa7 Mon Sep 17 00:00:00 2001 From: pkseniya Date: Tue, 1 Aug 2023 21:59:21 +0300 Subject: [PATCH 10/10] update adapter dataset --- scripts/train_adapter.py | 19 ++++++++---- src/conditioning/adapter.py | 4 +-- src/data/convai2_dataset.py | 60 +++++++++++++++++-------------------- src/diffusion/model.py | 5 ++-- 4 files changed, 46 insertions(+), 42 deletions(-) diff --git a/scripts/train_adapter.py b/scripts/train_adapter.py index 1c867d2..210734c 100644 --- a/scripts/train_adapter.py +++ b/scripts/train_adapter.py @@ -3,6 +3,7 @@ from os.path import join import torch +from functools import partial from omegaconf import OmegaConf from torch.utils.data import DataLoader @@ -18,11 +19,10 @@ def configure_arg_parser(): parser.add_argument("config_path", type=str, help="Path to YAML config file") parser.add_argument("dataset_dir", type=str, help="Path to dataset directory") parser.add_argument("model_path", type=str, help="Path to DiDi model") - parser.add_argument("--condition", type=str, default="other", help="Type of persona") return parser -def main(config_path: str, dataset_dir: str, model_path: str, condition: str): +def main(config_path: str, dataset_dir: str, model_path: str): filter_warnings() setup_logger() environ["TOKENIZERS_PARALLELISM"] = "false" @@ -33,16 +33,23 @@ def main(config_path: str, dataset_dir: str, model_path: str, condition: str): zero_rank_info(f"Loaded config:\n{OmegaConf.to_yaml(config, resolve=False, sort_keys=True)}") train_dataset = ConvAI2Dataset( - join(dataset_dir, f"train_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset + join(dataset_dir, f"train_both_revised_no_cands.txt"), config.base_name, **config.dataset ) val_dataset = ConvAI2Dataset( - join(dataset_dir, f"valid_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset + join(dataset_dir, f"valid_both_revised_no_cands.txt"), config.base_name, **config.dataset + ) + + train_collate_fn = partial( + train_dataset.collate_fn, return_partner_persona=config.persona.partner, return_my_persona=config.persona.my + ) + val_collate_fn = partial( + val_dataset.collate_fn, return_partner_persona=config.persona.partner, return_my_persona=config.persona.my ) train_dataloader = DataLoader( train_dataset, batch_size=config.batch_size, - collate_fn=train_dataset.collate_fn, + collate_fn=train_collate_fn, pin_memory=True, num_workers=1, ) @@ -50,7 +57,7 @@ def main(config_path: str, dataset_dir: str, model_path: str, condition: str): val_dataloader = DataLoader( val_dataset, batch_size=config.val_batch_size, - collate_fn=val_dataset.collate_fn, + collate_fn=val_collate_fn, pin_memory=True, num_workers=1, ) diff --git a/src/conditioning/adapter.py b/src/conditioning/adapter.py index fdb53b5..6412728 100644 --- a/src/conditioning/adapter.py +++ b/src/conditioning/adapter.py @@ -96,8 +96,8 @@ def forward( return hidden_states, context, condition - def training_step(self, batch: list, batch_idx: int): - raw_context, target, condition = batch + def training_step(self, batch: dict, batch_idx: int): + raw_context, target, condition = batch["context"], batch["target"], batch["condition"] emb = self.didi.emb(target.input_ids) x_0 = get_x0(emb, self.didi.std_0) noise = torch.randn_like(x_0) diff --git a/src/data/convai2_dataset.py b/src/data/convai2_dataset.py index b8df858..10dedb7 100644 --- a/src/data/convai2_dataset.py +++ b/src/data/convai2_dataset.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from enum import Enum +from itertools import zip_longest from typing import Optional from loguru import logger from torch.utils.data import Dataset from tqdm.auto import tqdm from transformers import AutoTokenizer + from src.data.utils import Preprocessor @@ -17,21 +18,6 @@ class ConvAI2Dialog: partner_persona: Optional[list[str]] = None -class Conditions(Enum): - NONE = 0 - YOUR = 1 - PARTNERS = 2 - - -def get_condition(path: str): - if "none" in path: - return Conditions.NONE - elif "self" in path: - return Conditions.YOUR - else: - return Conditions.PARTNERS - - class ConvAI2Dataset(Dataset): _YOUR_PERSONA_PREFIX = "your persona: " _PARTNER_PERSONA_PREFIX = "partner's persona: " @@ -51,6 +37,7 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, m preprocessor = Preprocessor(tokenizer_name) bos = preprocessor.bos eos = preprocessor.eos + self.sep = preprocessor.sep self.max_context_len = max_context_len self.max_target_len = max_target_len or max_context_len @@ -58,7 +45,6 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, m self.have_candidates = not "no_cands" in path self.vocab_size = self.context_tokenizer.vocab_size - self.condition = get_condition(path) logger.info(f"Loading dataset from '{path}'") with open(path, "r") as f: @@ -104,12 +90,16 @@ def collate_fn( self, samples: list[ConvAI2Dialog], return_all_candidates: bool = False, - return_condition: bool = False, + return_my_persona: bool = False, + return_partner_persona: bool = False, ): + output = {} return_all_candidates = self.have_candidates & return_all_candidates str_contexts = [" ".join(sample.context) for sample in samples] # [batch size, context seq len] - b_contexts = self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs) + output["context"] = self.context_tokenizer( + str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs + ) if return_all_candidates: str_candidates = [it for sample in samples for it in sample.candidates] @@ -119,18 +109,24 @@ def collate_fn( # Tokenizer truncates on the left, but for candidates we want to truncate on the right b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs) - if return_condition: - str_conditions = [] - if self.condition is Conditions.YOUR: - str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore - elif self.condition is Conditions.PARTNERS: - str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore - - b_conditions = self.candidate_tokenizer( - str_conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs - ) - return b_contexts, b_candidates, b_conditions - else: + if return_all_candidates: # [batch size, # candidates, candidates seq len] b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) - return b_contexts, b_candidates.squeeze(1) + output["candidates"] = b_candidates.squeeze(1) + else: + output["target"] = b_candidates + + my_personas, partner_personas = [], [] + if return_my_persona: + my_personas = [" ".join(sample.my_persona) for sample in samples] # type: ignore + if return_partner_persona: + partner_personas = [" ".join(sample.partner_persona) for sample in samples] # type: ignore + + conditions = [] + for my_persona, partner_persona in zip_longest(my_personas, partner_personas, fillvalue=""): + conditions.append(my_persona + self.sep + partner_persona) + output["condition"] = self.candidate_tokenizer( + conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs + ) + + return output diff --git a/src/diffusion/model.py b/src/diffusion/model.py index 7b75889..fc12cd2 100644 --- a/src/diffusion/model.py +++ b/src/diffusion/model.py @@ -130,7 +130,8 @@ def forward( if encoder_input_ids is None and context is None: raise ValueError("Either `encoder_input_ids` or `context` must be provided.") - context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask) + if context is None: + context = get_cached_content(self, encoder_input_ids, encoder_attention_mask) time_embeds = self.time_embeds(time_ids) input_embeds = decoder_inputs_embeds + time_embeds @@ -158,7 +159,7 @@ def training_step(self, batch: list, batch_idx: int): time_ids=t, ) # [batch size; seq len; emb dim] - loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t) + loss, metrics = calculate_train_step(self, emb, x_0, x_0_hat, target, t) self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) return loss