diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 29efb47..e1b96fb 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -6,8 +6,8 @@ jobs: main: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 with: python-version: 3.11.6 cache: "pip" @@ -18,7 +18,3 @@ jobs: run: black . --check --diff --color - name: "isort" run: isort . --check --diff - - name: "mypy" - run: mypy - - name: "pytests" - run: pytest diff --git a/.gitignore b/.gitignore index 917b4e5..8e510ff 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .mypy_cache/ .pytest_cache/ __pycache__/ -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ +*/.ipynb_checkpoints/ \ No newline at end of file diff --git a/README.md b/README.md index 8b636a1..e6e484d 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,77 @@ -# py_template +# GeRaCl: General Rapid text Classifier -Template repository for Python projects. -Use it to create a new repo, but feel free to adopt for your use-cases. +**GeRaCl** is an open‑source **framework** for building, training, and evaluating efficient zero‑shot text classifiers on top of any BERT‑like sentence-encoder. It is inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework. -## Structure +### ✨ Why GeRaCl? -There are several directories to organize your code: -- `src`: Main directory for your modules, e.g., models or dataset implementations, train loops, metrics. -- `scripts`: Directory to define scripts to interact with modules, e.g., run training or evaluation, run data preprocessing, collect statistic. -- `tests`: Directory for tests, this may include multiple unit tests for different parts of logic. +| Feature | What it means for you | +| ------------------------------ | ------------------------------------------------------------------------------------------------- | +| **Zero‑shot by design** | Classify with **arbitrary** label sets that you decide at run‑time — just pass a list of strings. | +| **One forward pass** | As fast as ordinary text classification; no pairwise loops like in NLI‑based approaches. | +| **Model‑agnostic** | Works with any Hugging Face sentence-encoder. | +| **155 M reference checkpoint** | A lean [baseline](https://huggingface.co/deepvk/GeRaCl-USER2-base) (155M parameters) that beats much larger sentence‑encoders (300-500M parameters). | +| **All‑in‑one toolkit** | Training/eval scripts, HF Hub and WandB integration. | -You can create new directories for your need. -For example, you can create a `Notebooks` folder for Jupyter notebooks, such as `EDA.ipynb`. -## Usage +### 🚀 Quick Start -First of all, -navigate to [`pyproject.toml`](./pyproject.toml) and set up `name` and `url` properties according to your project. +Clone and install directly from GitHub: -For correct work of the import system: -1. Use absolute import statements starting from `src`. For example, `from src.model import MySuperModel` -2. Execute scripts as modules, i.e. use `python -m scripts.`. See details about `-m` flag [here](https://docs.python.org/3/using/cmdline.html#cmdoption-m). +```bash +git clone https://github.com/deepvk/zero-shot-classification +cd GeRaCl -To keep your code clean, use `black`, `isort`, and `mypy` -(install everything from [`requirements.dev.txt`](./requirements.dev.txt)). -[`pyproject.toml`](./pyproject.toml) already defines their parameters, but you can change them if you want. +pip install -r requirements.txt +``` + +Verify your installation: + +```python +import geracl +print(geracl.__version__) +``` + +### 🧑‍💻 Usage Examples + +#### Single classification scenario + +```python +from transformers import AutoTokenizer +from geracl import GeraclHF, ZeroShotClassificationPipeline + +model = GeraclHF.from_pretrained('deepvk/GeRaCl-USER2-base').to('cuda').eval() +tokenizer = AutoTokenizer.from_pretrained('deepvk/GeRaCl-USER2-base') + +pipe = ZeroShotClassificationPipeline(model, tokenizer, device="cuda") + +text = "Утилизация катализаторов: как неплохо заработать" +labels = ["экономика", "происшествия", "политика", "культура", "наука", "спорт"] +result = pipe(text, labels, batch_size=1)[0] + +print(labels[result]) +``` + +#### Multiple classification scenarios + +```python +from transformers import AutoTokenizer +from geracl import GeraclHF, ZeroShotClassificationPipeline + +model = GeraclHF.from_pretrained('deepvk/GeRaCl-USER2-base').to('cuda').eval() +tokenizer = AutoTokenizer.from_pretrained('deepvk/GeRaCl-USER2-base') + +pipe = ZeroShotClassificationPipeline(model, tokenizer, device="cuda") + +texts = [ + "Утилизация катализаторов: как неплохо заработать", + "Мне не понравился этот фильм." +] +labels = [ + ["экономика", "происшествия", "политика", "культура", "наука", "спорт"], + ["нейтральный", "позитивный", "негативный"] +] +results = pipe(texts, labels, batch_size=2) + +for i in range(len(labels)): + print(labels[i][results[i]]) +``` diff --git a/geracl/__init__.py b/geracl/__init__.py new file mode 100644 index 0000000..85a0fc2 --- /dev/null +++ b/geracl/__init__.py @@ -0,0 +1,4 @@ +from .model.config import GeraclConfigHF +from .model.geracl import Geracl +from .model.hf_wrapper import GeraclHF +from .pipeline import ZeroShotClassificationPipeline diff --git a/geracl/configs/custom.yaml b/geracl/configs/custom.yaml new file mode 100644 index 0000000..e107cfd --- /dev/null +++ b/geracl/configs/custom.yaml @@ -0,0 +1,43 @@ +model: + embedder_name: "deepvk/USER2-base" + ffn_dim: 2048 + ffn_classes_dropout: 0.4 + ffn_text_dropout: 0.4 + device: "cuda" + unfreeze_embedder: True + loss_args: + loss_type: "bce" +# init_params: +# alpha: +# gamma: + optimizer_args: + class_path: torch.optim.AdamW + init_params: + lr: 0.000005 + weight_decay: 0.1 + scheduler_args: + scheduler: "linear" + total_steps: 35310 + warmup_steps: 1000 + +data_module: + batch_size: 32 + val_batch_size: 32 + tokenizer_name: "deepvk/USER2-base" + config: "real_world_extended_expanded" + model_max_length: 2000 + num_workers: 5 + include_scenarios: False + input_prompt: "classification: " + +trainer: + accelerator: "gpu" + val_check_interval: 7062 + max_epochs: 5 + log_every_n_steps: 100 + # gradient_clip_val: 2.0 + #accumulate_grad_batches: 2 + # overfit_batches: 50 + +other: + checkpoints_dir: "/data/checkpoints/release_user2_base_training" \ No newline at end of file diff --git a/geracl/configs/default.yaml b/geracl/configs/default.yaml new file mode 100644 index 0000000..dfda16a --- /dev/null +++ b/geracl/configs/default.yaml @@ -0,0 +1,28 @@ +model: + embedder_name: "deepvk/USER2-base" + ffn_dim: 2048 + device: "cuda" + unfreeze_embedder: False + pooling_type: "mean" + loss_args: + loss_type: "bce" + +data_module: + batch_size: 16 + val_batch_size: 16 + num_workers: 10 + tokenizer_name: "deepvk/USER2-base" + config: "synthetic_positives_multiclass" + include_scenarios: False + input_prompt: "classification: " + +trainer: +# max_steps: 200 + accelerator: "gpu" + val_check_interval: 1000 + gradient_clip_val: 0.0 + log_every_n_steps: 100 + +other: + wandb_project: "universal_classifier" + checkpoints_dir: "/data/checkpoints" \ No newline at end of file diff --git a/src/configs/sweeps.yaml b/geracl/configs/sweeps.yaml similarity index 95% rename from src/configs/sweeps.yaml rename to geracl/configs/sweeps.yaml index 2e02e8f..7358bba 100644 --- a/src/configs/sweeps.yaml +++ b/geracl/configs/sweeps.yaml @@ -19,7 +19,7 @@ data_module: trainer: # max_steps: 200 accelerator: "gpu" - val_check_interval: 5810 + val_check_interval: 50 gradient_clip_val: 0.0 log_every_n_steps: 50 diff --git a/src/__init__.py b/geracl/data/__init__.py similarity index 100% rename from src/__init__.py rename to geracl/data/__init__.py diff --git a/src/data/data_utils.py b/geracl/data/batch_creation.py similarity index 80% rename from src/data/data_utils.py rename to geracl/data/batch_creation.py index 04ea0e5..e77ce3a 100644 --- a/src/data/data_utils.py +++ b/geracl/data/batch_creation.py @@ -7,7 +7,9 @@ def make_classifier_prompt( input_seq: ndarray, special_token_ids: dict[int], classes_list: list[ndarray], - positive_labels: list[list[int]] = None, + scenario: ndarray = np.array([], dtype=int), + starting_prompt: ndarray = np.array([], dtype=int), + positive_labels: list[int] = None, ) -> tuple[ndarray, ndarray]: if positive_labels: label_mask = [-2] * len(classes_list) @@ -26,10 +28,11 @@ def make_classifier_prompt( for i, (class_name, mask) in enumerate(zip(classes_list, label_mask)) ] ) - result_prompt = np.concatenate( [ [special_token_ids["bos_token"]], + starting_prompt, + scenario, result_prompt, [special_token_ids["sep_token"]], input_seq, @@ -40,9 +43,11 @@ def make_classifier_prompt( extended_label_mask = np.concatenate( [ np.array([-4]), + np.full(len(starting_prompt), -5, dtype=int), + np.full(len(scenario), -5, dtype=int), extended_label_mask, np.array([-4]), - np.full(len(input_seq), -3), + np.full(len(input_seq), -3, dtype=int), np.array([-4]), ] ) @@ -61,7 +66,6 @@ def prepare_batch( max_len = max(len(res_prompt) for res_prompt in result_prompts) if model_max_length is not None: max_len = min(max_len, model_max_length) - input_ids = torch.full((batch_size, max_len), pad_token_id, dtype=torch.long) attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long) classes_mask = torch.full((batch_size, max_len), -4, dtype=torch.long) @@ -85,9 +89,11 @@ def prepare_batch( return input_ids, attention_mask, classes_mask -def prepare_inference_batch(input_texts, classes, tokenizer): +def prepare_inference_batch(input_texts, classes, tokenizer, input_prompt=None): tokenized_texts = tokenizer(input_texts, add_special_tokens=False).input_ids tokenized_classes = [tokenizer(sample_classes, add_special_tokens=False).input_ids for sample_classes in classes] + if input_prompt: + tokenized_prompt = tokenizer(input_prompt, add_special_tokens=False).input_ids result_prompts = [] label_masks = [] @@ -98,9 +104,12 @@ def prepare_inference_batch(input_texts, classes, tokenizer): "sep_token": tokenizer.sep_token_id, "eos_token": tokenizer.eos_token_id, } - - for tokenized_text, tokenized_sample_classes in tokenized_texts: - result_prompt, label_mask = make_classifier_prompt(tokenized_text, special_token_ids, tokenized_sample_classes) + if input_prompt is None: + tokenized_prompt = np.array([], dtype=int) + for tokenized_text, tokenized_sample_classes in zip(tokenized_texts, tokenized_classes): + result_prompt, label_mask = make_classifier_prompt( + tokenized_text, special_token_ids, tokenized_sample_classes, starting_prompt=tokenized_prompt + ) result_prompts.append(result_prompt) label_masks.append(label_mask) diff --git a/geracl/data/data_module.py b/geracl/data/data_module.py new file mode 100644 index 0000000..2472f98 --- /dev/null +++ b/geracl/data/data_module.py @@ -0,0 +1,247 @@ +import json +from functools import partial +from typing import List, Optional, Tuple + +import numpy as np +import torch +from datasets import load_dataset +from loguru import logger +from numpy import ndarray +from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from tokenizers import Tokenizer +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from geracl.data.batch_creation import make_classifier_prompt, prepare_batch +from geracl.data.data_processing import choose_synthetic_classes, generate_classes, shuffle_classes +from geracl.data.dataset import ZeroShotClassificationDataset +from geracl.utils import add_required_tokens + + +class ZeroShotClassificationDataModule(LightningDataModule): + """Lightning data module for data handling. + Provides dataloader for all splits, e.g. `train_dataloader` method. + Public methods and attributes allow to retrieve information about texts and their classes. + + And `tokenizers` to tokenize data, e.g. `deepvk/USER-base` + """ + + def __init__( + self, + batch_size: int = 16, + val_batch_size: int = 16, + num_workers: int = 20, + tokenizer_name: str = "deepvk/USER-base", + config: str = "synthetic_positives_multiclass", + model_max_length: int = None, + include_scenarios: bool = False, + input_prompt: str = None, + ): + """Data module constructor. + + :param batch_size: train batch size; + :param val_batch_size: validation and test batch size; + :param num_workers: Number of workers for data loaders; + :param tokenizer_name: name of the tokenizer, "deepvk/USER-base" by default; + :param config: config of the CLAZER HuggingFace dataset; + :param model_max_length: Maximum sequence length of the embedder model; + :param include_scenarios: whether to include classification scenarios in the prompt or not. + Works only when config = "synthetic_classes"; + :param input_prompt: Input prompt to the embedder model, e.g. "classification: ". + """ + super().__init__() + if config not in { + "synthetic_positives_multiclass", + "synthetic_positives_multilabel", + "synthetic_classes", + "ru_mteb_classes", + "ru_mteb_extended_classes", + "real_world_extended_expanded", + }: + raise ValueError("Invalid DataModule config parameter.") + if include_scenarios and config != "synthetic_classes": + raise ValueError( + "Invalid DataModule include_scenarios parameter. It can be true only with the 'synthetic_classes' dataset config." + ) + self._config = config + self._batch_size = batch_size + self._val_batch_size = val_batch_size + self._num_workers = num_workers + self._include_scenarios = include_scenarios + self._tokenizer_name = tokenizer_name + logger.info(f"Downloading and opening '{self._tokenizer_name}' tokenizer") + self._tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_name) + self._tokenizer = add_required_tokens(self._tokenizer) + if model_max_length is not None: + self._tokenizer.model_max_length = model_max_length + self._special_token_ids = { + "bos_token": self._tokenizer.bos_token_id, + "cls_token": self._tokenizer.cls_token_id, + "sep_token": self._tokenizer.sep_token_id, + "eos_token": self._tokenizer.eos_token_id, + } + self._val_test_collate_fn = partial( + self._generic_val_test_collate_fn, include_scenarios=include_scenarios, input_prompt=input_prompt + ) + + if self._config == "synthetic_classes": + self._train_collate_fn = self._synthetic_classes_train_collate_fn + elif self._config in {"synthetic_positives_multiclass", "synthetic_positives_multilabel"}: + self._train_collate_fn = self._synthetic_positives_train_collate_fn + elif self._config == "real_world_extended_expanded": + self._train_collate_fn = self._extended_expanded_train_collate_fn + else: + self._train_collate_fn = self._val_test_collate_fn + + def setup(self, stage: Optional[str] = None): + logger.info("Downloading and opening 'deepvk/synthetic-classes' dataset") + + data = load_dataset("deepvk/synthetic-classes", self._config) + + if self._config == "synthetic_classes": + train_data = load_dataset("deepvk/synthetic-classes", "synthetic_classes_train", split="train") + elif self._config in {"synthetic_positives_multiclass", "synthetic_positives_multilabel"}: + train_data = load_dataset("deepvk/synthetic-classes", "synthetic_positives", split="train") + + self._datasets = {} + self._labels = {} + if self._config not in { + "synthetic_classes", + "synthetic_positives_multiclass", + "synthetic_positives_multilabel", + }: + splits = ["train", "validation", "test"] + else: + splits = ["validation", "test"] + logger.info("Initializing train dataset") + self._datasets["train"] = train_data + + for split in splits: + logger.info(f"Initializing {split} dataset") + self._datasets[split] = ZeroShotClassificationDataset(data[split], self._tokenizer) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + return DataLoader( + self._datasets["train"], + batch_size=self._batch_size, + collate_fn=self._train_collate_fn, + num_workers=self._num_workers, + ) + + def val_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader( + self._datasets["validation"], + batch_size=self._val_batch_size, + collate_fn=self._val_test_collate_fn, + num_workers=self._num_workers, + ) + + def test_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader( + self._datasets["test"], + batch_size=self._val_batch_size, + collate_fn=self._val_test_collate_fn, + num_workers=self._num_workers, + ) + + def _synthetic_classes_train_collate_fn(self, samples) -> Tuple[torch.Tensor, ...]: + if self._include_scenarios: + classes, scenarios = choose_synthetic_classes(samples, include_scenarios=self._include_scenarios) + for i in range(len(scenarios)): + scenarios[i] = scenarios[i] + ": " + else: + classes = choose_synthetic_classes(samples, include_scenarios=self._include_scenarios) + + positive_classes = [[sample_classes[0]] for sample_classes in classes] + new_classes, positive_labels = shuffle_classes(classes, positive_classes) + new_samples = [(samples[i]["text"], new_classes[i], positive_labels[i]) for i in range(len(samples))] + texts = [sample_text for (sample_text, _, _) in new_samples] + tokenized_texts = self._tokenizer(texts, add_special_tokens=False).input_ids + tokenized_classes = [ + self._tokenizer(sample_classes, add_special_tokens=False).input_ids + for (_, sample_classes, _) in new_samples + ] + if self._include_scenarios: + tokenized_scenarios = self._tokenizer(scenarios, add_special_tokens=False).input_ids + new_samples = [ + (tokenized_texts[i], tokenized_classes[i], positive_labels[i], tokenized_scenarios[i]) + for i in range(len(samples)) + ] + else: + new_samples = [(tokenized_texts[i], tokenized_classes[i], positive_labels[i]) for i in range(len(samples))] + + return self._val_test_collate_fn(new_samples) + + def _synthetic_positives_train_collate_fn( + self, samples: list[tuple[ndarray, List[ndarray]]] + ) -> Tuple[torch.Tensor, ...]: + classes = [sample["classes"] for sample in samples] + new_classes, positive_labels = generate_classes(classes, self._config) + new_samples = [(samples[i]["text"], new_classes[i], positive_labels[i]) for i in range(len(samples))] + + texts = [sample_text for (sample_text, _, _) in new_samples] + tokenized_texts = self._tokenizer(texts, add_special_tokens=False).input_ids + tokenized_classes = [ + self._tokenizer(sample_classes, add_special_tokens=False).input_ids + for (_, sample_classes, _) in new_samples + ] + + new_samples = [(tokenized_texts[i], tokenized_classes[i], positive_labels[i]) for i in range(len(samples))] + + return self._val_test_collate_fn(new_samples) + + def _generic_val_test_collate_fn( + self, samples: list[tuple[ndarray, list[ndarray], list[int]]], input_prompt=None, include_scenarios=None + ) -> tuple[torch.Tensor, ...]: + result_prompts = [] + label_masks = [] + if len(input_prompt) != 0: + tokenized_prompt = self._tokenizer(input_prompt, add_special_tokens=False).input_ids + else: + tokenized_prompt = input_prompt + if len(samples[0]) == 3: + classes_count = [len(sample_classes) for (_, sample_classes, _) in samples] + positive_labels = [sample_positive_labels for (_, _, sample_positive_labels) in samples] + for input_seq, sample_classes, sample_positive_labels in samples: + result_prompt, label_mask = make_classifier_prompt( + input_seq, + self._special_token_ids, + sample_classes, + starting_prompt=tokenized_prompt, + positive_labels=sample_positive_labels, + ) + + result_prompts.append(result_prompt) + label_masks.append(label_mask) + else: + classes_count = [len(sample_classes) for (_, sample_classes, _, _) in samples] + positive_labels = [sample_positive_labels for (_, _, sample_positive_labels, _) in samples] + for input_seq, sample_classes, sample_positive_labels, sample_scenario in samples: + if not include_scenarios: + sample_scenario = np.array([], dtype=int) + result_prompt, label_mask = make_classifier_prompt( + input_seq, + self._special_token_ids, + sample_classes, + starting_prompt=tokenized_prompt, + scenario=sample_scenario, + positive_labels=sample_positive_labels, + ) + + result_prompts.append(result_prompt) + label_masks.append(label_mask) + + input_ids, attention_mask, classes_mask = prepare_batch( + result_prompts, + label_masks, + self._tokenizer.pad_token_id, + self._tokenizer.eos_token_id, + self._tokenizer.model_max_length if self._tokenizer.model_max_length else None, + ) + return input_ids, attention_mask, classes_mask, torch.tensor(classes_count), positive_labels + + @property + def tokenizer(self) -> Tokenizer: + """Return current tokenizer instance.""" + return self._tokenizer diff --git a/src/data/data_processing.py b/geracl/data/data_processing.py similarity index 76% rename from src/data/data_processing.py rename to geracl/data/data_processing.py index 61c5237..0fa03cc 100644 --- a/src/data/data_processing.py +++ b/geracl/data/data_processing.py @@ -35,7 +35,7 @@ def generate_classes_multiclass(classes): derangement = random_derangement(len(classes)) fixed = True - max_attempts = 100 + max_attempts = 200 attempts = 0 while True: @@ -56,7 +56,6 @@ def generate_classes_multiclass(classes): # Break if no collisions or we tried too many times if fixed or attempts >= max_attempts: if attempts >= max_attempts: - # print(classes) raise Exception( "Could not distribute one positive class from each sample to different samples without repetition of classes." ) @@ -68,7 +67,6 @@ def generate_classes_multiclass(classes): for i in range(len(new_classes)): existing_set = set(new_classes[i]).union(set(classes[i])) # to ensure uniqueness - # Pick 'classes_count[i]' number of classes from weighted_classes, skipping duplicates added = 0 while added < negatives_count[i]: candidate = random.choice(weighted_classes) @@ -103,7 +101,6 @@ def generate_classes_multilabel(classes): for i in range(len(new_classes)): existing_set = set(new_classes[i]).union(set(classes[i])) # to ensure uniqueness - # Pick 'classes_count[i]' number of classes from weighted_classes, skipping duplicates added = 0 while added < negatives_count[i]: candidate = random.choice(weighted_classes) @@ -114,40 +111,13 @@ def generate_classes_multilabel(classes): return new_classes -def generate_classes_llm_negatives(llm_negatives, positives): - new_classes = [[] for _ in range(len(positives))] - for i in range(len(positives)): - selected_positive = random.choice(positives[i]) - new_classes[i] = [selected_positive] - - for i in range(len(new_classes)): - existing_set = set(new_classes[i]).union(set(positives[i])) # to ensure uniqueness - if new_classes[i][0] not in llm_negatives[i]: - break_flag = False - for positive_classes in positives: - for positive in positive_classes: - if positive not in existing_set: - new_classes[i].append(positive) - break_flag = True - break - if break_flag: - break - continue - negatives = llm_negatives[i][new_classes[i][0]] - for negative_class in negatives: - if negative_class not in existing_set: - new_classes[i].append(negative_class) - existing_set.add(negative_class) - - return new_classes - - def shuffle_classes(classes, positives): rng = np.random.default_rng() labels = [[] for _ in range(len(classes))] for i, sample_classes in enumerate(classes): + lower_positives = [positive.lower() for positive in positives[i]] for sample_class in sample_classes: - if sample_class in positives[i]: + if sample_class.lower() in lower_positives: labels[i].append(1) else: labels[i].append(0) @@ -165,19 +135,35 @@ def shuffle_classes(classes, positives): return classes, new_labels -def generate_classes_task_creation(batch): +def choose_synthetic_classes(batch, include_scenarios=False) -> list | tuple[list, list]: new_classes = [[] for _ in range(len(batch))] + if include_scenarios: + scenarios = [] for i in range(len(batch)): - negative_idx = random.randint(0, 4) - new_classes[i] = batch[i][f"negatives_{negative_idx}"] - - return new_classes + negative_ids = [0, 1, 2, 3, 4] + negative_idx = random.choice(negative_ids) + while not batch[i][f"classes_{negative_idx}"]: + negative_ids.remove(negative_idx) + negative_idx = random.choice(negative_ids) + new_classes[i] = batch[i][f"classes_{negative_idx}"] + if include_scenarios: + scenario_idx = 0 + for k in range(negative_idx): + if not batch[i][f"classes_{negative_idx}"]: + continue + else: + scenario_idx += 1 + scenarios.append(batch[i]["scenarios"][scenario_idx]) + if include_scenarios: + return new_classes, scenarios + else: + return new_classes def generate_classes(classes: list, config: str): - if config == "multiclass": + if config == "synthetic_positives_multiclass": new_classes = generate_classes_multiclass(classes) - else: + elif config == "synthetic_positives_multilabel": new_classes = generate_classes_multilabel(classes) new_classes, labels = shuffle_classes(new_classes.copy(), classes) diff --git a/src/data/dataset.py b/geracl/data/dataset.py similarity index 72% rename from src/data/dataset.py rename to geracl/data/dataset.py index 710ac95..ebfbf59 100644 --- a/src/data/dataset.py +++ b/geracl/data/dataset.py @@ -28,12 +28,21 @@ def __init__(self, data: dict[str, list[str], list[int]], tokenizer: Tokenizer): def tokenize(self, sample): text = sample["text"] classes = sample["classes"] - labels = sample["labels"] + label = sample["labels"] + if isinstance(label, int): + label = [label] + scenario = None + if "scenarios" in sample: + scenario = sample["scenarios"] encoded_text = self._tokenizer(text, add_special_tokens=False).input_ids encoded_classes = self._tokenizer(classes, add_special_tokens=False).input_ids - - return encoded_text, encoded_classes, labels + if scenario is not None: + scenario = scenario + ": " + encoded_scenario = self._tokenizer(scenario, add_special_tokens=False).input_ids + return encoded_text, encoded_classes, label, encoded_scenario + else: + return encoded_text, encoded_classes, label def __len__(self): return len(self._dataset) diff --git a/src/data/__init__.py b/geracl/model/__init__.py similarity index 100% rename from src/data/__init__.py rename to geracl/model/__init__.py diff --git a/geracl/model/config.py b/geracl/model/config.py new file mode 100644 index 0000000..24a8ff7 --- /dev/null +++ b/geracl/model/config.py @@ -0,0 +1,52 @@ +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + + +class GeraclConfigHF(PretrainedConfig): + model_type = "GeRaCl" + is_composition = True + + def __init__( + self, + embedder_config=None, + embedder_name=None, + ffn_dim=None, + ffn_classes_dropout=0.4, + ffn_text_dropout=0.4, + device="cuda", + tokenizer_len=None, + pooling_type="mean", + loss_args={"loss_type": "bce"}, + **kwargs, + ): + if isinstance(embedder_config, dict): + embedder_config["model_type"] = ( + embedder_config["model_type"] if "model_type" in embedder_config else "modernbert" + ) + embedder_config = CONFIG_MAPPING[embedder_config["model_type"]](**embedder_config) + elif embedder_config is None: + embedder_config = CONFIG_MAPPING["modernbert"]() + + self.embedder_config = embedder_config + self.embedder_name = embedder_name + + self.hidden_size = self.embedder_config.hidden_size + + if tokenizer_len is not None: + self.tokenizer_len = tokenizer_len + else: + self.tokenizer_len = self.embedder_config.vocab_size + + if ffn_dim is None: + self.ffn_dim = self.hidden_size * 2 + else: + self.ffn_dim = ffn_dim + + self.ffn_classes_dropout = ffn_classes_dropout + self.ffn_text_dropout = ffn_text_dropout + + self.pooling_type = pooling_type + self.device = device + self.loss_args = loss_args + super().__init__(**kwargs) diff --git a/geracl/model/geracl.py b/geracl/model/geracl.py new file mode 100644 index 0000000..9e01c0e --- /dev/null +++ b/geracl/model/geracl.py @@ -0,0 +1,269 @@ +from functools import partial +from itertools import chain +from typing import Tuple + +import torch +import torch.optim +from loguru import logger +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities.types import STEP_OUTPUT +from sklearn.metrics import accuracy_score +from torch import Tensor +from torch.nn.functional import binary_cross_entropy_with_logits +from torchmetrics import MetricCollection +from torchmetrics.classification import BinaryAUROC, BinaryF1Score + +from geracl.model.geracl_core import GeraclCore +from geracl.utils import cosine_lambda, focal_loss_with_logits, linear_lambda + + +class Geracl(LightningModule): + """Lightning module that encapsulate all routines for zero-shot text classification. + + Maybe used as regular Torch module on inference: forward pass returns predicted classes. + Also support training via lightning Trainer, see: https://pytorch-lightning.readthedocs.io/en/stable/. + + Use HuggingFace models as backbone to embed tokens, e.g. "USER-base": + https://huggingface.co/deepvk/USER-base + Reports accuracy, binary F1-score, binary AUROC and either BCE loss or focal loss during training. + """ + + def __init__( + self, + embedder_name: str = "deepvk/USER-base", + *, + unfreeze_embedder: bool = False, + ffn_dim: int = 2048, + ffn_classes_dropout: float = 0.1, + ffn_text_dropout: float = 0.1, + device: str = "cuda", + tokenizer_len: int, + pooling_type: str = "mean", + loss_args: dict = None, + optimizer_args: dict = None, + scheduler_args: dict = None, + ): + """ + :param embedder_name: name of pretrained HuggingFace model to embed tokens. + :param unfreeze_embedder: if `True` then train top mlp layers along with backbone module. + :param ffn_dim: hidden dimension of mlp layers. + :param ffn_classes_dropout: dropout of the mlp layer used for transforming input classes embeddings. + :param ffn_text_dropout: dropout of the mlp layer used for transforming input text embedding. + :param device: name of device to train the model on. + :param tokenizer_len: sumber of tokens in tokenizer. + :param pooling_type: sentence embedding's pooling type (either mean or first). + :param loss_args: dict with arguments to choose appropriate loss function. + :param optimizer_args: dict with arguments to initalize optimizer. + :param scheduler_args: dict with arguments to initalize scheduler. + """ + super().__init__() + self.save_hyperparameters() + + self._optimizer_args = optimizer_args if optimizer_args is not None else None + self._scheduler_args = scheduler_args if scheduler_args is not None else None + self._loss_args = loss_args if loss_args is not None else None + + if self._loss_args["loss_type"] not in {"bce", "focal"}: + raise ValueError("Invalid loss type config parameter.") + + self._device = device + self._classification_core = GeraclCore( + embedder_name=embedder_name, + ffn_dim=ffn_dim, + ffn_classes_dropout=ffn_classes_dropout, + ffn_text_dropout=ffn_text_dropout, + device=device, + tokenizer_len=tokenizer_len, + pooling_type=pooling_type, + loss_args=loss_args, + ) + + if not unfreeze_embedder: + logger.info(f"Freezing embedding model: {self._classification_core._token_embedder.__class__.__name__}") + for param in self._classification_core._token_embedder.parameters(): + param.requires_grad = False + + self._step_outputs = { + f"{split}_{metric}": [] + for split in ["val", "test", "train"] + for metric in ["loss", "predictions", "target"] + } + + self._auroc_metric = MetricCollection({f"{split}_auroc": BinaryAUROC() for split in ["train", "val", "test"]}) + + self._f1_metric = MetricCollection( + {f"{split}_binary_f1": BinaryF1Score(threshold=0.1) for split in ["train", "val", "test"]} + ) + + def configure_optimizers(self): + parameters = chain( + self._classification_core._token_embedder.parameters(), + self._classification_core._mlp_classes.parameters(), + self._classification_core._mlp_text.parameters(), + ) + + if self._optimizer_args: + module_name, class_name = self._optimizer_args["class_path"].rsplit(".", 1) + optimizer_cls = getattr(__import__(module_name, fromlist=[class_name]), class_name) + if "init_params" in self._optimizer_args: + optimizer = optimizer_cls(parameters, **self._optimizer_args["init_params"]) + else: + optimizer = optimizer_cls(parameters) + else: + optimizer = torch.optim.AdamW(parameters) + + if self._scheduler_args is None: + return optimizer + + if self._scheduler_args["scheduler"] == "linear": + linear_lambda_with_total_steps = partial( + linear_lambda, + total_steps=self._scheduler_args["total_steps"], + warmup_steps=self._scheduler_args["warmup_steps"], + ) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_lambda_with_total_steps) + else: + cosine_lambda_with_total_steps = partial( + cosine_lambda, + total_steps=self._scheduler_args["total_steps"], + warmup_steps=self._scheduler_args["warmup_steps"], + ) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=cosine_lambda_with_total_steps) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + def forward(self, input_ids: Tensor, attention_mask: Tensor, classes_mask: Tensor, classes_count: Tensor) -> Tensor: # type: ignore + return self._classification_core.forward(input_ids, attention_mask, classes_mask, classes_count) + + def shared_step(self, batch: Tuple[Tensor, ...], split: str) -> STEP_OUTPUT: + """Shared step of them that used during training and evaluation. + Make forward pass of the model, calculate loss and metric and log them. + + :param batch: Tuple of + > input_ids [batch size; seq len] – input tokens ids padded to the same length; + > attention_mask [batch size; seq len] – mask with padding description, 0 means PAD token; + > classes_mask [batch size; seq len] - labels of each token; + > classes_count [batch_size] - classes count for each sample. + > positive_classes [batch_size] - indices of positive classes for each sample. + :param split: name of current split, one of `train`, `val`, or `test`. + :return: loss on the current batch. + """ + + input_ids, attention_mask, classes_mask, classes_count, positive_classes = batch + bs = len(input_ids) + + forward_batch = input_ids, attention_mask, classes_mask, classes_count + + similarities = self.forward(*forward_batch) + + target = torch.zeros(similarities.shape[0]).to(self._device) + idx = 0 + for i, sample_class_count in enumerate(classes_count): + for positive_class in positive_classes[i]: + target[idx + positive_class] = 1 + idx = idx + sample_class_count + + if self._loss_args["loss_type"] == "bce": + batch_loss = binary_cross_entropy_with_logits(similarities, target) + elif self._loss_args["loss_type"] == "focal": + batch_loss = focal_loss_with_logits( + similarities, + target, + alpha=self._loss_args["init_params"]["alpha"], + gamma=self._loss_args["init_params"]["gamma"], + label_smoothing=self._loss_args["init_params"]["label_smoothing"], + ignore_index=self._loss_args["init_params"]["ignore_index"], + reduction="mean", + ) + + with torch.no_grad(): + if split != "train": + self._step_outputs[f"{split}_loss"].append(batch_loss.item()) + + predicted_classes = [] + idx = 0 + for i in range(bs): + predicted_class_idx = similarities[idx : (idx + classes_count[i])].argmax() + predicted_classes.append(predicted_class_idx) + idx = idx + classes_count[i] + + self._step_outputs[f"{split}_predictions"] = self._step_outputs[f"{split}_predictions"] + [ + pred_class.to("cpu") for pred_class in predicted_classes + ] + self._step_outputs[f"{split}_target"] = self._step_outputs[f"{split}_target"] + positive_classes + + probs = torch.sigmoid(similarities) + batch_auroc = self._auroc_metric[f"{split}_auroc"](probs, target) + batch_f1 = self._f1_metric[f"{split}_binary_f1"](probs, target) + + if split == "train": + self.log_dict( + { + f"{split}/step_loss": batch_loss.item(), + } + ) + return batch_loss + + def _report_metrics(self, split: str, loss: list = None): + if split != "train": + epoch_auroc = self._auroc_metric[f"{split}_auroc"].compute() + self._auroc_metric[f"{split}_auroc"].reset() + + epoch_binary_f1 = self._f1_metric[f"{split}_binary_f1"].compute() + self._f1_metric[f"{split}_binary_f1"].reset() + + y_true = self._step_outputs[f"{split}_target"] + y_pred = self._step_outputs[f"{split}_predictions"] + y_true_multiclass = [true_classes[0] for true_classes in y_true] + accuracy = accuracy_score(y_true_multiclass, y_pred) + self.log(f"{split}/epoch_accuracy", accuracy) + + self.log_dict( + { + f"{split}/epoch_auroc": epoch_auroc, + f"{split}/epoch_binary_f1": epoch_binary_f1, + } + ) + + if loss: + epoch_loss = torch.tensor( + loss, + dtype=torch.float32, + device=self._device, + ).mean() + + self.log(f"{split}/epoch_loss", epoch_loss) + + def on_train_epoch_end(self): + self._report_metrics("train") + self._step_outputs["train_predictions"].clear() + self._step_outputs["train_target"].clear() + + def on_validation_epoch_end(self): + self._report_metrics("val", self._step_outputs["val_loss"]) + + self._step_outputs["val_loss"].clear() + self._step_outputs["val_predictions"].clear() + self._step_outputs["val_target"].clear() + + def on_test_epoch_end(self): + self._report_metrics("test", self._step_outputs["test_loss"]) + + self._step_outputs["test_loss"].clear() + self._step_outputs["test_predictions"].clear() + self._step_outputs["test_target"].clear() + + def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore + return self.shared_step(batch, "val") + + def test_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore + return self.shared_step(batch, "test") + + def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore + return self.shared_step(batch, "train") diff --git a/geracl/model/geracl_core.py b/geracl/model/geracl_core.py new file mode 100644 index 0000000..5f85aec --- /dev/null +++ b/geracl/model/geracl_core.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn +from torch import Tensor +from transformers import AutoModel + + +class GeraclCore(nn.Module): + """Core torch module that encapsulate all routines for zero-shot text classification. + + May be used as regular Torch module on inference: forward pass returns probabilities of choosing each possible class in a sample. + + Use HuggingFace models as backbone to embed tokens, e.g. "USER2-base": + https://huggingface.co/deepvk/USER2-base + """ + + def __init__( + self, + embedder_name: str = "deepvk/USER-base", + *, + ffn_dim: int = 2048, + ffn_classes_dropout: float = 0.1, + ffn_text_dropout: float = 0.1, + device: str = "cuda", + tokenizer_len: int, + pooling_type: str = "mean", + loss_args: dict = None, + ): + """ + :param embedder_name: name of pretrained HuggingFace model to embed tokens. + :param ffn_dim: hidden dimension of mlp layers. + :param ffn_classes_dropout: dropout of the mlp layer used for transforming input classes embeddings. + :param ffn_text_dropout: dropout of the mlp layer used for transforming input text embedding. + :param device: name of device to train the model on. + :param tokenizer_len: Number of tokens in tokenizer. + :param loss_args: Dict with arguments to choose appropriate loss function. + """ + super().__init__() + if pooling_type not in {"mean", "first"}: + raise ValueError("Invalid pooling type config parameter.") + self._pooling_type = pooling_type + + self._device = device + self._token_embedder = AutoModel.from_pretrained(embedder_name).to(self._device) + self._token_embedder.resize_token_embeddings(tokenizer_len) + + self._mlp_classes = nn.Sequential( + nn.Linear(self._token_embedder.config.hidden_size, ffn_dim), + nn.Dropout(ffn_classes_dropout), + nn.GELU(), + nn.Linear(ffn_dim, self._token_embedder.config.hidden_size), + ).to(self._device) + + self._mlp_text = nn.Sequential( + nn.Linear(self._token_embedder.config.hidden_size, ffn_dim), + nn.Dropout(ffn_text_dropout), + nn.GELU(), + nn.Linear(ffn_dim, self._token_embedder.config.hidden_size), + ).to(self._device) + + def _get_text_embeddings( + self, + embeddings: torch.Tensor, + classes_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Aggregate token embeddings that correspond to the *input text* (mask value **-3**) + and project the pooled vector through the text-MLP head. + + Args: + embeddings (Tensor): shape **(batch_size, seq_len, hidden_dim)** + Raw contextual embeddings produced by the encoder. + + classes_mask (Tensor): shape **(batch_size, seq_len)** + Integer mask that tags every token: + * **-3** -- genuine input-text tokens (these are the ones we pool) + * **-5** -- extra prompt / scenario tokens (ignored here) + * **-1/-2 / 0…N** -- other task-specific labels (ignored here) + + Returns: + Tensor: shape **(batch_size, text_proj_dim)** + One embedding per sample, summarising the entire input text after + the MLP projection. + """ + mask = classes_mask == -3 + + text_mask_expanded = mask.unsqueeze(-1).expand(embeddings.size()).float() + text_embeddings = torch.sum(embeddings * text_mask_expanded, 1) / torch.clamp( + text_mask_expanded.sum(1), min=1e-9 + ) + text_embeddings = self._mlp_text(text_embeddings) + return text_embeddings + + def _get_classes_embeddings( + self, embeddings: torch.Tensor, classes_mask: torch.Tensor, classes_count: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + """ + Build a single stack of per-class embeddings for the *entire* batch and project + them through the class-MLP head. + + The exact pooling strategy depends on `self._pooling_type`: + * **"mean"** - mean of token vectors spanning each label span + * **"first"** - first occurrence of each label (HF-style) + + Args: + embeddings (Tensor): shape **(batch_size, seq_len, hidden_dim)** + Contextual encoder output. + + classes_mask (Tensor): shape **(batch_size, seq_len)** + Integer mask assigning every token to either + * a class ID **0 … C-1**, + * the two special “separator” labels **-1** and **-2** (used to distinguish + positive and negative classes and set boundaries between classes), + * or any negative value reserved for non-class tokens (e.g. **-3**, **-5**). + + classes_count (Tensor): shape **(batch_size,)** + How many distinct classes appear in each sample. + The sum of this vector equals the first dimension of the returned tensor. + + Returns: + Tensor: shape **(total_classes_in_batch, class_proj_dim)** + One embedding per class instance across the batch after + the MLP projection, ordered batch-major (all classes of sample 0, + then sample 1, …). + """ + bs, _ = classes_mask.shape + + mlp_input = torch.empty( + (classes_count.sum(), embeddings.shape[-1]), dtype=torch.float, device=embeddings.device + ) + idx = 0 + for i in range(bs): + sample_classes_mask = classes_mask[i] + emb = embeddings[i] + if self._pooling_type == "mean": + sample_class_embeddings = [] + for label in range(classes_count[i]): + positions = torch.nonzero(sample_classes_mask == label, as_tuple=True)[0] + start_idx = positions.min() + end_idx = positions.max() + + # +1 because we should include token on end_idx position + class_mean = emb[start_idx : end_idx + 1].mean(dim=0) + sample_class_embeddings.append(class_mean) + sample_class_embeddings = torch.stack(sample_class_embeddings, dim=0) + elif self._pooling_type == "first": + positions_1 = torch.nonzero(sample_classes_mask == -1).squeeze(1) + positions_2 = torch.nonzero(sample_classes_mask == -2).squeeze(1) + sorted_classes_positions = torch.sort(torch.cat((positions_2, positions_1)))[0] + sample_class_embeddings = emb[sorted_classes_positions] + + mlp_input[idx : (idx + classes_count[i])] = sample_class_embeddings + idx = idx + classes_count[i] + + classes_embeddings = self._mlp_classes(mlp_input.to(self._device)) + return classes_embeddings + + def forward(self, input_ids: Tensor, attention_mask: Tensor, classes_mask: Tensor, classes_count: Tensor) -> Tensor: # type: ignore + """ + End-to-end zero-shot classification pass: + embed the input text, embed each candidate class, and return a similarity + score for every *text ↔ class* pair in the batch. + + Args: + input_ids (Tensor): shape **(batch_size, seq_len)** + Pre-tokenised input IDs. + + attention_mask (Tensor): shape **(batch_size, seq_len)** + Standard HF mask — **1** for real tokens, **0** for padding. + + classes_mask (Tensor): shape **(batch_size, seq_len)** + Per-token label map: + * **0 … C-1** – tokens belonging to that class span + * **-1**, **-2** – class separators (used to distinguish + positive and negative classes and set boundaries between classes) + * **-3** – tokens of the actual input text (pooled by + `_get_text_embeddings`) + * **-5** – prompt / scenario tokens (ignored here) + + classes_count (Tensor): shape **(batch_size,)** + Number of distinct candidate classes in each sample. + + Returns: + Tensor: shape **(batch_total_classes,)** where + `batch_total_classes = classes_count.sum()` + Dot-product similarities between the pooled text embedding + of each sample and each of its class embeddings, ordered + batch-major (all classes of sample 0, then sample 1, …). + """ + embeddings = self._token_embedder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state + text_embeddings = self._get_text_embeddings(embeddings, classes_mask) + classes_embeddings = self._get_classes_embeddings(embeddings, classes_mask, classes_count) + + # [all_classes_in_batch, embedding_dim] + wide_text_embeddings = torch.repeat_interleave(text_embeddings, classes_count, dim=0) + + similarities = torch.sum(classes_embeddings * wide_text_embeddings, dim=-1) + + return similarities diff --git a/geracl/model/hf_wrapper.py b/geracl/model/hf_wrapper.py new file mode 100644 index 0000000..cc2c824 --- /dev/null +++ b/geracl/model/hf_wrapper.py @@ -0,0 +1,26 @@ +from transformers import PreTrainedModel + +from geracl.model.config import GeraclConfigHF +from geracl.model.geracl_core import GeraclCore + + +class GeraclHF(PreTrainedModel): + config_class = GeraclConfigHF + + def __init__(self, config: GeraclConfigHF): + super().__init__(config) + self._classification_core = GeraclCore( + embedder_name=config.embedder_name, + ffn_dim=config.ffn_dim, + ffn_classes_dropout=config.ffn_classes_dropout, + ffn_text_dropout=config.ffn_text_dropout, + device=config.device, + tokenizer_len=config.tokenizer_len, + pooling_type=config.pooling_type, + loss_args=config.loss_args, + ) + self.post_init() + + # delegate to the core + def forward(self, *args, **kwargs): + return self._classification_core(*args, **kwargs) diff --git a/geracl/pipeline.py b/geracl/pipeline.py new file mode 100644 index 0000000..e85de89 --- /dev/null +++ b/geracl/pipeline.py @@ -0,0 +1,86 @@ +import numpy as np +import torch +from tqdm.auto import tqdm +from transformers import AutoTokenizer + +from geracl.data.batch_creation import prepare_inference_batch + + +class ZeroShotClassificationPipeline: + def __init__(self, model, tokenizer, device="cuda", progress_bar=True, input_prompt=None): + self.model = model + if isinstance(tokenizer, str): + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + self._tokenizer = tokenizer + self.progress_bar = progress_bar + + if not isinstance(device, torch.device): + if torch.cuda.is_available() and "cuda" in device: + self.device = torch.device(device) + else: + self.device = torch.device("cpu") + else: + self.device = device + + if self.model.device != self.device: + self.model.to(self.device) + + self._rng = np.random.default_rng() + self.input_prompt = input_prompt + + @torch.no_grad() + def get_similarities(self, texts, labels, same_labels=True, batch_size=100): + if isinstance(texts, str): + texts = [texts] + + results = [] + if same_labels: + labels = [labels for _ in range(batch_size)] + + iterable = range(0, len(texts), batch_size) + if self.progress_bar: + iterable = tqdm(iterable) + + for idx in iterable: + batch_texts = texts[idx : idx + batch_size] + + if same_labels: + tokenized_inputs = prepare_inference_batch( + batch_texts, labels, self._tokenizer, input_prompt=self.input_prompt + ) + else: + tokenized_inputs = prepare_inference_batch( + batch_texts, labels[idx : idx + batch_size], self._tokenizer, input_prompt=self.input_prompt + ) + input_ids, attention_mask, classes_mask, classes_count = [x.to(self.device) for x in tokenized_inputs] + similarities = self.model(input_ids, attention_mask, classes_mask, classes_count) + results.append(similarities) + + return results + + @torch.no_grad() + def __call__(self, texts, labels, batch_size=100): + if isinstance(texts, str): + texts = [texts] + if isinstance(labels[0], str): + same_labels = True + else: + same_labels = False + + similarities = self.get_similarities(texts, labels, same_labels, batch_size) + + real_predictions = [] + for i in range(len(similarities)): + if same_labels: + real_predictions.extend(torch.argmax(similarities[i].view(-1, len(labels)), dim=1).tolist()) + else: + label_idx = 0 + for k in range(batch_size): + if label_idx + len(labels[(i * batch_size) + k]) > len(similarities[i]): + break + pred = torch.argmax(similarities[i][label_idx : label_idx + len(labels[(i * batch_size) + k])]) + label_idx += len(labels[(i * batch_size) + k]) + real_predictions.append(pred.tolist()) + + return real_predictions diff --git a/geracl/utils.py b/geracl/utils.py new file mode 100644 index 0000000..d200743 --- /dev/null +++ b/geracl/utils.py @@ -0,0 +1,138 @@ +import math + +import torch +import torch.nn.functional as F +from tokenizers import Tokenizer + + +def add_required_tokens(tokenizer: Tokenizer) -> Tokenizer: + required_token_types = ["bos_token", "eos_token", "cls_token", "sep_token", "pad_token"] + default_tokens = { + "bos_token": "", + "eos_token": "", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "pad_token": "[PAD]", + } + + # Ensure each required token type exists + tokens_to_add = {} + for token_type in required_token_types: + current_token = getattr(tokenizer, token_type, None) + default_token = default_tokens[token_type] + + if current_token is None: + current_token = default_token + # Check if the token string is in the vocabulary + if current_token not in tokenizer.vocab: + tokens_to_add[token_type] = current_token + # Add missing tokens to the tokenizer + if len(tokens_to_add) > 0: + tokenizer.add_special_tokens(tokens_to_add) + + # Check for bos/cls conflict + if tokenizer.bos_token_id == tokenizer.cls_token_id: + new_token = tokenizer.cls_token + "_1" + tokenizer.add_special_tokens({"cls_token": new_token}) + + # Check for eos/sep conflict + if tokenizer.eos_token_id == tokenizer.sep_token_id: + new_token = tokenizer.sep_token + "_1" + tokenizer.add_special_tokens({"sep_token": new_token}) + + return tokenizer + + +def focal_loss_with_logits( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", + label_smoothing: float = 0.0, + ignore_index: int = -100, # default value for ignored index + weight: torch.Tensor = None, +) -> torch.Tensor: + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Code is taken from the GliClass GitHub repo: https://github.com/Knowledgator/GLiClass/. + + Args: + inputs (Tensor): A float tensor of arbitrary shape. + The predictions for each example. + targets (Tensor): A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float): Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default: ``0.25``. + gamma (float): Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. Default: ``2``. + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + label_smoothing (float): Specifies the amount of smoothing when computing the loss, + where 0.0 means no smoothing. + ignore_index (int): Specifies a target value that is ignored and does not contribute + to the input gradient. Default: ``-100``. + Returns: + Loss tensor with the reduction option applied. + """ + # Create a mask to ignore specified index + valid_mask = targets != ignore_index + + # Apply label smoothing if needed + if label_smoothing != 0: + with torch.no_grad(): + targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing + + # Apply sigmoid activation to inputs + p = torch.sigmoid(inputs) + + # Compute the binary cross-entropy loss without reduction + if weight is not None: + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=weight) + else: + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + # Apply the valid mask to the loss + loss = loss * valid_mask + + # Apply focal loss modulation if gamma is greater than 0 + if gamma > 0: + p_t = p * targets + (1 - p) * (1 - targets) + loss = loss * ((1 - p_t) ** gamma) + + # Apply alpha weighting if alpha is specified + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + # Apply reduction method + if reduction == "none": + return loss + elif reduction == "mean": + return loss.sum() / valid_mask.sum() # Normalize by the number of valid (non-ignored) elements + elif reduction == "sum": + return loss.sum() + else: + raise ValueError( + f"Invalid value for argument 'reduction': '{reduction}'. " + f"Supported reduction modes: 'none', 'mean', 'sum'" + ) + + +def linear_lambda(current_step: int, total_steps: int, warmup_steps: int): + if current_step < warmup_steps: + return current_step / warmup_steps + else: + progress = float(current_step - warmup_steps) / float(total_steps - warmup_steps) + return max(0.0, (1.0 - progress)) + + +def cosine_lambda(current_step: int, total_steps: int, warmup_steps: int): + if current_step < warmup_steps: + return current_step / warmup_steps + else: + progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) + return 0.5 * (1 + math.cos(math.pi * progress)) diff --git a/main.py b/main.py index 8e3ccc1..1bfa7f1 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from src.data.data_module import ZeroShotClassificationDataModule -from src.model.universal_classifier import ZeroShotClassifier -from src.utils import setup_logging +from geracl import Geracl +from geracl.data.data_module import ZeroShotClassificationDataModule def merge_configs(default, override): @@ -51,7 +50,6 @@ def configure_arg_parser() -> ArgumentParser: config = load_config_with_defaults(_args.config_path, default_config) seed_everything(seed) - setup_logging() wandb_logger = WandbLogger(project=config["other"]["wandb_project"]) checkpoint_callback = ModelCheckpoint( @@ -65,7 +63,7 @@ def configure_arg_parser() -> ArgumentParser: lr_logger = LearningRateMonitor("step") data_module = ZeroShotClassificationDataModule(**config["data_module"]) - model = ZeroShotClassifier(**config["model"], tokenizer_len=len(data_module.tokenizer)) + model = Geracl(**config["model"], tokenizer_len=len(data_module.tokenizer)) trainer = Trainer( **config["trainer"], diff --git a/pyproject.toml b/pyproject.toml index 301f8c6..4fccf78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [metadata] -name = "py_template" -author = "deepvk" -url = "https://github.com/deepvk/py_template" +name = "geracl" +author = "deepvk, Vyrodov Mikhail" +url = "https://github.com/deepvk/geracl" [tool.isort] profile = "black" @@ -23,7 +23,3 @@ warn_unused_ignores = "True" allow_redefinition = "True" warn_no_return = "False" no_implicit_optional = "False" - -[tool.pytest.ini_options] -testpaths = ["tests"] -addopts = ["--color=yes", "-s"] \ No newline at end of file diff --git a/requirements.dev.txt b/requirements.dev.txt index 6940cea..794984e 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -3,4 +3,4 @@ mypy==1.7.0 pytest==7.4.3 pytest-subtests==0.11.0 isort==5.12.0 -loguru=0.7.3 +loguru==0.7.3 diff --git a/src/model/__init__.py b/scripts/__init__.py similarity index 100% rename from src/model/__init__.py rename to scripts/__init__.py diff --git a/sweeps.py b/scripts/sweeps.py similarity index 68% rename from sweeps.py rename to scripts/sweeps.py index 49966bd..e0aa986 100644 --- a/sweeps.py +++ b/scripts/sweeps.py @@ -5,10 +5,10 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loggers import WandbLogger +from geracl.data.data_module import ZeroShotClassificationDataModule +from geracl.model.geracl import Geracl +from geracl.utils import setup_logging from main import load_config_with_defaults -from src.data.data_module import ZeroShotClassificationDataModule -from src.model.universal_classifier import ZeroShotClassifier -from src.utils import setup_logging sweep_configuration = { "method": "grid", @@ -18,10 +18,8 @@ "goal": "maximize", }, "parameters": { - "max_lr": {"values": [1e-7, 5e-6, 1e-6, 5e-6, 1e-5, 5e-5]}, - "batch_size": {"values": [16]}, - "scheduler": {"values": ["linear", "cosine"]}, - "max_epochs": {"values": [4, 6, 10]}, + "max_lr": {"values": [5e-6, 1e-6, 1e-5, 5e-5]}, + "weight_decay": {"values": [0, 0.01, 0.1]}, }, } @@ -30,7 +28,7 @@ def train_sweep(): with open("src/configs/default.yaml", "r") as file: default_config = yaml.safe_load(file) - config = load_config_with_defaults("src/configs/sweeps.yaml", default_config) + config = load_config_with_defaults("src/configs/custom.yaml", default_config) seed = 7 run = wandb.init(dir=config["other"]["wandb_dir"]) @@ -40,12 +38,8 @@ def train_sweep(): config["model"]["optimizer_args"]["init_params"] = dict() config["model"]["optimizer_args"]["init_params"]["lr"] = sweep_config.max_lr - config["data_module"]["batch_size"] = sweep_config.batch_size - config["data_module"]["val_batch_size"] = sweep_config.batch_size - config["trainer"]["max_epochs"] = sweep_config.max_epochs - config["model"]["scheduler_args"] = dict() - config["model"]["scheduler_args"]["scheduler"] = sweep_config.scheduler - config["model"]["scheduler_args"]["total_steps"] = 5810 * sweep_config.max_epochs + assert "scheduler_args" in config["model"] + config["model"]["scheduler_args"]["total_steps"] = config["builder"]["val_check_interval"] * config.max_epochs seed_everything(seed) setup_logging() @@ -64,7 +58,7 @@ def train_sweep(): data_module = ZeroShotClassificationDataModule() data_module = ZeroShotClassificationDataModule(**config["data_module"]) - model = ZeroShotClassifier(**config["model"], tokenizer_len=len(data_module.tokenizer)) + model = Geracl(**config["model"], tokenizer_len=len(data_module.tokenizer)) trainer = Trainer( **config["trainer"], diff --git a/scripts/train.py b/scripts/train.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/validate.py b/scripts/validate.py new file mode 100644 index 0000000..020d421 --- /dev/null +++ b/scripts/validate.py @@ -0,0 +1,162 @@ +import argparse +import json + +import numpy as np +import torch +from datasets import load_dataset +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from tqdm.auto import tqdm +from transformers import AutoTokenizer + +from geracl.data.batch_creation import prepare_inference_batch +from geracl.model.hf_wrapper import GeraclHF +from geracl.utils import add_required_tokens + + +def print_macro_metrics(true_labels, predicted_labels): + accuracy = accuracy_score(true_labels, predicted_labels) + precision = precision_score(true_labels, predicted_labels, average="macro") + recall = recall_score(true_labels, predicted_labels, average="macro") + f1 = f1_score(true_labels, predicted_labels, average="macro") + + print(f"accuracy={accuracy:.2f}\nprecision = {precision:.2f}\nRecall = {recall:.2f}\nF1 = {f1:.2f}") + return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1} + + +DATASETS = [ + # name, split, label_col, class_source, batch, n_classes + ( + "ai-forever/kinopoisk-sentiment-classification", + "validation", + "label", + lambda dataset: ["негативный", "нейтральный", "позитивный"], + 20, + 3, + ), + ( + "ai-forever/headline-classification", + "validation", + "label", + lambda dataset: ["спорт", "происшествия", "политика", "наука", "культура", "экономика"], + 100, + 6, + ), + ( + "ai-forever/ru-scibench-grnti-classification", + "test", + "label_text", + lambda dataset: list(sorted(set(dataset["label_text"]))), + 21, + 28, + ), + ( + "ai-forever/ru-scibench-oecd-classification", + "test", + "label_text", + lambda dataset: list(sorted(set(dataset["label_text"]))), + 29, + 29, + ), + ( + "ai-forever/inappropriateness-classification", + "test", + "label", + lambda dataset: ["приличный", "неприличный"], + 50, + 2, + ), +] + + +def permute_labels(rng, class_list): + perm = rng.permutation(len(class_list)) + return [class_list[i] for i in perm], perm + + +def compute_real_preds(logits, n_classes): + """ + logits: (batch*n_classes,) 1-D tensor that the current script appends + """ + return torch.argmax(logits.view(-1, n_classes), dim=1).tolist() + + +@torch.inference_mode() +def evaluate_one_dataset( + model, tokenizer, name, split, label_col, class_fn, batch_size, n_classes, input_prompt, device="cuda" +): + dataset = load_dataset(name)[split] + class_set = class_fn(dataset) + + rng = np.random.default_rng() + input_classes, target_labels = [], [] + + for gold in dataset[label_col]: + if isinstance(gold, float): + gold = int(gold) + permuted, perm = permute_labels(rng, class_set) + input_classes.append(permuted) + # gold may be int idx or text; handle both + gold_idx = gold if isinstance(gold, int) else class_set.index(gold) + target_labels.append(perm.tolist().index(gold_idx)) + + preds, texts = [], dataset["text"] + for i in tqdm(range(0, len(texts), batch_size), desc=name): + if i + batch_size < len(texts): + batch = prepare_inference_batch( + texts[i : i + batch_size], input_classes[i : i + batch_size], tokenizer, input_prompt=input_prompt + ) + else: + batch = prepare_inference_batch(texts[i:], input_classes[i:], tokenizer, input_prompt=input_prompt) + input_ids, attention_mask, classes_mask, classes_count = [x.to(device) for x in batch] + logits = model(input_ids, attention_mask, classes_mask, classes_count) + preds.extend(compute_real_preds(logits.cpu(), n_classes)) + + # de-permute back to original label space + real_targets, real_preds = [], [] + for classes, gold_result, pred_result in zip(input_classes, target_labels, preds): + perm = [class_set.index(c) for c in classes] + real_targets.append(perm[gold_result]) + real_preds.append(perm[pred_result]) + + return real_targets, real_preds + + +def parse_args(): + parser = argparse.ArgumentParser(description="Validate a model and dump metrics.") + parser.add_argument("--model_path", type=str, required=True, help="Path to the model on HuggingFace.") + parser.add_argument( + "--metrics_path", + type=str, + required=False, + default="metrics.json", + help="Output path where evaluation metrics json file will be written. Default: 'metrics.json'.", + ) + parser.add_argument( + "--input_prompt", type=str, required=False, default=None, help="Input prompt to the model. Default: None." + ) + return parser.parse_args() + + +def run_mteb_evaluation(model, tokenizer, metrics_path, input_prompt): + out = dict() + for dataset_name, split, label_col, class_fn, bs, classes_num in DATASETS: + targets, preds = evaluate_one_dataset( + model, tokenizer, dataset_name, split, label_col, class_fn, bs, classes_num, input_prompt + ) + out[dataset_name] = print_macro_metrics(targets, preds) + json.dump(out, open(metrics_path, "w"), ensure_ascii=False, indent=2) + + +if __name__ == "__main__": + args = parse_args() + model_path = args.model_path + metrics_path = args.metrics_path + input_prompt = args.input_prompt + + model = GeraclHF.from_pretrained(model_path).to("cuda").eval() + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = add_required_tokens(tokenizer) + metrics = dict() + metrics["model_path"] = model_path + + run_mteb_evaluation(model, tokenizer, metrics_path, input_prompt) diff --git a/src/configs/custom.yaml b/src/configs/custom.yaml deleted file mode 100644 index 2e19bd2..0000000 --- a/src/configs/custom.yaml +++ /dev/null @@ -1,22 +0,0 @@ -model: - embedder_name: "deepvk/USER-bge-m3" - ffn_dim: 2048 - ffn_classes_dropout: 0.3 - ffn_text_dropout: 0.3 - device: "cuda" - unfreeze_embedder: True - optimizer_args: - class_path: torch.optim.AdamW - init_params: - lr: 0.000005 - -data_module: - tokenizer_name: "deepvk/USER-bge-m3" - config: "task_creation_negatives" - -trainer: - accelerator: "gpu" - val_check_interval: 5810 - -other: - checkpoints_dir: "/data/checkpoints/sweeps" \ No newline at end of file diff --git a/src/configs/default.yaml b/src/configs/default.yaml deleted file mode 100644 index 58aff58..0000000 --- a/src/configs/default.yaml +++ /dev/null @@ -1,23 +0,0 @@ -model: - embedder_name: "deepvk/USER-base" - ffn_dim: 2048 - device: "cuda" - unfreeze_embedder: False - -data_module: - batch_size: 16 - val_batch_size: 16 - num_workers: 20 - tokenizer_name: "deepvk/USER-base" - config: "multiclass" - -trainer: -# max_steps: 200 - accelerator: "gpu" - val_check_interval: 100 - gradient_clip_val: 0.0 - log_every_n_steps: 50 - -other: - wandb_project: "universal_classifier" - checkpoints_dir: "/data/checkpoints" \ No newline at end of file diff --git a/src/data/data_module.py b/src/data/data_module.py deleted file mode 100644 index f638986..0000000 --- a/src/data/data_module.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import List, Optional, Tuple - -import torch -from datasets import load_dataset -from loguru import logger -from numpy import ndarray -from pytorch_lightning import LightningDataModule -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from tokenizers import Tokenizer -from torch.utils.data import DataLoader -from transformers import AutoTokenizer - -from src.data.data_processing import ( - generate_classes, - generate_classes_llm_negatives, - generate_classes_task_creation, - shuffle_classes, -) -from src.data.data_utils import make_classifier_prompt, prepare_batch -from src.data.dataset import ZeroShotClassificationDataset -from src.utils import add_required_tokens - - -class ZeroShotClassificationDataModule(LightningDataModule): - """Lightning data module for data handling. - Provides dataloader for all splits, e.g. `train_dataloader` method. - Public methods and attributes allow to retrieve information about texts and their classes. - - And `tokenizers` to tokenize data, e.g. `deepvk/USER-base` - """ - - def __init__( - self, - batch_size: int = 16, - val_batch_size: int = 16, - num_workers: int = 20, - tokenizer_name: str = "deepvk/USER-base", - config: str = "task_creation_negatives", - model_max_length: int = None, - ): - """Data module constructor. - - :param batch_size: train batch size; - :param val_batch_size: validation and test batch size; - :param num_workers: Number of workers for data loaders; - :param tokenizer_name: name of the tokenizer, "deepvk/USER-base" by default. - """ - super().__init__() - if config not in {"multiclass", "multilabel", "llm_negatives", "task_creation_negatives"}: - raise ValueError("Invalid DataModule config parameter.") - self._config = config - self._batch_size = batch_size - self._val_batch_size = val_batch_size - self._num_workers = num_workers - self._tokenizer_name = tokenizer_name - logger.info(f"Downloading and opening '{self._tokenizer_name}' tokenizer") - self._tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_name) - self._tokenizer = add_required_tokens(self._tokenizer) - if model_max_length is not None: - self._tokenizer.model_max_length = model_max_length - self._special_token_ids = { - "bos_token": self._tokenizer.bos_token_id, - "cls_token": self._tokenizer.cls_token_id, - "sep_token": self._tokenizer.sep_token_id, - "eos_token": self._tokenizer.eos_token_id, - } - - def setup(self, stage: Optional[str] = None): - logger.info("Downloading and opening 'deepvk/synthetic_classes' dataset") - data = load_dataset("deepvk/synthetic-classes", self._config) - train_data = load_dataset("deepvk/synthetic-classes", "task_creation_negatives_train", split="train") - - self._datasets = {} - self._labels = {} - - self._datasets["train"] = train_data - - for split in ["validation", "test"]: - logger.info(f"Initializing {split} dataset") - self._datasets[split] = ZeroShotClassificationDataset(data[split], self._tokenizer) - - def train_dataloader(self) -> TRAIN_DATALOADERS: - return DataLoader( - self._datasets["train"], - batch_size=self._batch_size, - collate_fn=self._task_creation_train_collate_fn, - num_workers=self._num_workers, - ) - - def val_dataloader(self) -> EVAL_DATALOADERS: - return DataLoader( - self._datasets["validation"], - batch_size=self._val_batch_size, - collate_fn=self._val_test_collate_fn, - num_workers=self._num_workers, - ) - - def test_dataloader(self) -> EVAL_DATALOADERS: - return DataLoader( - self._datasets["test"], - batch_size=self._val_batch_size, - collate_fn=self._val_test_collate_fn, - num_workers=self._num_workers, - ) - - def _task_creation_train_collate_fn(self, samples) -> Tuple[torch.Tensor, ...]: - new_classes = generate_classes_task_creation(samples) - - positive_classes = [[sample_classes[0]] for sample_classes in new_classes] - new_classes, positive_labels = shuffle_classes(new_classes, positive_classes) - new_samples = [(samples[i]["text"], new_classes[i], positive_labels[i]) for i in range(len(samples))] - texts = [sample_text for (sample_text, _, _) in new_samples] - tokenized_texts = self._tokenizer(texts, add_special_tokens=False).input_ids - tokenized_classes = [ - self._tokenizer(sample_classes, add_special_tokens=False).input_ids - for (_, sample_classes, _) in new_samples - ] - - new_samples = [(tokenized_texts[i], tokenized_classes[i], positive_labels[i]) for i in range(len(samples))] - - return self._val_test_collate_fn(new_samples) - - def _llm_negatives_train_collate_fn(self, samples) -> Tuple[torch.Tensor, ...]: - positive_classes = [sample["classes"] for sample in samples] - llm_negatives = [self._train_negatives["negatives"].loc[sample["idx"]] for sample in samples] - new_classes = generate_classes_llm_negatives(llm_negatives, positive_classes) - - new_classes, positive_labels = shuffle_classes(new_classes, positive_classes) - new_samples = [(samples[i]["text"], new_classes[i], positive_labels[i]) for i in range(len(samples))] - - texts = [sample_text for (sample_text, _, _) in new_samples] - tokenized_texts = self._tokenizer(texts, add_special_tokens=False).input_ids - tokenized_classes = [ - self._tokenizer(sample_classes, add_special_tokens=False).input_ids - for (_, sample_classes, _) in new_samples - ] - - new_samples = [(tokenized_texts[i], tokenized_classes[i], positive_labels[i]) for i in range(len(samples))] - - return self._val_test_collate_fn(new_samples) - - def _train_collate_fn(self, samples: list[tuple[ndarray, List[ndarray]]]) -> Tuple[torch.Tensor, ...]: - classes = [sample["classes"] for sample in samples] - new_classes, positive_labels = generate_classes(classes, self._config) - new_samples = [(samples[i]["text"], new_classes[i], positive_labels[i]) for i in range(len(samples))] - - texts = [sample_text for (sample_text, _, _) in new_samples] - tokenized_texts = self._tokenizer(texts, add_special_tokens=False).input_ids - tokenized_classes = [ - self._tokenizer(sample_classes, add_special_tokens=False).input_ids - for (_, sample_classes, _) in new_samples - ] - - new_samples = [(tokenized_texts[i], tokenized_classes[i], positive_labels[i]) for i in range(len(samples))] - - return self._val_test_collate_fn(new_samples) - - def _val_test_collate_fn(self, samples: list[tuple[ndarray, list[ndarray], list[int]]]) -> tuple[torch.Tensor, ...]: - result_prompts = [] - label_masks = [] - classes_count = [len(sample_classes) for (_, sample_classes, _) in samples] - positive_labels = [sample_positive_labels for (_, _, sample_positive_labels) in samples] - - for input_seq, sample_classes, sample_positive_labels in samples: - result_prompt, label_mask = make_classifier_prompt( - input_seq, self._special_token_ids, sample_classes, sample_positive_labels - ) - - result_prompts.append(result_prompt) - label_masks.append(label_mask) - - input_ids, attention_mask, classes_mask = prepare_batch( - result_prompts, - label_masks, - self._tokenizer.pad_token_id, - self._tokenizer.eos_token_id, - self._tokenizer.model_max_length if self._tokenizer.model_max_length else None, - ) - return input_ids, attention_mask, classes_mask, torch.tensor(classes_count), positive_labels - - @property - def tokenizer(self) -> Tokenizer: - """Return current tokenizer instance.""" - return self._tokenizer diff --git a/src/model/universal_classifier.py b/src/model/universal_classifier.py deleted file mode 100644 index 2e7387d..0000000 --- a/src/model/universal_classifier.py +++ /dev/null @@ -1,339 +0,0 @@ -import math -from functools import partial -from itertools import chain -from typing import Tuple - -import torch -import torch.nn as nn -import torch.optim -from loguru import logger -from pytorch_lightning import LightningModule -from pytorch_lightning.utilities.types import STEP_OUTPUT -from sklearn.metrics import accuracy_score, f1_score -from torch import Tensor -from torch.nn.functional import binary_cross_entropy_with_logits -from torchmetrics import MetricCollection -from torchmetrics.classification import BinaryAUROC, BinaryF1Score -from transformers import AutoModel - -warmup_steps = 1000 - - -def linear_lambda(current_step: int, total_steps: int): - if current_step < warmup_steps: - return current_step / warmup_steps - else: - progress = float(current_step - warmup_steps) / float(total_steps - warmup_steps) - return max(0.0, (1.0 - progress)) - - -def cosine_lambda(current_step: int, total_steps: int): - if current_step < warmup_steps: - return current_step / warmup_steps - else: - progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) - return 0.5 * (1 + math.cos(math.pi * progress)) - - -class ZeroShotClassifier(LightningModule): - """Lightning module that encapsulate all routines for zero-shot text classification. - - Maybe used as regular Torch module on inference: forward pass returns predicted classes. - Also support training via lightning Trainer, see: https://pytorch-lightning.readthedocs.io/en/stable/. - - Use HuggingFace models as backbone to embed tokens, e.g. "USER-base": - https://huggingface.co/deepvk/USER-base - Reports binary cross-entropy loss, binary F1-score and binary AUROC during training. - """ - - def __init__( - self, - embedder_name: str = "deepvk/USER-base", - *, - unfreeze_embedder: bool = False, - ffn_dim: int = 2048, - ffn_classes_dropout: float = 0.1, - ffn_text_dropout: float = 0.1, - device: str = "cuda", - tokenizer_len: int, - optimizer_args: dict = None, - scheduler_args: dict = None, - ): - """ - :param embedder_name: name of pretrained HuggingFace model to embed tokens. - :param unfreeze_embedder: if `True` ten train top classifier and backbone module. - :param ffn_dim: hidden dimension of mlp layer. - :param ffn_dropout: dropout of mlp layer. - :param device: name of device to train the model on. - :param tokenizer_len: Number of tokens in tokenizer. - :param optimizer_args: Dict with arguments to initalize optimizer. - :param scheduler_args: Dict with arguments to initalize scheduler. - """ - super().__init__() - self.save_hyperparameters() - - self._device = device - self._token_embedder = AutoModel.from_pretrained(embedder_name).to(self._device) - self._token_embedder.resize_token_embeddings(tokenizer_len) - - self._mlp_classes = nn.Sequential( - nn.Linear(self._token_embedder.config.hidden_size, ffn_dim), - nn.Dropout(ffn_classes_dropout), - nn.ReLU(inplace=True), - nn.Linear(ffn_dim, self._token_embedder.config.hidden_size), - ).to(self._device) - - self._mlp_text = nn.Sequential( - nn.Linear(self._token_embedder.config.hidden_size, ffn_dim), - nn.Dropout(ffn_text_dropout), - nn.ReLU(inplace=True), - nn.Linear(ffn_dim, self._token_embedder.config.hidden_size), - ).to(self._device) - - self._step_outputs = { - f"{split}_{metric}": [] - for split in ["val", "test", "train"] - for metric in ["loss", "predictions", "target"] - } - - if not unfreeze_embedder: - logger.info(f"Freezing embedding model: {self._token_embedder.__class__.__name__}") - for param in self._token_embedder.parameters(): - param.requires_grad = False - - self._auroc_metric = MetricCollection({f"{split}_auroc": BinaryAUROC() for split in ["train", "val", "test"]}) - - self._f1_metric = MetricCollection( - {f"{split}_binary_f1": BinaryF1Score(threshold=0.1) for split in ["train", "val", "test"]} - ) - - self._optimizer_args = optimizer_args if optimizer_args is not None else None - self._scheduler_args = scheduler_args if scheduler_args is not None else None - - def configure_optimizers(self): - """ - :param optimizer_cls: PyTorch optimizer class, e.g. `torch.optim.AdamW`. - :param scheduler_cls: PyTorch scheduler class, e.g. `torch.optim.lr_scheduler.LambdaLR`. - If `None`, then constant lr. - """ - parameters = chain( - self._token_embedder.parameters(), self._mlp_classes.parameters(), self._mlp_text.parameters() - ) - - if self._optimizer_args: - module_name, class_name = self._optimizer_args["class_path"].rsplit(".", 1) - optimizer_cls = getattr(__import__(module_name, fromlist=[class_name]), class_name) - if "init_params" in self._optimizer_args: - optimizer = optimizer_cls(parameters, **self._optimizer_args["init_params"]) - else: - optimizer = optimizer_cls(parameters) - else: - optimizer = torch.optim.AdamW(parameters) - - if self._scheduler_args is None: - return optimizer - - if self._scheduler_args["scheduler"] == "linear": - linear_lambda_with_total_steps = partial(linear_lambda, total_steps=self._scheduler_args["total_steps"]) - lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_lambda_with_total_steps) - else: - cosine_lambda_with_total_steps = partial(cosine_lambda, total_steps=self._scheduler_args["total_steps"]) - lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=cosine_lambda_with_total_steps) - - # module_name, class_name = self._scheduler_args["class_path"].rsplit(".", 1) - # scheduler_cls = getattr(__import__(module_name, fromlist=[class_name]), class_name) - # lr_scheduler = scheduler_cls(optimizer, **self._scheduler_args["init_params"]) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", - }, - } - - def _get_text_embeddings( - self, - embeddings: torch.Tensor, - classes_mask: torch.Tensor, - ) -> torch.Tensor: - # -3 stands for input text tokens in classes_mask - mask = classes_mask == -3 - masked_values = embeddings * mask.unsqueeze(-1) - num_filtered = mask.sum(dim=1) - text_embeddings = masked_values.sum(dim=1) / num_filtered.unsqueeze(-1) - text_embeddings = self._mlp_text(text_embeddings) - return text_embeddings - - def _get_classes_embeddings( - self, embeddings: torch.Tensor, classes_mask: torch.Tensor, classes_count: torch.Tensor - ) -> tuple[torch.Tensor, ...]: - bs, _ = classes_mask.shape - - mlp_input = torch.empty( - (classes_count.sum(), embeddings.shape[-1]), dtype=torch.float, device=embeddings.device - ) - idx = 0 - for i in range(bs): - sample_classes_mask = classes_mask[i] - emb = embeddings[i] - - sample_class_embeddings = [] - for label in range(classes_count[i]): - positions = torch.nonzero(sample_classes_mask == label, as_tuple=True)[0] - start_idx = positions.min().item() - end_idx = positions.max().item() - - # +1 because we should include token on end_idx position - class_mean = emb[start_idx : end_idx + 1].mean(dim=0) - sample_class_embeddings.append(class_mean) - - sample_class_embeddings = torch.stack(sample_class_embeddings, dim=0) - mlp_input[idx : (idx + classes_count[i])] = sample_class_embeddings - idx += classes_count[i] - - # [all_classes_in_batch, embedding_dim] - classes_embeddings = self._mlp_classes(mlp_input.to(self._device)) - - return classes_embeddings - - def forward(self, input_ids: Tensor, attention_mask: Tensor, classes_mask: Tensor, classes_count: Tensor) -> Tensor: # type: ignore - """Forward pass of zero-shot classification model. - Could be used during inference to classify text. - - :param input_ids: [batch size; seq len] -- batch with pretokenized texts. - :param attention_mask: [batch size; seq len] -- attention mask with 0 for padding tokens. - :param classes_mask: [batch size; seq len] - labels of each token. - :return: [] -- . - """ - bs, seq_len = attention_mask.shape - # print(seq_len) - - embeddings = self._token_embedder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state - text_embeddings = self._get_text_embeddings(embeddings, classes_mask) - classes_embeddings = self._get_classes_embeddings(embeddings, classes_mask, classes_count) - - # [all_classes_in_batch, embedding_dim] - wide_text_embeddings = torch.repeat_interleave(text_embeddings, classes_count, dim=0) - - similarities = torch.sum(classes_embeddings * wide_text_embeddings, dim=-1) - - return similarities - - def shared_step(self, batch: Tuple[Tensor, ...], split: str) -> STEP_OUTPUT: - """Shared step of them that used during training and evaluation. - Make forward pass of the model, calculate loss and metric and log them. - - :param batch: Tuple of - > input_ids [batch size; seq len] – input tokens ids padded to the same length; - > attention_mask [batch size; seq len] – mask with padding description, 0 means PAD token; - > classes_mask [batch size; seq len] - labels of each token; - > classes_count [batch_size] - classes count for each sample. - > positive_classes [batch_size] - indices of positive classes for each sample. - :param split: name of current split, one of `train`, `val`, or `test`. - :return: loss on the current batch. - """ - - input_ids, attention_mask, classes_mask, classes_count, positive_classes = batch - bs = len(input_ids) - - forward_batch = input_ids, attention_mask, classes_mask, classes_count - - similarities = self.forward(*forward_batch) - - target = torch.zeros(similarities.shape[0]).to(self._device) - idx = 0 - for i, sample_class_count in enumerate(classes_count): - for positive_class in positive_classes[i]: - target[idx + positive_class] = 1 - idx += sample_class_count - - batch_loss = binary_cross_entropy_with_logits(similarities, target) - - with torch.no_grad(): - if split != "train": - self._step_outputs[f"{split}_loss"].append(batch_loss.item()) - - predicted_classes = [] - idx = 0 - for i in range(bs): - predicted_class_idx = similarities[idx : (idx + classes_count[i])].argmax() - predicted_classes.append(predicted_class_idx) - idx += classes_count[i] - - self._step_outputs[f"{split}_predictions"] += [pred_class.to("cpu") for pred_class in predicted_classes] - self._step_outputs[f"{split}_target"] += positive_classes - - probs = torch.sigmoid(similarities) - batch_auroc = self._auroc_metric[f"{split}_auroc"](probs, target) - batch_f1 = self._f1_metric[f"{split}_binary_f1"](probs, target) - - if split == "train": - self.log_dict( - { - f"{split}/step_loss": batch_loss.item(), - f"{split}/step_auroc": batch_auroc, - f"{split}/step_f1": batch_f1, - } - ) - return batch_loss - - def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore - return self.shared_step(batch, "train") - - def _report_metrics(self, split: str, loss: list = None): - epoch_auroc = self._auroc_metric[f"{split}_auroc"].compute() - self._auroc_metric[f"{split}_auroc"].reset() - - epoch_binary_f1 = self._f1_metric[f"{split}_binary_f1"].compute() - self._f1_metric[f"{split}_binary_f1"].reset() - - y_true = self._step_outputs[f"{split}_target"] - y_pred = self._step_outputs[f"{split}_predictions"] - y_true_multiclass = [true_classes[0] for true_classes in y_true] - accuracy = accuracy_score(y_true_multiclass, y_pred) - self.log(f"{split}/epoch_accuracy", accuracy) - - # if split in {"val", "test"}: - # micro_f1 = f1_score(y_true, y_pred, average="micro") - # self.log(f"{split}/epoch_micro_f1", micro_f1) - - self.log_dict( - { - f"{split}/epoch_auroc": epoch_auroc, - f"{split}/epoch_binary_f1": epoch_binary_f1, - } - ) - - if loss: - epoch_loss = torch.tensor( - loss, - dtype=torch.float32, - device=self._device, - ).mean() - - self.log(f"{split}/epoch_loss", epoch_loss) - - def on_train_epoch_end(self): - self._report_metrics("train") - self._step_outputs["train_predictions"].clear() - self._step_outputs["train_target"].clear() - - def on_validation_epoch_end(self): - self._report_metrics("val", self._step_outputs["val_loss"]) - - self._step_outputs["val_loss"].clear() - self._step_outputs["val_predictions"].clear() - self._step_outputs["val_target"].clear() - - def on_test_epoch_end(self): - self._report_metrics("test", self._step_outputs["test_loss"]) - - self._step_outputs["test_loss"].clear() - self._step_outputs["test_predictions"].clear() - self._step_outputs["test_target"].clear() - - def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore - return self.shared_step(batch, "val") - - def test_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT: # type: ignore - return self.shared_step(batch, "test") diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index 07a6024..0000000 --- a/src/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging as py_logging - -from datasets.utils import logging as ds_logging -from tokenizers import Tokenizer -from transformers.utils import logging as tr_logging - - -def setup_logging(): - py_logging.basicConfig(level=py_logging.INFO) - tr_logging.set_verbosity_info() - tr_logging.disable_progress_bar() - ds_logging.set_verbosity_info() - ds_logging.disable_progress_bar() - - -def add_required_tokens(tokenizer: Tokenizer) -> Tokenizer: - required_token_types = ["bos_token", "eos_token", "cls_token", "sep_token", "pad_token"] - default_tokens = { - "bos_token": "", - "eos_token": "", - "cls_token": "[CLS]", - "sep_token": "[SEP]", - "pad_token": "[PAD]", - } - - # Ensure each required token type exists - tokens_to_add = {} - for token_type in required_token_types: - current_token = getattr(tokenizer, token_type, None) - default_token = default_tokens[token_type] - - if current_token is None: - current_token = default_token - - # Check if the token string is in the vocabulary - if current_token not in tokenizer.vocab: - tokens_to_add[token_type] = current_token - - # Add missing tokens to the tokenizer - if tokens_to_add: - tokenizer.add_special_tokens(tokens_to_add) - - # Check for bos/cls conflict - if tokenizer.bos_token_id == tokenizer.cls_token_id: - new_token = tokenizer.cls_token + "_1" - tokenizer.add_special_tokens({"cls_token": new_token}) - - # Check for eos/sep conflict - if tokenizer.eos_token_id == tokenizer.sep_token_id: - new_token = tokenizer.sep_token + "_1" - tokenizer.add_special_tokens({"sep_token": new_token}) - - return tokenizer diff --git a/tests/test_model.py b/tests/test_model.py deleted file mode 100644 index 039f9cc..0000000 --- a/tests/test_model.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_dummy(): - assert 2 + 2 == 4, "Wow, this really shouldn't have happened."