From 1477deac2d421c15b06bc948e9184c2c18edffa7 Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 1/9] Revert "feat: add the 4-bit quantisation option and
remove unnecessary base model copying"
This reverts commit d03f4875d323d005b43d6e3e9e60abf1c93ff07d.
---
.github/workflows/main.yaml | 2 +-
app/cli/cli.py | 8 +---
app/model_services/base.py | 6 +--
app/model_services/huggingface_llm_model.py | 12 ++---
app/model_services/huggingface_ner_model.py | 9 +---
app/model_services/medcat_model.py | 9 +---
app/model_services/medcat_model_deid.py | 9 +---
app/model_services/trf_model_deid.py | 2 +-
app/trainers/huggingface_llm_trainer.py | 52 +++++++++------------
app/utils.py | 5 ++
10 files changed, 42 insertions(+), 72 deletions(-)
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml
index 6c3a9dd..ffce1f8 100644
--- a/.github/workflows/main.yaml
+++ b/.github/workflows/main.yaml
@@ -24,7 +24,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
- uv sync --group dev --group docs
+ uv sync --group dev --group docs --group vllm
- name: Check types
run: |
uv run mypy app
diff --git a/app/cli/cli.py b/app/cli/cli.py
index 6003407..8a94647 100644
--- a/app/cli/cli.py
+++ b/app/cli/cli.py
@@ -67,7 +67,6 @@ def serve_model(
streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"),
device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"),
llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"),
- load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
) -> None:
"""
@@ -85,7 +84,6 @@ def serve_model(
streamable (bool): Serve the streamable endpoints only. Defaults to False.
device (Device): The device to serve the model on. Defaults to Device.DEFAULT.
llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS.
- load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
debug (Optional[bool]): Run in debug mode if set to True.
"""
@@ -137,7 +135,7 @@ def serve_model(
if model_path:
model_service = model_service_dep()
model_service.model_name = model_name
- model_service.init_model(load_in_4bit=load_in_4bit)
+ model_service.init_model()
cms_globals.model_manager_dep = ModelManagerDep(model_service)
elif mlflow_model_uri:
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
@@ -189,7 +187,6 @@ def train_model(
description: Optional[str] = typer.Option(None, help="The description of the training or change logs"),
model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"),
device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"),
- load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
) -> None:
"""
@@ -209,7 +206,6 @@ def train_model(
description (Optional[str]): The optional description of the training or change logs.
model_name (Optional[str]): The optional string representation of the model name.
device (Device): The device to train the model on. Defaults to Device.DEFAULT.
- load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
debug (Optional[bool]): Run in debug mode if set to True.
"""
@@ -233,7 +229,7 @@ def train_model(
pass
model_service = model_service_dep()
model_service.model_name = model_name if model_name is not None else "CMS model"
- model_service.init_model(load_in_4bit=load_in_4bit)
+ model_service.init_model()
elif mlflow_model_uri:
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
model_service.model_name = model_name if model_name is not None else "CMS model"
diff --git a/app/model_services/base.py b/app/model_services/base.py
index a7b6323..dfde491 100644
--- a/app/model_services/base.py
+++ b/app/model_services/base.py
@@ -154,14 +154,10 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
raise NotImplementedError
@abstractmethod
- def init_model(self, *args: Any, **kwargs: Any) -> None:
+ def init_model(self) -> None:
"""
Initialises the model and auxiliary resources.
- Args:
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
-
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index a747739..98f9bea 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -174,14 +174,8 @@ def load_model(
else:
raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}")
- def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> None:
- """Initialises the HuggingFace model and its tokenizer based on the configuration.
-
- Args:
- load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False.
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
- """
+ def init_model(self) -> None:
+ """Initialises the HuggingFace model and its tokenizer based on the configuration."""
if all([
hasattr(self, "_model"),
@@ -191,7 +185,7 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N
]):
logger.warning("Model service is already initialised and can be initialised only once")
else:
- self._model, self._tokenizer = self.load_model(self._model_pack_path, load_in_4bit=load_in_4bit)
+ self._model, self._tokenizer = self.load_model(self._model_pack_path)
if non_default_device_is_available(get_settings().DEVICE):
self._model.to(get_settings().DEVICE)
if self._enable_trainer:
diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py
index e741705..d982836 100644
--- a/app/model_services/huggingface_ner_model.py
+++ b/app/model_services/huggingface_ner_model.py
@@ -175,13 +175,8 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
else:
raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}")
- def init_model(self, *args: Any, **kwargs: Any) -> None:
- """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration.
-
- Args:
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
- """
+ def init_model(self) -> None:
+ """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration."""
if all([
hasattr(self, "_model"),
diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py
index 9ab5235..a3a6f2c 100644
--- a/app/model_services/medcat_model.py
+++ b/app/model_services/medcat_model.py
@@ -119,13 +119,8 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
else:
raise ConfigurationException("Model package archive format is not supported")
- def init_model(self, *args: Any, **kwargs: Any) -> None:
- """Initializes the MedCAT model based on the configuration.
-
- Args:
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
- """
+ def init_model(self) -> None:
+ """Initializes the MedCAT model based on the configuration."""
if hasattr(self, "_model") and isinstance(self._model, CAT):
logger.warning("Model service is already initialised and can be initialised only once")
diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py
index 9ec1248..fe94dde 100644
--- a/app/model_services/medcat_model_deid.py
+++ b/app/model_services/medcat_model_deid.py
@@ -178,13 +178,8 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
return annotations_list
- def init_model(self, *args: Any, **kwargs: Any) -> None:
- """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration.
-
- Args:
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
- """
+ def init_model(self) -> None:
+ """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration."""
if hasattr(self, "_model") and isinstance(self._model, CAT):
logger.warning("Model service is already initialised and can be initialised only once")
diff --git a/app/model_services/trf_model_deid.py b/app/model_services/trf_model_deid.py
index fbf6290..fb8e3ac 100644
--- a/app/model_services/trf_model_deid.py
+++ b/app/model_services/trf_model_deid.py
@@ -86,7 +86,7 @@ def load_model(
logger.info("Model loaded from %s", unpacked_model_dir)
return tokenizer, model
- def init_model(self, *args: Any, **kwargs: Any) -> None:
+ def init_model(self) -> None:
if hasattr(self, "_model") and isinstance(self._model, PreTrainedModel):
logger.warning("Model service is already initialised and can be initialised only once")
else:
diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py
index b85f44b..895a58e 100644
--- a/app/trainers/huggingface_llm_trainer.py
+++ b/app/trainers/huggingface_llm_trainer.py
@@ -88,11 +88,8 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
self._model_service = model_service
self._model_name = model_service.model_name
self._model_pack_path = model_service._model_pack_path
- self._retrained_models_dir = os.path.join(
- model_service._model_parent_dir,
- "retrained",
- self._model_name.replace(" ", "_"),
- )
+ self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained",
+ self._model_name.replace(" ", "_"))
self._model_manager = ModelManager(type(model_service), model_service._config)
self._max_length = model_service.model.config.max_position_embeddings
os.makedirs(self._retrained_models_dir, exist_ok=True)
@@ -309,7 +306,7 @@ def run(
logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
- trained_model_pack_path = None
+ copied_model_pack_path = None
redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true"
skip_save_model = self._config.SKIP_SAVE_MODEL == "true"
results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results"))
@@ -322,16 +319,15 @@ def run(
if not eval_mode:
try:
- logger.info("Loading a PEFT model for training...")
- model_pack_file_ext = get_model_data_package_extension(self._model_pack_path)
- trained_model_pack_path = self._model_pack_path.replace(
- model_pack_file_ext,
- f"_trained_{run_id}{model_pack_file_ext}",
+ logger.info("Loading a new model copy for training...")
+ copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id)
+ model, tokenizer = self._model_service.load_model(
+ copied_model_pack_path,
+ load_in_4bit=True, # for memory efficient training
)
- model, tokenizer = self._model_service.model, self._model_service.tokenizer
- trained_model_directory = os.path.join(
- os.path.dirname(trained_model_pack_path),
- get_model_data_package_base_name(trained_model_pack_path),
+ copied_model_directory = os.path.join(
+ os.path.dirname(copied_model_pack_path),
+ get_model_data_package_base_name(copied_model_pack_path),
)
if non_default_device_is_available(self._config.DEVICE):
@@ -359,7 +355,7 @@ def run(
],
)
- peft_model = get_peft_model(model, lora_config)
+ model = get_peft_model(model, lora_config)
mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
@@ -382,7 +378,6 @@ def run(
training_args = GRPOConfig(
output_dir=results_path,
logging_dir=logs_path,
- logging_steps=log_frequency,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
@@ -390,18 +385,20 @@ def run(
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
+ logging_steps=1,
per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations
gradient_accumulation_steps=1,
num_generations=6,
max_prompt_length=max_prompt_length,
max_completion_length=max_seq_length - max_prompt_length,
num_train_epochs = training_params["nepochs"],
+ max_steps=250,
save_steps=250,
max_grad_norm=0.1,
report_to="none",
)
trainer = GRPOTrainer(
- model=peft_model,
+ model=model,
processing_class=tokenizer,
reward_funcs=self._get_reward_functions(),
args=training_args,
@@ -412,7 +409,7 @@ def run(
else:
raise ConfigurationException(f"Unsupported trainer type: {trainer_type}")
- self._tracker_client.log_model_config({**model.config.to_dict(), **peft_model.peft_config})
+ self._tracker_client.log_model_config(model.config.to_dict())
self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version)
logger.info(f"Performing {trainer_type.upper()} training...")
@@ -425,13 +422,11 @@ def run(
model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE)
model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}"
retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name)
- model = peft_model.merge_and_unload()
model.save_pretrained(
- trained_model_directory,
+ copied_model_directory,
safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"),
)
- tokenizer.save_pretrained(trained_model_directory)
- create_model_data_package(trained_model_directory, retrained_model_pack_path)
+ create_model_data_package(copied_model_directory, retrained_model_pack_path)
model_uri = self._tracker_client.save_model(
retrained_model_pack_path,
self._model_name,
@@ -480,7 +475,7 @@ def run(
with self._training_lock:
self._training_in_progress = False
self._clean_up_training_cache()
- self._housekeep_file(trained_model_pack_path)
+ self._housekeep_file(copied_model_pack_path)
if trainer is not None:
del trainer
gc.collect()
@@ -510,7 +505,6 @@ def run(
training_args = GRPOConfig(
output_dir=results_path,
logging_dir=logs_path,
- logging_steps=log_frequency,
per_device_eval_batch_size=6,
num_generations=2,
max_prompt_length=max_prompt_length,
@@ -613,19 +607,19 @@ def correctness_reward_func(
)
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
- def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
- def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^\n.*?\n\n\n.*?\n\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
- def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
@@ -646,7 +640,7 @@ def count_xml(text: str) -> float:
count -= (len(text.split("\n")[-1]) - 1) * 0.001
return count
- def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
diff --git a/app/utils.py b/app/utils.py
index 47faefa..370c4e2 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -547,6 +547,11 @@ def unpack_model_data_package(model_data_file_path: str, model_data_folder_path:
elif model_data_file_path.endswith(".tar.gz"):
with tarfile.open(model_data_file_path, "r:gz") as f:
for member in f.getmembers():
+ path_parts = member.name.split(os.sep)
+ stripped_path = os.sep.join(path_parts[1:])
+ if not stripped_path:
+ continue
+ member.name = stripped_path
f.extract(member, path=model_data_folder_path)
return True
else:
From e4966c223c606a2159ac6503927606dbcc9d2f8f Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 2/9] Revert "feat: add the option to include rewards metrics"
This reverts commit d1ff2fb180178e16e73bf7be06bcbe8000d4e44d.
---
app/trainers/huggingface_llm_trainer.py | 91 +------------------------
1 file changed, 3 insertions(+), 88 deletions(-)
diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py
index 895a58e..b41bdd1 100644
--- a/app/trainers/huggingface_llm_trainer.py
+++ b/app/trainers/huggingface_llm_trainer.py
@@ -7,7 +7,6 @@
import re
import threading
import json
-import inspect
import pandas as pd
from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable
from transformers import __version__ as transformers_version
@@ -483,7 +482,6 @@ def run(
else:
try:
logger.info("Evaluating the running model...")
- include_rewards_metrics = training_params.get("include_rewards_metrics", False)
model, tokenizer = self._model_service.model, self._model_service.tokenizer
if non_default_device_is_available(self._config.DEVICE):
model.to(self._config.DEVICE)
@@ -530,23 +528,8 @@ def run(
)
eval_metrics = trainer.evaluate()
- if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics:
- eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])})
logger.info(f"Evaluation metrics: {eval_metrics}")
self._tracker_client.send_hf_metrics_logs(eval_metrics, 0)
- if include_rewards_metrics:
- try:
- reward_metrics = self._evaluate_with_rewards(
- model=model,
- tokenizer=tokenizer,
- eval_dataset=eval_dataset,
- max_new_tokens=training_args.max_completion_length,
- )
- if reward_metrics:
- logger.info(f"Reward metrics: {reward_metrics}")
- self._tracker_client.send_hf_metrics_logs(reward_metrics, 0)
- except Exception as e:
- logger.warning(f"Failed to compute reward-based metrics: {e}")
self._tracker_client.end_with_success()
logger.info("Model evaluation finished")
except torch.OutOfMemoryError as e:
@@ -594,8 +577,8 @@ def correctness_reward_func(
answer: List,
**kwargs: Dict[str, Any]
) -> List[float]:
- responses = [completion[0]["content"] for completion in completions]
- q = prompts[0][-1]["content"]
+ responses = [completion[0]['content'] for completion in completions]
+ q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
logger.debug(
"%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s",
@@ -608,7 +591,7 @@ def correctness_reward_func(
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
- responses = [completion[0]["content"] for completion in completions]
+ responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
@@ -652,74 +635,6 @@ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> l
correctness_reward_func,
]
- def _evaluate_with_rewards(
- self,
- model: PreTrainedModel,
- tokenizer: PreTrainedTokenizerBase,
- eval_dataset: datasets.Dataset,
- max_new_tokens: int,
- ) -> Dict[str, float]:
- model.eval()
- if non_default_device_is_available(self._config.DEVICE):
- model.to(self._config.DEVICE)
-
- reward_funcs = self._get_reward_functions()
- reward_sums: Dict[str, float] = {fn.__name__: 0.0 for fn in reward_funcs}
- count = 0
-
- for example in eval_dataset:
- if "prompt" not in example:
- continue
- messages = example["prompt"]
- answer = example.get("answer", "")
-
- prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- inputs = tokenizer(prompt_text, return_tensors="pt")
- input_ids = inputs["input_ids"]
- attention_mask = inputs.get("attention_mask")
- if non_default_device_is_available(self._config.DEVICE):
- input_ids = input_ids.to(self._config.DEVICE)
- attention_mask = attention_mask.to(self._config.DEVICE)
-
- with torch.no_grad():
- generated = model.generate(
- input_ids=input_ids,
- attention_mask=attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- temperature=0.0,
- eos_token_id=getattr(tokenizer, "eos_token_id", None),
- pad_token_id=getattr(tokenizer, "pad_token_id", 0),
- )
-
- completion_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
- for fn in reward_funcs:
- sig = inspect.signature(fn)
- kwargs: Dict[str, Any] = {}
- if "prompts" in sig.parameters:
- kwargs["prompts"] = [messages]
- if "completions" in sig.parameters:
- kwargs["completions"] = [({"content": completion_text},)]
- if "answer" in sig.parameters:
- kwargs["answer"] = [answer]
-
- try:
- rewards = fn(**kwargs) # type: ignore
- value = float(rewards[0]) if isinstance(rewards, (list, tuple)) and rewards else float(rewards)
- except Exception:
- value = 0.0
-
- reward_sums[fn.__name__] += value
- count += 1
- if count == 0:
- return {}
-
- reward_avgs = {f"reward_{name}": total / count for name, total in reward_sums.items()}
- reward_overall_mean = sum(reward_avgs.values()) / len(reward_avgs) if reward_avgs else 0.0
- reward_avgs["reward_overall_mean"] = reward_overall_mean
- reward_avgs["reward_samples"] = float(count)
- return reward_avgs
-
@final
class MLflowLoggingCallback(TrainerCallback):
From ed1e0fa166339254996b880250006b2e5d25e6b9 Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 3/9] Revert "fix: use the GRPO trainer for evaluation"
This reverts commit 5075fa3cd1f2cc32ab4dbbb3ce2ff22e4b1d056e.
---
app/api/utils.py | 30 +--
app/exception.py | 4 -
app/model_services/huggingface_llm_model.py | 3 +-
app/trainers/huggingface_llm_trainer.py | 261 +++++++++-----------
4 files changed, 116 insertions(+), 182 deletions(-)
diff --git a/app/api/utils.py b/app/api/utils.py
index 87cea26..e73425d 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -27,13 +27,7 @@
from fastapi_users.jwt import decode_jwt
from app.config import Settings
from app.domain import TagsGenerative
-from app.exception import (
- StartTrainingException,
- AnnotationException,
- ConfigurationException,
- ClientException,
- ExtraDependencyRequiredException,
-)
+from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException
logger = logging.getLogger("cms")
@@ -124,24 +118,6 @@ async def configuration_exception_handler(_: Request, exception: ConfigurationEx
logger.exception(exception)
return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)})
- @app.exception_handler(ExtraDependencyRequiredException)
- async def extra_dependency_exception_handler(
- _: Request,
- exception: ExtraDependencyRequiredException
- ) -> JSONResponse:
- """
- Handles extra dependency required exceptions.
-
- Args:
- _ (Request): The request object.
- exception (ExtraDependencyRequiredException): The extra dependency required exception.
-
- Returns:
- JSONResponse: A JSON response with a 500 status code and an error message.
- """
- logger.exception(exception)
- return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)})
-
@app.exception_handler(ClientException)
async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse:
"""
@@ -323,8 +299,8 @@ async def init_vllm_engine(app: FastAPI,
)
from vllm import SamplingParams, TokensPrompt
except ImportError:
- logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
- raise ExtraDependencyRequiredException("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
+ # Raise a custom exception if vLLM is not installed
+ raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.")
parser = FlexibleArgumentParser()
parser = make_arg_parser(parser)
diff --git a/app/exception.py b/app/exception.py
index ddba71b..99b87a6 100644
--- a/app/exception.py
+++ b/app/exception.py
@@ -32,7 +32,3 @@ class DatasetException(Exception):
class DeviceNotAvailableError(RuntimeError):
"""An exception raised when a specificy device is required but not available."""
-
-
-class ExtraDependencyRequiredException(Exception):
- """An exception raised when an extra dependency is required but not found."""
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index 98f9bea..13838d7 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -369,9 +369,8 @@ def create_embeddings(
sum_hidden_states = masked_hidden_states.sum(dim=1)
num_tokens = attention_mask.sum(dim=1, keepdim=True)
embeddings = sum_hidden_states / num_tokens
- l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1)
- results = l2_normalised.cpu().numpy().tolist()
+ results = embeddings.cpu().numpy().tolist()
return results[0] if isinstance(text, str) else results
def train_supervised(
diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py
index b41bdd1..49082d5 100644
--- a/app/trainers/huggingface_llm_trainer.py
+++ b/app/trainers/huggingface_llm_trainer.py
@@ -18,6 +18,7 @@
TrainerCallback,
TrainerState,
TrainerControl,
+ Trainer
)
from peft import LoraConfig, get_peft_model # type: ignore
from app.management.model_manager import ModelManager
@@ -40,7 +41,6 @@
DatasetException,
ConfigurationException,
DeviceNotAvailableError,
- ExtraDependencyRequiredException,
)
if TYPE_CHECKING:
from app.model_services.huggingface_llm_model import HuggingFaceLlmModel
@@ -93,6 +93,25 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
self._max_length = model_service.model.config.max_position_embeddings
os.makedirs(self._retrained_models_dir, exist_ok=True)
+ class _LocalDataCollator:
+
+ def __init__(self, max_length: int, pad_token_id: int) -> None:
+ self.max_length = max_length
+ self.pad_token_id = pad_token_id
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ return {
+ "input_ids": torch.tensor([self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) for f in features], dtype=torch.long),
+ "labels": torch.tensor([self._add_padding(f["labels"], self.max_length, HuggingFaceLlmSupervisedTrainer.PAD_LABEL_ID) for f in features], dtype=torch.long),
+ "attention_mask": torch.tensor([self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], dtype=torch.long),
+ }
+
+ @staticmethod
+ def _add_padding(target: List[int], max_length: int, pad_token_id: int) -> List[int]:
+ padding_length = max(0, max_length - len(target))
+ paddings = [pad_token_id] * padding_length
+ return target + paddings
+
def _load_dataset_from_config(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
"""
Loads training and validation datasets based on configuration in training_params.
@@ -137,6 +156,8 @@ def _set_dataset_format(train_dataset: datasets.Dataset, test_dataset: datasets.
else:
raise DatasetException("Unsupported dataset format")
+
+
def _load_huggingface_dataset(self, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
"""Loads dataset from HuggingFace Hub."""
@@ -299,12 +320,6 @@ def run(
if self._config.DEVICE is not Device.GPU.value:
raise DeviceNotAvailableError("This trainer currently requires a CUDA device")
- try:
- from trl import GRPOConfig, GRPOTrainer # , PPOConfig, PPOTrainer
- except ImportError:
- logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
- raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
-
copied_model_pack_path = None
redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true"
skip_save_model = self._config.SKIP_SAVE_MODEL == "true"
@@ -313,9 +328,6 @@ def run(
reset_random_seed()
eval_mode = training_params["nepochs"] == 0
self._tracker_client.log_trainer_mode(not eval_mode)
- trainer = None
- max_seq_length = 1024
-
if not eval_mode:
try:
logger.info("Loading a new model copy for training...")
@@ -356,27 +368,79 @@ def run(
model = get_peft_model(model, lora_config)
+ def extract_xml_answer(text: str) -> str:
+ answer = text.split("")[-1]
+ answer = answer.split("")[0]
+ return answer.strip()
+
+ # Reward functions
+ def correctness_reward_func(
+ prompts: List,
+ completions: List,
+ answer: List,
+ **kwargs: Dict[str, Any]
+ ) -> List[float]:
+ responses = [completion[0]['content'] for completion in completions]
+ q = prompts[0][-1]['content']
+ extracted_responses = [extract_xml_answer(r) for r in responses]
+ print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
+ f"\nExtracted:\n{extracted_responses[0]}")
+ return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+ def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ responses = [completion[0]['content'] for completion in completions]
+ extracted_responses = [extract_xml_answer(r) for r in responses]
+ return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
+ def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r"^\n.*?\n\n\n.*?\n\n$"
+ responses = [completion[0]["content"] for completion in completions]
+ matches = [re.match(pattern, r) for r in responses]
+ return [0.5 if match else 0.0 for match in matches]
+
+ def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r".*?\s*.*?"
+ responses = [completion[0]["content"] for completion in completions]
+ matches = [re.match(pattern, r) for r in responses]
+ return [0.5 if match else 0.0 for match in matches]
+
+ def count_xml(text: str) -> float:
+ count = 0.0
+ if text.count("\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ count -= len(text.split("\n\n")[-1]) * 0.001
+ if text.count("\n") == 1:
+ count += 0.125
+ count -= (len(text.split("\n")[-1]) - 1) * 0.001
+ return count
+
+ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
+ contents = [completion[0]["content"] for completion in completions]
+ return [count_xml(c) for c in contents]
+
mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback]
+ max_prompt_length = 256
+ max_seq_length = 1024
+
+ try:
+ from trl import GRPOConfig, GRPOTrainer #, PPOConfig, PPOTrainer
+ except ImportError:
+ logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
+
trainer_type = training_params.get("trainer_type", LlmTrainerType.GRPO.value).lower()
- max_prompt_length = max(train_dataset.map(
- lambda x: {
- "tokens": tokenizer.apply_chat_template(
- x["prompt"],
- add_generation_prompt=True,
- tokenize=True
- )
- },
- batched=True,
- ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1
if trainer_type == LlmTrainerType.PPO.value:
raise NotImplementedError("PPO training is not yet supported for HuggingFace LLM models")
elif trainer_type == LlmTrainerType.GRPO.value:
training_args = GRPOConfig(
- output_dir=results_path,
- logging_dir=logs_path,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
@@ -395,11 +459,18 @@ def run(
save_steps=250,
max_grad_norm=0.1,
report_to="none",
+ output_dir="outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
- reward_funcs=self._get_reward_functions(),
+ reward_funcs=[
+ xmlcount_reward_func,
+ soft_format_reward_func,
+ strict_format_reward_func,
+ int_reward_func,
+ correctness_reward_func,
+ ],
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
@@ -475,56 +546,41 @@ def run(
self._training_in_progress = False
self._clean_up_training_cache()
self._housekeep_file(copied_model_pack_path)
- if trainer is not None:
- del trainer
- gc.collect()
- torch.cuda.empty_cache()
+ del trainer
+ gc.collect()
+ torch.cuda.empty_cache()
else:
try:
logger.info("Evaluating the running model...")
- model, tokenizer = self._model_service.model, self._model_service.tokenizer
+ model, tokenizer = self._model_service.load_model(self._model_pack_path)
if non_default_device_is_available(self._config.DEVICE):
model.to(self._config.DEVICE)
eval_dataset, _ = self._load_dataset_from_config(data_file, training_params)
make_conversation = self._create_conversation_formatter(training_params)
eval_dataset = eval_dataset.map(make_conversation)
- max_prompt_length = max(eval_dataset.map(
- lambda x: {
- "tokens": tokenizer.apply_chat_template(
- x["prompt"],
- add_generation_prompt=True,
- tokenize=True
- )
- },
- batched=True,
- ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1
-
- training_args = GRPOConfig(
+
+ data_collator = self._LocalDataCollator(
+ max_length=self._max_length,
+ pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
+ )
+
+ training_args = TrainingArguments(
output_dir=results_path,
logging_dir=logs_path,
- per_device_eval_batch_size=6,
- num_generations=2,
- max_prompt_length=max_prompt_length,
- max_completion_length=max_seq_length - max_prompt_length,
- num_train_epochs=training_params["nepochs"],
- report_to="none",
+ per_device_eval_batch_size=1,
do_train=False,
do_eval=True,
+ report_to="none",
+ dataloader_drop_last=False,
)
- mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
- cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
- trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback]
-
- trainer = GRPOTrainer(
+ trainer = Trainer(
model=model,
- processing_class=tokenizer,
args=training_args,
- reward_funcs=self._get_reward_functions(),
- train_dataset=None,
+ data_collator=data_collator,
eval_dataset=eval_dataset,
- callbacks=trainer_callbacks,
+ tokenizer=tokenizer,
)
eval_metrics = trainer.evaluate()
@@ -532,24 +588,8 @@ def run(
self._tracker_client.send_hf_metrics_logs(eval_metrics, 0)
self._tracker_client.end_with_success()
logger.info("Model evaluation finished")
- except torch.OutOfMemoryError as e:
- logger.exception("Evaluation failed on CUDA OOM")
- try:
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- try:
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.reset_accumulated_memory_stats()
- except Exception:
- pass
- torch.cuda.synchronize()
- except Exception:
- pass
- self._tracker_client.log_exceptions(e)
- self._tracker_client.end_with_failure()
except Exception as e:
- logger.exception("Evaluation failed")
+ logger.exception("Model evaluation failed")
self._tracker_client.log_exceptions(e)
self._tracker_client.end_with_failure()
finally:
@@ -557,83 +597,6 @@ def run(
with self._training_lock:
self._training_in_progress = False
self._clean_up_training_cache()
- if trainer is not None:
- del trainer
- gc.collect()
- torch.cuda.empty_cache()
-
- @staticmethod
- def _get_reward_functions() -> List:
-
- def extract_xml_answer(text: str) -> str:
- answer = text.split("")[-1]
- answer = answer.split("")[0]
- return answer.strip()
-
- # Reward functions
- def correctness_reward_func(
- prompts: List,
- completions: List,
- answer: List,
- **kwargs: Dict[str, Any]
- ) -> List[float]:
- responses = [completion[0]['content'] for completion in completions]
- q = prompts[0][-1]['content']
- extracted_responses = [extract_xml_answer(r) for r in responses]
- logger.debug(
- "%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s",
- "-" * 20,
- q,
- answer[0],
- responses[0],
- extracted_responses[0]
- )
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-
- def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
- responses = [completion[0]['content'] for completion in completions]
- extracted_responses = [extract_xml_answer(r) for r in responses]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
-
- def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r"^\n.*?\n\n\n.*?\n\n$"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
-
- def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r".*?\s*.*?"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
-
- def count_xml(text: str) -> float:
- count = 0.0
- if text.count("\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
- count -= len(text.split("\n\n")[-1]) * 0.001
- if text.count("\n") == 1:
- count += 0.125
- count -= (len(text.split("\n")[-1]) - 1) * 0.001
- return count
-
- def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
- contents = [completion[0]["content"] for completion in completions]
- return [count_xml(c) for c in contents]
-
- return [
- xmlcount_reward_func,
- soft_format_reward_func,
- strict_format_reward_func,
- int_reward_func,
- correctness_reward_func,
- ]
@final
From 69da1e3c2efab802a6a9c44117527a76cdbe0df8 Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 4/9] Revert "feat: add the trainer for HF LLMs"
This reverts commit 994f88d55233a733274b34f2952538243811ed5f.
---
app/api/api.py | 24 +-
app/api/routers/generative.py | 5 +-
app/api/routers/supervised_training.py | 24 +-
app/api/utils.py | 4 +-
app/domain.py | 19 -
app/exception.py | 6 +-
app/model_services/huggingface_llm_model.py | 71 +-
app/processors/metrics_collector.py | 30 -
app/trainers/huggingface_llm_trainer.py | 674 ------------------
app/utils.py | 50 --
pyproject.toml | 4 -
tests/app/api/test_api.py | 4 +-
.../app/processors/test_metrics_collector.py | 54 --
uv.lock | 177 +----
14 files changed, 26 insertions(+), 1120 deletions(-)
delete mode 100644 app/trainers/huggingface_llm_trainer.py
diff --git a/app/api/api.py b/app/api/api.py
index 5dd1522..b9874ab 100644
--- a/app/api/api.py
+++ b/app/api/api.py
@@ -4,7 +4,7 @@
import os.path
import app.api.globals as cms_globals
-from typing import Dict, Any, Optional, Union, Type
+from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from anyio.lowlevel import RunVar
from anyio import CapacityLimiter
@@ -20,7 +20,7 @@
from app.api.dependencies import ModelServiceDep
from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine
from app.config import Settings
-from app.domain import Tags, TagsStreamable, TagsGenerative
+from app.domain import Tags, TagsStreamable
from app.management.tracker_client import TrackerClient
from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name
from app.exception import ConfigurationException
@@ -131,11 +131,6 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi
app = _load_health_check_router(app)
logger.debug("Health check router loaded")
- if config.ENABLE_TRAINING_APIS == "true":
- app = _load_supervised_training_router(app)
- logger.debug("Supervised training router loaded")
- app = _load_training_operations(app)
-
if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)
logger.debug("Auth router loaded")
@@ -203,18 +198,11 @@ def _get_app(
streamable: bool = False,
generative: bool = False,
) -> FastAPI:
- config = get_settings()
- tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]]
- if generative:
- tags = TagsGenerative
- elif streamable:
- tags = TagsStreamable
- else:
- tags = Tags
tags_metadata = [{ # type: ignore
- "name": tag.name, # type: ignore
- "description": tag.value # type: ignore
- } for tag in tags]
+ "name": tag.name,
+ "description": tag.value
+ } for tag in (Tags if not streamable else TagsStreamable)]
+ config = get_settings()
app = FastAPI(
title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py
index b2d9454..26f4fd1 100644
--- a/app/api/routers/generative.py
+++ b/app/api/routers/generative.py
@@ -13,7 +13,6 @@
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from app.domain import (
Tags,
- TagsGenerative,
OpenAIChatRequest,
OpenAIChatResponse,
OpenAIEmbeddingsRequest,
@@ -42,7 +41,7 @@
@router.post(
PATH_GENERATE,
- tags=[TagsGenerative.Generative],
+ tags=[Tags.Generative.name],
response_class=PlainTextResponse,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Generate text",
@@ -92,7 +91,7 @@ def generate_text(
@router.post(
PATH_GENERATE_ASYNC,
- tags=[TagsGenerative.Generative],
+ tags=[Tags.Generative.name],
response_class=StreamingResponse,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Generate a stream of texts",
diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py
index 1ee7c9b..89d0d17 100644
--- a/app/api/routers/supervised_training.py
+++ b/app/api/routers/supervised_training.py
@@ -12,9 +12,9 @@
import app.api.globals as cms_globals
from app.api.dependencies import validate_tracking_id
-from app.domain import Tags, ModelType
+from app.domain import Tags
from app.model_services.base import AbstractModelService
-from app.processors.metrics_collector import concat_json_lists, concat_trainer_exports
+from app.processors.metrics_collector import concat_trainer_exports
from app.utils import filter_by_concept_ids
router = APIRouter()
@@ -72,19 +72,12 @@ async def train_supervised(
files.append(temp_te)
file_names.append("" if te.filename is None else te.filename)
- if model_service.info().model_type is not ModelType.HUGGINGFACE_LLM:
- concatenated_te = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
- logger.debug("Training exports concatenated")
- data_file = tempfile.NamedTemporaryFile(mode="w+")
- concatenated_te = filter_by_concept_ids(cast(Dict[str, Any], concatenated_te), model_service.info().model_type)
- logger.debug("Training exports filtered by concept IDs")
- json.dump(concatenated_te, data_file)
- else:
- concatenated = concat_json_lists([file.name for file in files])
- logger.debug("Training exports concatenated")
- data_file = tempfile.NamedTemporaryFile(mode="w+")
- json.dump(concatenated, data_file)
-
+ concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
+ logger.debug("Training exports concatenated")
+ data_file = tempfile.NamedTemporaryFile(mode="w")
+ concatenated = filter_by_concept_ids(cast(Dict[str, Any], concatenated), model_service.info().model_type)
+ logger.debug("Training exports filtered by concept IDs")
+ json.dump(concatenated, data_file)
data_file.flush()
data_file.seek(0)
training_id = tracking_id or str(uuid.uuid4())
@@ -109,7 +102,6 @@ async def train_supervised(
return _get_training_response(training_response, training_id)
-
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
training_accepted, experiment_id, run_id = training_response
if training_accepted:
diff --git a/app/api/utils.py b/app/api/utils.py
index e73425d..05726ba 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -26,7 +26,7 @@
from slowapi.errors import RateLimitExceeded
from fastapi_users.jwt import decode_jwt
from app.config import Settings
-from app.domain import TagsGenerative
+from app.domain import Tags
from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException
logger = logging.getLogger("cms")
@@ -376,7 +376,7 @@ async def _stream() -> AsyncGenerator[bytes, None]:
endpoint=endpoint,
methods=methods,
include_in_schema=True,
- tags=[TagsGenerative.Generative.name],
+ tags=[Tags.Generative],
)
app.include_router(router)
diff --git a/app/domain.py b/app/domain.py
index 6be1564..de098a8 100644
--- a/app/domain.py
+++ b/app/domain.py
@@ -31,15 +31,9 @@ class Tags(str, Enum):
class TagsStreamable(str, Enum):
- Metadata = "Get the model card"
Streaming = "Retrieve NER entities as a stream by running the model"
-class TagsGenerative(str, Enum):
- Metadata = "Get the model card"
- Generative = "Generate text based on the input prompt"
-
-
class CodeType(str, Enum):
SNOMED = "SNOMED"
UMLS = "UMLS"
@@ -110,19 +104,6 @@ class LlmEngine(Enum):
CMS = "CMS"
VLLM = "vLLM"
-class LlmRole(Enum):
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
-
-class LlmTrainerType(Enum):
- GRPO = "grpo"
- PPO = "ppo"
-
-class LlmDatasetType(Enum):
- JSON = "json"
- CSV = "csv"
class Annotation(BaseModel):
doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs")
diff --git a/app/exception.py b/app/exception.py
index 99b87a6..1b8f9bc 100644
--- a/app/exception.py
+++ b/app/exception.py
@@ -27,8 +27,4 @@ class ClientException(Exception):
class DatasetException(Exception):
- """An exception raised due to dataset errors"""
-
-
-class DeviceNotAvailableError(RuntimeError):
- """An exception raised when a specificy device is required but not available."""
+ """ An exception raised due to dataset errors"""
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index 13838d7..25cd032 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -3,19 +3,17 @@
import asyncio
import torch
from concurrent.futures import ThreadPoolExecutor
-from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, TextIO, Callable, Union
+from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
TextIteratorStreamer,
- BitsAndBytesConfig,
)
from app import __version__ as app_version
from app.exception import ConfigurationException
from app.model_services.base import AbstractModelService
-from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
from app.domain import ModelCard, ModelType, Annotation
from app.config import Settings
from app.utils import (
@@ -125,19 +123,13 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase)
return model_service
@staticmethod
- def load_model(
- model_file_path: str,
- *args: Tuple,
- load_in_4bit: bool = False,
- **kwargs: Dict[str, Any]
- ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
+ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
"""
Loads a pre-trained model and its tokenizer from a model package file.
Args:
model_file_path (str): The path to the model package file.
*args (Tuple): Additional positional arguments.
- load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False.
**kwargs (Dict[str, Any]): Additional keyword arguments.
Returns:
@@ -150,16 +142,7 @@ def load_model(
model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path))
if unpack_model_data_package(model_file_path, model_path):
try:
- if load_in_4bit:
- bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- )
- model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config)
- else:
- model = AutoModelForCausalLM.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(model_path)
ensure_tensor_contiguity(model)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
@@ -189,7 +172,7 @@ def init_model(self) -> None:
if non_default_device_is_available(get_settings().DEVICE):
self._model.to(get_settings().DEVICE)
if self._enable_trainer:
- self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self)
+ logger.error("Trainers are not yet implemented for HuggingFace Generative models")
def info(self) -> ModelCard:
"""
@@ -372,49 +355,3 @@ def create_embeddings(
results = embeddings.cpu().numpy().tolist()
return results[0] if isinstance(text, str) else results
-
- def train_supervised(
- self,
- data_file: TextIO,
- epochs: int,
- log_frequency: int,
- training_id: str,
- input_file_name: str,
- raw_data_files: Optional[List[TextIO]] = None,
- description: Optional[str] = None,
- synchronised: bool = False,
- **hyperparams: Dict[str, Any],
- ) -> Tuple[bool, str, str]:
- """
- Initiates supervised training on the model.
-
- Args:
- data_file (TextIO): The file containing the trainer export data.
- epochs (int): The number of training epochs.
- log_frequency (int): The number of epochs after which training metrics will be logged.
- training_id (str): A unique identifier for the training process.
- input_file_name (str): The name of the input file to be logged.
- raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None.
- description (Optional[str]): The description of the training or change logs. Defaults to empty.
- synchronised (bool): Whether to wait for the training to complete.
- **hyperparams (Dict[str, Any]): Additional hyperparameters for training.
-
- Returns:
- Tuple[bool, str, str]: A tuple with the first element indicating success or failure.
-
- Raises:
- ConfigurationException: If the supervised trainer is not enabled.
- """
- if self._supervised_trainer is None:
- raise ConfigurationException("The supervised trainer is not enabled")
- return self._supervised_trainer.train(
- data_file,
- epochs,
- log_frequency,
- training_id,
- input_file_name,
- raw_data_files,
- description,
- synchronised,
- **hyperparams,
- )
diff --git a/app/processors/metrics_collector.py b/app/processors/metrics_collector.py
index 84f74da..07f0592 100644
--- a/app/processors/metrics_collector.py
+++ b/app/processors/metrics_collector.py
@@ -194,36 +194,6 @@ def concat_trainer_exports(
return combined
-def concat_json_lists(
- data_file_paths: List[str],
- combined_data_file_path: Optional[str] = None,
-) -> Union[List[Dict[str, Any]], str]:
- """
- Concatenates multiple json list files into a single combined file.
-
- Args:
- data_file_paths (List[str]): List of paths to files each containing a json list.
- combined_data_file_path (Optional[str]): The file path where the combined data will be saved. If None, the combined data will be returned as a list.
-
-
- Returns:
- Union[List[Dict[str, Any]], str]: The path to the combined data file if `combined_data_file_path` is provided, or the combined data as a list otherwise.
- """
- combined: List = []
- for path in data_file_paths:
- with open(path, "r") as f:
- data = json.load(f)
- combined.extend(data)
-
- if isinstance(combined_data_file_path, str):
- with open(combined_data_file_path, "w") as f:
- json.dump(combined, f)
-
- return combined_data_file_path
- else:
- return combined
-
-
def get_stats_from_trainer_export(
trainer_export: Union[str, IO, Dict],
return_df: bool = False,
diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py
deleted file mode 100644
index 49082d5..0000000
--- a/app/trainers/huggingface_llm_trainer.py
+++ /dev/null
@@ -1,674 +0,0 @@
-import os
-import logging
-import math
-import torch
-import gc
-import datasets
-import re
-import threading
-import json
-import pandas as pd
-from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable
-from transformers import __version__ as transformers_version
-from transformers import (
- TrainingArguments,
- PreTrainedModel,
- PreTrainedTokenizerBase,
- PreTrainedTokenizerFast,
- TrainerCallback,
- TrainerState,
- TrainerControl,
- Trainer
-)
-from peft import LoraConfig, get_peft_model # type: ignore
-from app.management.model_manager import ModelManager
-from app.management.tracker_client import TrackerClient
-from app.utils import (
- reset_random_seed,
- non_default_device_is_available,
- create_model_data_package,
- get_model_data_package_extension,
- load_pydantic_object_from_dict,
- get_default_chat_template,
- get_default_system_prompt,
- get_model_data_package_base_name,
-)
-from app.trainers.base import SupervisedTrainer
-from app.domain import ModelType, TrainerBackend, LlmRole, LlmTrainerType, LlmDatasetType, PromptMessage, Device
-from app.exception import (
- TrainingCancelledException,
- ManagedModelException,
- DatasetException,
- ConfigurationException,
- DeviceNotAvailableError,
-)
-if TYPE_CHECKING:
- from app.model_services.huggingface_llm_model import HuggingFaceLlmModel
-
-logger = logging.getLogger("cms")
-
-
-class _HuggingFaceLlmTrainerCommon(object):
-
- @staticmethod
- def deploy_model(
- model_service: "HuggingFaceLlmModel",
- model: PreTrainedModel,
- tokenizer: PreTrainedTokenizerBase,
- ) -> None:
- del model_service.model
- del model_service.tokenizer
- gc.collect()
- model_service.model = model
- model_service.tokenizer = tokenizer
- logger.info("Retrained model deployed")
-
-
-@final
-class HuggingFaceLlmSupervisedTrainer(SupervisedTrainer, _HuggingFaceLlmTrainerCommon):
- """
- A supervised trainer class for HuggingFace LLM models.
-
- Args:
- model_service (HuggingFaceLlmModel): An instance of the HuggingFace LLM model service.
- """
-
- MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT = 5
- MAX_CONCEPTS_TO_TRACK = 20
- PAD_LABEL_ID = -100
- DEFAULT_LABEL_ID = 0
- CONTINUING_TOKEN_LABEL_ID = 1
-
- def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
- if not isinstance(model_service.tokenizer, PreTrainedTokenizerFast):
- logger.error("The supervised trainer requires a fast tokenizer to function correctly")
- raise ManagedModelException("The supervised trainer requires a fast tokenizer to function correctly")
- SupervisedTrainer.__init__(self, model_service._config, model_service.model_name)
- self._model_service = model_service
- self._model_name = model_service.model_name
- self._model_pack_path = model_service._model_pack_path
- self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained",
- self._model_name.replace(" ", "_"))
- self._model_manager = ModelManager(type(model_service), model_service._config)
- self._max_length = model_service.model.config.max_position_embeddings
- os.makedirs(self._retrained_models_dir, exist_ok=True)
-
- class _LocalDataCollator:
-
- def __init__(self, max_length: int, pad_token_id: int) -> None:
- self.max_length = max_length
- self.pad_token_id = pad_token_id
-
- def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
- return {
- "input_ids": torch.tensor([self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) for f in features], dtype=torch.long),
- "labels": torch.tensor([self._add_padding(f["labels"], self.max_length, HuggingFaceLlmSupervisedTrainer.PAD_LABEL_ID) for f in features], dtype=torch.long),
- "attention_mask": torch.tensor([self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], dtype=torch.long),
- }
-
- @staticmethod
- def _add_padding(target: List[int], max_length: int, pad_token_id: int) -> List[int]:
- padding_length = max(0, max_length - len(target))
- paddings = [pad_token_id] * padding_length
- return target + paddings
-
- def _load_dataset_from_config(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
- """
- Loads training and validation datasets based on configuration in training_params.
-
- Args:
- data_file: The training data file
- training_params: Dictionary containing dataset configuration
-
- Returns:
- Tuple of (train_dataset, validation_dataset)
- """
- dataset_type = training_params.get("dataset_type", "json")
-
- # if dataset_type == "huggingface":
- # return self._load_huggingface_dataset(training_params)
- if dataset_type == LlmDatasetType.JSON.value:
- return self._load_json_dataset(data_file, training_params)
- elif dataset_type == LlmDatasetType.CSV.value:
- return self._load_csv_dataset(data_file, training_params)
- else:
- raise DatasetException(f"Unsupported dataset type: {dataset_type}")
-
- @staticmethod
- def _set_dataset_format(train_dataset: datasets.Dataset, test_dataset: datasets.Dataset) -> None:
- """Sets the format of the datasets based on the dataset structure."""
-
- if "messages" in train_dataset.column_names:
- train_dataset.set_format(type=None, columns=["messages"])
- test_dataset.set_format(type=None, columns=["messages"])
- elif "question" in train_dataset.column_names and "answer" in train_dataset.column_names:
- train_dataset.set_format(type=None, columns=["question", "answer"])
- test_dataset.set_format(type=None, columns=["question", "answer"])
- elif "input" in train_dataset.column_names and "output" in train_dataset.column_names:
- train_dataset.set_format(type=None, columns=["input", "output"])
- test_dataset.set_format(type=None, columns=["input", "output"])
- elif "prompt" in train_dataset.column_names and "completion" in train_dataset.column_names:
- train_dataset.set_format(type=None, columns=["prompt", "completion"])
- test_dataset.set_format(type=None, columns=["prompt", "completion"])
- elif "problem" in train_dataset.column_names and "solution" in train_dataset.column_names:
- train_dataset.set_format(type=None, columns=["problem", "solution"])
- test_dataset.set_format(type=None, columns=["problem", "solution"])
- else:
- raise DatasetException("Unsupported dataset format")
-
-
-
- def _load_huggingface_dataset(self, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
- """Loads dataset from HuggingFace Hub."""
-
- dataset_id = training_params.get("dataset_id", "AI-MO/NuminaMath-TIR")
- test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"]
- split_ratio = 1 - test_size
- train_percentage = int(split_ratio * 100)
- test_percentage = 100 - train_percentage
- train_split = training_params.get("train_split", f"train[:{train_percentage}%]")
- test_split = training_params.get("test_split", f"test[:{test_percentage}%]")
-
- logger.info(f"Loading HuggingFace dataset: {dataset_id}")
- train_dataset, test_dataset = datasets.load_dataset(dataset_id, split=[train_split, test_split])
- self._set_dataset_format(train_dataset, test_dataset)
-
- return train_dataset, test_dataset
-
-
- def _load_json_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
- """Loads dataset from JSON file."""
-
- data = json.load(data_file)
- test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"]
- split_ratio = 1 - test_size
-
- if isinstance(data, list):
- examples = data
- split_idx = int(len(examples) * split_ratio)
- train_examples = examples[:split_idx]
- test_examples = examples[split_idx:]
- elif isinstance(data, dict) and "train" in data and "test" in data:
- train_examples = data["train"]
- test_examples = data["test"]
- elif isinstance(data, dict) and "examples" in data:
- examples = data["examples"]
- split_idx = int(len(examples) * split_ratio)
- train_examples = examples[:split_idx]
- test_examples = examples[split_idx:]
- else:
- raise DatasetException("Unsupported JSON format")
-
- train_dataset = datasets.Dataset.from_list(train_examples)
- test_dataset = datasets.Dataset.from_list(test_examples)
- self._set_dataset_format(train_dataset, test_dataset)
-
- return train_dataset, test_dataset
-
- def _load_csv_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]:
- """Loads dataset from CSV file."""
-
- df = pd.read_csv(data_file)
- test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"]
- split_ratio = 1 - test_size
- split_idx = int(len(df) * split_ratio)
-
- train_df = df[:split_idx]
- test_df = df[split_idx:]
-
- train_dataset = datasets.Dataset.from_pandas(train_df)
- test_dataset = datasets.Dataset.from_pandas(test_df)
- self._set_dataset_format(train_dataset, test_dataset)
-
- return train_dataset, test_dataset
-
- def _create_conversation_formatter(self, training_params: Dict) -> Callable:
- """
- Creates a conversation formatter based on training parameters.
-
- Args:
- training_params: Dictionary containing formatting configuration
-
- Returns:
- Function that formats examples into conversations
- """
- format_config = training_params.get("format_config", {})
- system_prompt = format_config.get("system_prompt", get_default_system_prompt())
-
- def make_conversation(example: Dict[str, Any]) -> Dict[str, Any]:
- # Handle different input formats
- if "messages" in example:
- system_content = None
- question_content = None
- answer_content = None
- for message in example.get("messages", []):
- msg = load_pydantic_object_from_dict(PromptMessage, message)
- if msg.role == LlmRole.SYSTEM:
- system_content = msg.content
- elif msg.role == LlmRole.USER:
- question_content = msg.content
- elif msg.role == LlmRole.ASSISTANT:
- answer_content = msg.content
-
- return {
- "prompt": [
- {"role": "system", "content": system_prompt if system_content is None else system_content},
- {"role": "user", "content": question_content if question_content is not None else ""},
- ],
- "answer": answer_content if answer_content is not None else "",
- }
- elif "question" in example and "answer" in example:
- # Question/Answer format
- return {
- "prompt": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example.get("question")},
- ],
- "answer": example["answer"],
- }
- elif "input" in example and "output" in example:
- # Input/Output format
- return {
- "prompt": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example.get("input")},
- ],
- "answer": example["output"],
- }
- elif "prompt" in example and "completion" in example:
- # Prompt/Completion format
- return {
- "prompt": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example.get("prompt")},
- ],
- "answer": example["completion"],
- }
- elif "problem" in example and "solution" in example:
- # Problem/Solution format
- return {
- "prompt": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example.get("problem")},
- ],
- "answer": example["solution"],
- }
- else:
- raise DatasetException(f"Cannot determine the conversation format from example: {example}")
-
- return make_conversation
-
- def run(
- self,
- training_params: Dict,
- data_file: TextIO,
- log_frequency: int,
- run_id: str,
- description: Optional[str] = None,
- ) -> None:
- """
- Runs the supervised training loop for HuggingFace LLM models.
-
- Args:
- training_params (Dict): A dictionary containing parameters for the training.
- data_file (Union[TextIO, tempfile.TemporaryDirectory]): The file-like object or temporary directory containing the training data.
- log_frequency (int): The frequency at which logs should be recorded (e.g, the number of processed documents or finished epochs).
- run_id (str): The run ID of the training job.
- description (Optional[str]): The optional description of the training or change logs.
- """
-
- if self._config.DEVICE is not Device.GPU.value:
- raise DeviceNotAvailableError("This trainer currently requires a CUDA device")
-
- copied_model_pack_path = None
- redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true"
- skip_save_model = self._config.SKIP_SAVE_MODEL == "true"
- results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results"))
- logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs"))
- reset_random_seed()
- eval_mode = training_params["nepochs"] == 0
- self._tracker_client.log_trainer_mode(not eval_mode)
- if not eval_mode:
- try:
- logger.info("Loading a new model copy for training...")
- copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id)
- model, tokenizer = self._model_service.load_model(
- copied_model_pack_path,
- load_in_4bit=True, # for memory efficient training
- )
- copied_model_directory = os.path.join(
- os.path.dirname(copied_model_pack_path),
- get_model_data_package_base_name(copied_model_pack_path),
- )
-
- if non_default_device_is_available(self._config.DEVICE):
- model.to(self._config.DEVICE)
-
- train_dataset, test_dataset = self._load_dataset_from_config(data_file, training_params)
- make_conversation = self._create_conversation_formatter(training_params)
- train_dataset = train_dataset.map(make_conversation)
- test_dataset = test_dataset.map(make_conversation)
-
- if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None:
- logger.warning("The tokenizer does not have a chat template. Using the default one.")
- tokenizer.chat_template = get_default_chat_template()
- else:
- logger.debug(f"Found a chat template in the tokenizer:\n {tokenizer.chat_template}")
-
- lora_config = LoraConfig(
- task_type="CAUSAL_LM",
- r=8,
- lora_alpha=32,
- lora_dropout=0.1,
- target_modules=[
- "q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",
- ],
- )
-
- model = get_peft_model(model, lora_config)
-
- def extract_xml_answer(text: str) -> str:
- answer = text.split("")[-1]
- answer = answer.split("")[0]
- return answer.strip()
-
- # Reward functions
- def correctness_reward_func(
- prompts: List,
- completions: List,
- answer: List,
- **kwargs: Dict[str, Any]
- ) -> List[float]:
- responses = [completion[0]['content'] for completion in completions]
- q = prompts[0][-1]['content']
- extracted_responses = [extract_xml_answer(r) for r in responses]
- print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
- f"\nExtracted:\n{extracted_responses[0]}")
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-
- def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
- responses = [completion[0]['content'] for completion in completions]
- extracted_responses = [extract_xml_answer(r) for r in responses]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
-
- def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r"^\n.*?\n\n\n.*?\n\n$"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
-
- def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r".*?\s*.*?"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
-
- def count_xml(text: str) -> float:
- count = 0.0
- if text.count("\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
- count -= len(text.split("\n\n")[-1]) * 0.001
- if text.count("\n") == 1:
- count += 0.125
- count -= (len(text.split("\n")[-1]) - 1) * 0.001
- return count
-
- def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
- contents = [completion[0]["content"] for completion in completions]
- return [count_xml(c) for c in contents]
-
- mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
- cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
- trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback]
-
- max_prompt_length = 256
- max_seq_length = 1024
-
- try:
- from trl import GRPOConfig, GRPOTrainer #, PPOConfig, PPOTrainer
- except ImportError:
- logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
-
- trainer_type = training_params.get("trainer_type", LlmTrainerType.GRPO.value).lower()
- if trainer_type == LlmTrainerType.PPO.value:
- raise NotImplementedError("PPO training is not yet supported for HuggingFace LLM models")
- elif trainer_type == LlmTrainerType.GRPO.value:
- training_args = GRPOConfig(
- learning_rate=5e-6,
- adam_beta1=0.9,
- adam_beta2=0.99,
- weight_decay=0.1,
- warmup_ratio=0.1,
- lr_scheduler_type="cosine",
- optim="paged_adamw_8bit",
- logging_steps=1,
- per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations
- gradient_accumulation_steps=1,
- num_generations=6,
- max_prompt_length=max_prompt_length,
- max_completion_length=max_seq_length - max_prompt_length,
- num_train_epochs = training_params["nepochs"],
- max_steps=250,
- save_steps=250,
- max_grad_norm=0.1,
- report_to="none",
- output_dir="outputs",
- )
- trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- reward_funcs=[
- xmlcount_reward_func,
- soft_format_reward_func,
- strict_format_reward_func,
- int_reward_func,
- correctness_reward_func,
- ],
- args=training_args,
- train_dataset=train_dataset,
- eval_dataset=test_dataset,
- callbacks=trainer_callbacks,
- )
- else:
- raise ConfigurationException(f"Unsupported trainer type: {trainer_type}")
-
- self._tracker_client.log_model_config(model.config.to_dict())
- self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version)
-
- logger.info(f"Performing {trainer_type.upper()} training...")
- trainer.train()
-
- if cancel_event_check_callback.training_cancelled:
- raise TrainingCancelledException("Training was cancelled by the user")
-
- if not skip_save_model:
- model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE)
- model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}"
- retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name)
- model.save_pretrained(
- copied_model_directory,
- safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"),
- )
- create_model_data_package(copied_model_directory, retrained_model_pack_path)
- model_uri = self._tracker_client.save_model(
- retrained_model_pack_path,
- self._model_name,
- self._model_manager,
- )
- logger.info(f"Retrained model saved: {model_uri}")
- else:
- logger.info("Skipped saving on the retrained model")
- if redeploy:
- self.deploy_model(self._model_service, model, tokenizer)
- else:
- del model
- del tokenizer
- gc.collect()
- logger.info("Skipped deployment on the retrained model")
- logger.info("Supervised training finished")
- self._tracker_client.end_with_success()
- except TrainingCancelledException as e:
- logger.exception(e)
- logger.info("Supervised training was cancelled")
- del model
- gc.collect()
- self._tracker_client.end_with_interruption()
- except torch.OutOfMemoryError as e:
- logger.exception("Supervised training failed on CUDA OOM")
- try:
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- try:
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.reset_accumulated_memory_stats()
- except Exception:
- pass
- torch.cuda.synchronize()
- except Exception:
- pass
- self._tracker_client.log_exceptions(e)
- self._tracker_client.end_with_failure()
- except Exception as e:
- logger.exception("Supervised training failed")
- self._tracker_client.log_exceptions(e)
- self._tracker_client.end_with_failure()
- finally:
- data_file.close()
- with self._training_lock:
- self._training_in_progress = False
- self._clean_up_training_cache()
- self._housekeep_file(copied_model_pack_path)
- del trainer
- gc.collect()
- torch.cuda.empty_cache()
- else:
- try:
- logger.info("Evaluating the running model...")
- model, tokenizer = self._model_service.load_model(self._model_pack_path)
- if non_default_device_is_available(self._config.DEVICE):
- model.to(self._config.DEVICE)
-
- eval_dataset, _ = self._load_dataset_from_config(data_file, training_params)
- make_conversation = self._create_conversation_formatter(training_params)
- eval_dataset = eval_dataset.map(make_conversation)
-
- data_collator = self._LocalDataCollator(
- max_length=self._max_length,
- pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
- )
-
- training_args = TrainingArguments(
- output_dir=results_path,
- logging_dir=logs_path,
- per_device_eval_batch_size=1,
- do_train=False,
- do_eval=True,
- report_to="none",
- dataloader_drop_last=False,
- )
-
- trainer = Trainer(
- model=model,
- args=training_args,
- data_collator=data_collator,
- eval_dataset=eval_dataset,
- tokenizer=tokenizer,
- )
-
- eval_metrics = trainer.evaluate()
- logger.info(f"Evaluation metrics: {eval_metrics}")
- self._tracker_client.send_hf_metrics_logs(eval_metrics, 0)
- self._tracker_client.end_with_success()
- logger.info("Model evaluation finished")
- except Exception as e:
- logger.exception("Model evaluation failed")
- self._tracker_client.log_exceptions(e)
- self._tracker_client.end_with_failure()
- finally:
- data_file.close()
- with self._training_lock:
- self._training_in_progress = False
- self._clean_up_training_cache()
-
-
-@final
-class MLflowLoggingCallback(TrainerCallback):
- """
- A callback class for logging training metrics to MLflow.
-
- Args:
- tracker_client (TrackerClient): An instance of TrackerClient used for logging.
- """
-
- def __init__(self, tracker_client: TrackerClient) -> None:
- self.tracker_client = tracker_client
- self.epoch = 0
-
- def on_log(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- logs: Dict[str, float],
- **kwargs: Dict[str, Any],
- ) -> None:
- """
- Logs metrics at the end of each epoch.
-
- Args:
- args (TrainingArguments): The arguments used for training.
- state (TrainerState): The current state of the Trainer.
- control (TrainerControl): The current control of the Trainer.
- logs (Dict[str, float]): A dictionary containing the metrics to log.
- **kwargs (Dict[str, Any]): Additional keyword arguments.
- """
-
- if logs is not None:
- if logs.get("eval_loss", None) is not None:
- logs["perplexity"] = math.exp(logs["eval_loss"])
- self.tracker_client.send_hf_metrics_logs(logs, self.epoch)
- self.epoch += 1
-
-
-@final
-class CancelEventCheckCallback(TrainerCallback):
- """
- A callback class for checking a cancellation event during training.
-
- Args:
- cancel_event (threading.Event): A threading event that signals whether training should be cancelled.
- """
-
- def __init__(self, cancel_event: threading.Event) -> None:
- self.cancel_event = cancel_event
- self.training_cancelled = False
-
- def on_step_end(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs: Dict[str, Any],
- ) -> None:
- """
- Checks if the training should be cancelled at the end of each training step.
-
- Args:
- args (TrainingArguments): The arguments used for training.
- state (TrainerState): The current state of the Trainer.
- control (TrainerControl): The current control of the Trainer.
- **kwargs (Dict[str, Any]): Additional keyword arguments.
- """
-
- if self.cancel_event.is_set():
- control.should_training_stop = True
- self.cancel_event.clear()
- self.training_cancelled = True
diff --git a/app/utils.py b/app/utils.py
index 370c4e2..fcc231e 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -740,55 +740,6 @@ def download_model_package(
retry_delay *= 2
-def get_default_chat_template() -> str:
- """
- Gets the default chat template.
-
- Returns:
- str: The default chat template.
- """
-
- return (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'] %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = false %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 and system_message != false %}"
- "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ '[INST] ' + content + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ ' ' + content + ' ' }}"
- "{% endif %}"
- "{% endfor %}"
- )
-
-
-def get_default_system_prompt() -> str:
- """
- Gets the default system prompt.
-
- Returns:
- str: The default system prompt.
- """
- return (
- "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
- "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
- "process and answer are enclosed within and tags, respectively, i.e., "
- " reasoning process here answer here "
- )
-
-
def get_prompt_from_messages(
tokenizer: PreTrainedTokenizer,
messages: List[PromptMessage],
@@ -845,7 +796,6 @@ def get_prompt_from_messages(
)
return prompt
-
TYPE_ID_TO_NAME_PATCH = {
"32816260": "physical object",
"2680757": "observable entity",
diff --git a/pyproject.toml b/pyproject.toml
index 0aaaaea..eaba606 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -82,8 +82,6 @@ docs = [
vllm = [
"vllm~=0.8.5; python_version >= '3.9'",
- "trl>=0.11.4",
- "bitsandbytes>=0.45.5",
]
# For pip versions not supporting PEP 735
@@ -113,8 +111,6 @@ docs = [
vllm = [
"vllm~=0.8.5; python_version >= '3.9'",
- "trl>=0.11.4",
- "bitsandbytes>=0.45.5",
]
[tool.setuptools]
diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py
index 1d5f077..1a0399b 100644
--- a/tests/app/api/test_api.py
+++ b/tests/app/api/test_api.py
@@ -27,6 +27,7 @@ def test_get_model_server():
assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags
assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags
assert {"name": "Authentication", "description": "Authenticate registered users"} in tags
+ assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags
assert "/info" in paths
assert "/process" in paths
assert "/process_jsonl" in paths
@@ -90,8 +91,7 @@ def test_get_generative_server():
assert isinstance(info["title"], str)
assert isinstance(info["summary"], str)
assert isinstance(info["version"], str)
- assert {"name": "Metadata", "description": "Get the model card"} in tags
- assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags
+ assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags
assert "/info" in paths
assert "/generate" in paths
assert "/stream/generate" in paths
diff --git a/tests/app/processors/test_metrics_collector.py b/tests/app/processors/test_metrics_collector.py
index 6e53275..bd46fff 100644
--- a/tests/app/processors/test_metrics_collector.py
+++ b/tests/app/processors/test_metrics_collector.py
@@ -7,7 +7,6 @@
from app.processors.metrics_collector import (
sanity_check_model_with_trainer_export,
concat_trainer_exports,
- concat_json_lists,
get_stats_from_trainer_export,
get_iaa_scores_per_concept,
get_iaa_scores_per_doc,
@@ -333,56 +332,3 @@ def test_get_iaa_scores_per_span_and_return_dataframe():
assert len(result["cohens_kappa"]) == 30
assert len(result["iaa_percentage_meta"]) == 30
assert len(result["cohens_kappa_meta"]) == 30
-
-
-def test_concat_json_lists_return_list():
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1:
- json.dump([{"question": "question_1", "answer": "answer_1"}, {"question": "question_2", "answer": "answer_2"}], f1)
- file1_path = f1.name
-
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2:
- json.dump([{"question": "question_3", "answer": "answer_3"}], f2)
- file2_path = f2.name
-
- try:
- result = concat_json_lists([file1_path, file2_path])
-
- assert isinstance(result, list)
- assert len(result) == 3
- assert result[0] == {"question": "question_1", "answer": "answer_1"}
- assert result[1] == {"question": "question_2", "answer": "answer_2"}
- assert result[2] == {"question": "question_3", "answer": "answer_3"}
- finally:
- os.unlink(file1_path)
- os.unlink(file2_path)
-
-
-def test_concat_json_lists_save_to_file():
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1:
- json.dump([{"question": "question_1", "answer": "answer_1"}], f1)
- file1_path = f1.name
-
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2:
- json.dump([{"question": "question_2", "answer": "answer_2"}], f2)
- file2_path = f2.name
-
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as output_file:
- output_path = output_file.name
-
- try:
- result = concat_json_lists([file1_path, file2_path], output_path)
-
- assert isinstance(result, str)
- assert result == output_path
-
- with open(output_path, 'r') as f:
- saved_data = json.load(f)
-
- assert isinstance(saved_data, list)
- assert len(saved_data) == 2
- assert saved_data[0] == {"question": "question_1", "answer": "answer_1"}
- assert saved_data[1] == {"question": "question_2", "answer": "answer_2"}
- finally:
- os.unlink(file1_path)
- os.unlink(file2_path)
- os.unlink(output_path)
diff --git a/uv.lock b/uv.lock
index 255a329..cc97b01 100644
--- a/uv.lock
+++ b/uv.lock
@@ -734,55 +734,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d6/59/831b66ba317496332d4e9e1a33bcdd14922d6cfecc411dc315a229b67127/bcrypt-4.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba4e4cc26610581a6329b3937e02d319f5ad4b85b074846bf4fef8a8cf51e7bb", size = 698384 },
]
-[[package]]
-name = "bitsandbytes"
-version = "0.45.5"
-source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version < '3.9' and sys_platform != 'win32'",
- "python_full_version < '3.9' and sys_platform == 'win32'",
-]
-dependencies = [
- { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
-]
-wheels = [
- { url = "https://files.pythonhosted.org/packages/07/b7/cb5ce4d1a382cf53c19ef06c5fc29e85f5e129b4da6527dd207d90a5b8ad/bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a5453f30cc6aab6ccaac364e6bf51a7808d3da5f71763dffeb6d9694c59136e4", size = 76059261 },
- { url = "https://files.pythonhosted.org/packages/a6/4c/77b535e025ce780d2ada8271c1e481fb7337c1df2588a52fe1c9bd87d2e8/bitsandbytes-0.45.5-py3-none-win_amd64.whl", hash = "sha256:ed1c61b91d989d6a33fd05737d6edbf5086d8ebc89235ee632c7a19144085da2", size = 75430204 },
-]
-
-[[package]]
-name = "bitsandbytes"
-version = "0.47.0"
-source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version > '3.11' and sys_platform == 'darwin'",
- "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.11' and sys_platform == 'darwin'",
- "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.10.*' and sys_platform == 'darwin'",
- "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'",
- "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version > '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.10.*' and sys_platform == 'win32'",
- "python_full_version == '3.9.*' and sys_platform == 'win32'",
-]
-dependencies = [
- { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
-]
-wheels = [
- { url = "https://files.pythonhosted.org/packages/aa/eb/477d6b5602f469c7305fd43eec71d890c39909f615c1d7138f6e7d226eff/bitsandbytes-0.47.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:2f805b76891a596025e9e13318b675d08481b9ee650d65e5d2f9d844084c6521", size = 30004641 },
- { url = "https://files.pythonhosted.org/packages/9c/40/91f1a5a694f434bc13cba160045fdc4e867032e627b001bf411048fefd9c/bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:68f3fffd494a47ed1fd7593bfc5dd2ac69b68260599b71b4c4b3a32f90f3b184", size = 61284639 },
- { url = "https://files.pythonhosted.org/packages/18/a9/e07a227f1cd6562844cea2f05ee576b0991a9a91f45965c06034178ba0f6/bitsandbytes-0.47.0-py3-none-win_amd64.whl", hash = "sha256:4880a6d42ca9628b5a571c8cc3093dc3f5f52511e5a9e47d52d569807975531a", size = 60725121 },
-]
-
[[package]]
name = "blake3"
version = "1.0.5"
@@ -1283,10 +1234,6 @@ docs = [
{ name = "sphinx-rtd-theme" },
]
vllm = [
- { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
{ name = "vllm", marker = "python_full_version >= '3.9'" },
]
@@ -1315,10 +1262,6 @@ docs = [
{ name = "sphinx-rtd-theme" },
]
vllm = [
- { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
{ name = "vllm", marker = "python_full_version >= '3.9'" },
]
@@ -1326,7 +1269,6 @@ vllm = [
requires-dist = [
{ name = "aiosqlite", specifier = "~=0.19.0" },
{ name = "asyncpg", specifier = "~=0.27.0" },
- { name = "bitsandbytes", marker = "extra == 'vllm'", specifier = ">=0.45.5" },
{ name = "blis", specifier = "<1.0.0" },
{ name = "boto3", specifier = "~=1.28.84" },
{ name = "click", specifier = "<8.2.0" },
@@ -1367,7 +1309,6 @@ requires-dist = [
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = "~=3.0.2" },
{ name = "toml", specifier = "~=0.10.2" },
{ name = "torch", marker = "python_full_version < '3.9'", specifier = "<2.5.0" },
- { name = "trl", marker = "extra == 'vllm'", specifier = ">=0.11.4" },
{ name = "typer", specifier = "~=0.15.1" },
{ name = "typer-cli", marker = "extra == 'dev'", specifier = "~=0.15.1" },
{ name = "types-toml", marker = "extra == 'dev'", specifier = "==0.10.8.20240310" },
@@ -1400,11 +1341,7 @@ docs = [
{ name = "sphinx-autodoc-typehints", specifier = "~=2.0.1" },
{ name = "sphinx-rtd-theme", specifier = "~=3.0.2" },
]
-vllm = [
- { name = "bitsandbytes", specifier = ">=0.45.5" },
- { name = "trl", specifier = ">=0.11.4" },
- { name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" },
-]
+vllm = [{ name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" }]
[[package]]
name = "colorama"
@@ -2010,15 +1947,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774 },
]
-[[package]]
-name = "docstring-parser"
-version = "0.16"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/08/12/9c22a58c0b1e29271051222d8906257616da84135af9ed167c9e28f85cb3/docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e", size = 26565 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637", size = 36533 },
-]
-
[[package]]
name = "docutils"
version = "0.20.1"
@@ -2051,15 +1979,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a3/05/8b171626b850e870fc4433225cd6d5bec5a9916b1c39b3d7c67a60492aeb/email_validator-2.1.2-py3-none-any.whl", hash = "sha256:d89f6324e13b1e39889eab7f9ca2f91dc9aebb6fa50a6d8bd4329ab50f251115", size = 30739 },
]
-[[package]]
-name = "eval-type-backport"
-version = "0.2.2"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/30/ea/8b0ac4469d4c347c6a385ff09dc3c048c2d021696664e26c7ee6791631b5/eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1", size = 9079 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830 },
-]
-
[[package]]
name = "evaluate"
version = "0.4.3"
@@ -8179,15 +8098,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 },
]
-[[package]]
-name = "shtab"
-version = "1.7.2"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/5a/3e/837067b970c1d2ffa936c72f384a63fdec4e186b74da781e921354a94024/shtab-1.7.2.tar.gz", hash = "sha256:8c16673ade76a2d42417f03e57acf239bfb5968e842204c17990cae357d07d6f", size = 45751 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/74/03/3271b7bb470fbab4adf5bd30b0d32143909d96f3608d815b447357f47f2b/shtab-1.7.2-py3-none-any.whl", hash = "sha256:858a5805f6c137bb0cda4f282d27d08fd44ca487ab4a6a36d2a400263cd0b5c1", size = 14214 },
-]
-
[[package]]
name = "six"
version = "1.17.0"
@@ -9444,73 +9354,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/74/9f12bdedeb110242d8bb1bd621f6605e753ee0cbf73cf7f3a62b8173f190/triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee", size = 253057866 },
]
-[[package]]
-name = "trl"
-version = "0.11.4"
-source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version < '3.9' and sys_platform != 'win32'",
- "python_full_version < '3.9' and sys_platform == 'win32'",
-]
-dependencies = [
- { name = "accelerate", version = "1.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "datasets", marker = "python_full_version < '3.9'" },
- { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "transformers", version = "4.46.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "tyro", marker = "python_full_version < '3.9'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/07/39/a78c0608190cc412c49631dfb8c3e57f5c5b2fb0d79709071c992e707aa4/trl-0.11.4.tar.gz", hash = "sha256:de52a023fc35d580ab809fd74cd4f362a259e463bb968580e0e97e1b98a0fe79", size = 307304 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/e4/dd/d2cf3dbc1013cee71ceef584f5ab69915fc05d209ef1e276f8652058c350/trl-0.11.4-py3-none-any.whl", hash = "sha256:071d64164c152ef65b44d15f878793b28d3340310c9e157dc3608bbe5fa549a9", size = 316575 },
-]
-
-[[package]]
-name = "trl"
-version = "0.15.2"
-source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version > '3.11' and sys_platform == 'darwin'",
- "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.11' and sys_platform == 'darwin'",
- "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.10.*' and sys_platform == 'darwin'",
- "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'",
- "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version > '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.10.*' and sys_platform == 'win32'",
- "python_full_version == '3.9.*' and sys_platform == 'win32'",
-]
-dependencies = [
- { name = "accelerate", version = "1.7.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "datasets", marker = "python_full_version >= '3.9'" },
- { name = "rich", marker = "python_full_version >= '3.9'" },
- { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/95/fe/ae0d782c48eef56d0ec125ebd05998539ede7cbf0e307a48f9323998b9e7/trl-0.15.2.tar.gz", hash = "sha256:0f82190a058a0a194dbcfae1fe9548b68a0a05b2f4d1824f8db1ae7d949cdd47", size = 333962 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/7b/29/25378447c48359843de0e4ce1995d367210601c3b437ddf1c779b6393d74/trl-0.15.2-py3-none-any.whl", hash = "sha256:bf2b88e3cf5da08cd533dc03273d977965bd5d86c5878f76285fba45d9cb9634", size = 318931 },
-]
-
-[[package]]
-name = "typeguard"
-version = "4.4.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "typing-extensions", marker = "python_full_version < '3.9'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/79/5a/91b7c8cfc2e96962442abc9d65c650436dd831910b4d7878980d6596fb98/typeguard-4.4.0.tar.gz", hash = "sha256:463bd8697a65a4aa576a63767c369b1ecfba8a5ba735edfe3223127b6ecfa28c", size = 74399 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/61/a3/00203767544b597a9e3c57b29a84967b3230f00bdd9aa6a52a73187043b4/typeguard-4.4.0-py3-none-any.whl", hash = "sha256:8ca34c14043f53b2caae7040549ba431770869bcd6287cfa8239db7ecb882b4a", size = 35736 },
-]
-
[[package]]
name = "typer"
version = "0.15.4"
@@ -9568,24 +9411,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 },
]
-[[package]]
-name = "tyro"
-version = "0.9.24"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "colorama", marker = "python_full_version < '3.9' and sys_platform == 'win32'" },
- { name = "docstring-parser", marker = "python_full_version < '3.9'" },
- { name = "eval-type-backport", marker = "python_full_version < '3.9'" },
- { name = "rich", marker = "python_full_version < '3.9'" },
- { name = "shtab", marker = "python_full_version < '3.9'" },
- { name = "typeguard", marker = "python_full_version < '3.9'" },
- { name = "typing-extensions", marker = "python_full_version < '3.9'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/57/49/ca1698fcc5479fe9c7eff48861ebb671c5a6afba0245ea7cd560a939f281/tyro-0.9.24.tar.gz", hash = "sha256:5a9ef93d1b8e93cff2c5d82789a571d905d152e92af82a3ec96a17d668194df3", size = 303651 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/ac/59/4c865b56babef1aa6a9662879c94507dc62d0173ac7433579d7a2728f7e5/tyro-0.9.24-py3-none-any.whl", hash = "sha256:d8152e47375419752210da455226007b4bb9bd9c65af1de8fb12daf0658c91dc", size = 128326 },
-]
-
[[package]]
name = "tzdata"
version = "2025.2"
From 9d18564f0ee023d28cd5afb94e8a0fdfd63ba01a Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 5/9] Revert "feat: add the text embedding endpoint for LLM
serving"
This reverts commit d76098609662430b74e02c1dcc463b090d7a1dac.
---
app/api/routers/generative.py | 110 ++----------------
app/domain.py | 15 +--
app/model_services/base.py | 27 +----
app/model_services/huggingface_llm_model.py | 50 +-------
tests/app/api/test_serving_hf_llm.py | 22 ----
.../test_huggingface_llm_model.py | 104 +----------------
6 files changed, 13 insertions(+), 315 deletions(-)
diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py
index 26f4fd1..763c717 100644
--- a/app/api/routers/generative.py
+++ b/app/api/routers/generative.py
@@ -10,16 +10,8 @@
from fastapi import APIRouter, Depends, Request, Body, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
-from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
-from app.domain import (
- Tags,
- OpenAIChatRequest,
- OpenAIChatResponse,
- OpenAIEmbeddingsRequest,
- OpenAIEmbeddingsResponse,
- PromptMessage,
- PromptRole,
-)
+from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
+from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole
from app.model_services.base import AbstractModelService
from app.utils import get_settings, get_prompt_from_messages
from app.api.utils import get_rate_limiter
@@ -29,7 +21,6 @@
PATH_GENERATE = "/generate"
PATH_GENERATE_ASYNC = "/stream/generate"
PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
-PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
router = APIRouter()
config = get_settings()
@@ -143,7 +134,7 @@ async def generate_text_stream(
@router.post(
PATH_OPENAI_COMPLETIONS,
- tags=[Tags.OpenAICompatible.name],
+ tags=[Tags.Generative.name],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions",
@@ -171,7 +162,6 @@ def generate_chat_completions(
"""
messages = request_data.messages
- model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
stream = request_data.stream
max_tokens = request_data.max_tokens
temperature = request_data.temperature
@@ -234,7 +224,7 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
id=tracking_id,
object="chat.completion",
created=int(time.time()),
- model=model,
+ model=model_service.model_name,
choices=[
{
"index": 0,
@@ -249,100 +239,14 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})
-@router.post(
- PATH_OPENAI_EMBEDDINGS,
- tags=[Tags.OpenAICompatible.name],
- response_model=None,
- dependencies=[Depends(cms_globals.props.current_active_user)],
- description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint",
-)
-def embed_texts(
- request: Request,
- request_data: Annotated[OpenAIEmbeddingsRequest, Body(
- description="Text(s) to be embedded", media_type="application/json"
- )],
- tracking_id: Union[str, None] = Depends(validate_tracking_id),
- model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
-) -> JSONResponse:
- """
- Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.
-
- Args:
- request (Request): The request object.
- request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
- model_service (AbstractModelService): The model service dependency.
-
- Returns:
- JSONResponse: A response containing the embeddings of the text(s).
- """
- tracking_id = tracking_id or str(uuid.uuid4())
-
- if not hasattr(model_service, "create_embeddings"):
- error_response = {
- "error": {
- "message": "Model does not support embeddings",
- "type": "invalid_request_error",
- "param": "model",
- "code": "model_not_supported",
- }
- }
- return JSONResponse(
- content=error_response,
- status_code=HTTP_500_INTERNAL_SERVER_ERROR,
- headers={"x-cms-tracking-id": tracking_id},
- )
-
- input_text = request_data.input
- model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
-
- if isinstance(input_text, str):
- input_texts = [input_text]
- else:
- input_texts = input_text
-
- try:
- embeddings_data = []
-
- for i, embedding in enumerate(model_service.create_embeddings(input_texts)):
- embeddings_data.append({
- "object": "embedding",
- "embedding": embedding,
- "index": i,
- })
-
- response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model)
-
- return JSONResponse(
- content=jsonable_encoder(response),
- headers={"x-cms-tracking-id": tracking_id},
- )
-
- except Exception as e:
- logger.error("Failed to create embeddings")
- logger.exception(e)
- error_response = {
- "error": {
- "message": f"Failed to create embeddings: {str(e)}",
- "type": "server_error",
- "code": "internal_error",
- }
- }
- return JSONResponse(
- content=error_response,
- status_code=HTTP_500_INTERNAL_SERVER_ERROR,
- headers={"x-cms-tracking-id": tracking_id},
- )
-
-
def _empty_prompt_error() -> Iterable[str]:
yield "ERROR: No prompt text provided\n"
def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None:
cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num)
- logger.debug("Sent prompt tokens usage: %s", prompt_token_num)
+ logger.debug(f"Sent prompt tokens usage: {prompt_token_num}")
cms_completion_tokens.labels(handler=handler).observe(completion_token_num)
- logger.debug("Sent completion tokens usage: %s", completion_token_num)
+ logger.debug(f"Sent completion tokens usage: {completion_token_num}")
cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num)
- logger.debug("Sent total tokens usage: %s", prompt_token_num + completion_token_num)
+ logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}")
diff --git a/app/domain.py b/app/domain.py
index de098a8..2ddd6e3 100644
--- a/app/domain.py
+++ b/app/domain.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import List, Optional, Dict, Any, Union
+from typing import List, Optional, Dict, Any
from fastapi import HTTPException
from starlette.status import HTTP_400_BAD_REQUEST
@@ -27,7 +27,6 @@ class Tags(str, Enum):
Evaluating = "Evaluate the deployed model with trainer export"
Authentication = "Authenticate registered users"
Generative = "Generate text based on the input prompt"
- OpenAICompatible = "Compatible with OpenAI APIs"
class TagsStreamable(str, Enum):
@@ -186,7 +185,6 @@ class OpenAIChatRequest(BaseModel):
messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model")
stream: bool = Field(..., description="Whether to stream the response")
max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
- model: str = Field(..., description="The name of the model used for generating the completion")
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
@@ -196,14 +194,3 @@ class OpenAIChatResponse(BaseModel):
created: int = Field(..., description="The timestamp when the completion was generated")
model: str = Field(..., description="The name of the model used for generating the completion")
choices: List = Field(..., description="The generated messages and their metadata")
-
-
-class OpenAIEmbeddingsRequest(BaseModel):
- input: Union[str, List[str]] = Field(..., description="Input text or list of texts to embed")
- model: str = Field(..., description="The name of the model used for creating the embeddings")
-
-
-class OpenAIEmbeddingsResponse(BaseModel):
- object: str = Field(..., description="The type of the response")
- data: List[Dict[str, Any]] = Field(..., description="List of embedding objects")
- model: str = Field(..., description="The name of the model used for creating the embeddings")
diff --git a/app/model_services/base.py b/app/model_services/base.py
index dfde491..a3c1ccc 100644
--- a/app/model_services/base.py
+++ b/app/model_services/base.py
@@ -1,6 +1,6 @@
import asyncio
from abc import ABC, abstractmethod
-from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable, Union
+from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable
from app.config import Settings
from app.domain import ModelCard, Annotation
@@ -17,7 +17,7 @@ def tracker_client(self) -> Any:
T = TypeVar("T", bound=_TrainerCommon)
class AbstractModelService(ABC, Generic[T]):
- """An abstract base class defining the common interface for NER model services."""
+ """An abstract base class defining the common interface for all model services."""
@abstractmethod
def __init__(self, config: Settings, *args: Any, **kwargs: Any) -> None:
@@ -200,29 +200,6 @@ def generate_async(self, prompt: str, *args: Any, **kwargs: Any) -> AsyncIterabl
raise NotImplementedError
- def create_embeddings(
- self,
- text: Union[str, List[str]],
- *args: Any,
- **kwargs: Any
- ) -> Union[List[float], List[List[float]]]:
- """
- Creates embeddings for a given text or list of texts.
-
- Args:
- text (Union[str, List[str]]): The text(s) to be embedded.
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
-
- Returns:
- Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s).
-
- Raises:
- NotImplementedError: If the method is not implemented by the subclass.
- """
-
- raise NotImplementedError
-
def train_supervised(self, *args: Any, **kwargs: Any) -> Tuple[bool, str, str]:
"""
Initiates supervised training on the model.
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index 25cd032..7340f67 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -1,9 +1,8 @@
import os
import logging
import asyncio
-import torch
from concurrent.futures import ThreadPoolExecutor
-from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union
+from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -308,50 +307,3 @@ async def generate_async(
return
finally:
logger.debug("Chat response generation completed")
-
- def create_embeddings(
- self,
- text: Union[str, List[str]],
- *args: Any,
- **kwargs: Any
- ) -> Union[List[float], List[List[float]]]:
- """
- Creates embeddings for a given text or list of texts using the model's hidden states.
-
- Args:
- text (Union[str, List[str]]): The text(s) to be embedded.
- *args (Any): Additional positional arguments to be passed to this method.
- **kwargs (Any): Additional keyword arguments to be passed to this method.
-
- Returns:
- List[float], List[List[float]]: The embedding vector(s) for the text(s).
-
- Raises:
- NotImplementedError: If the model doesn't support embeddings.
- """
-
- self.model.eval()
-
- inputs = self.tokenizer(
- text,
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True,
- )
-
- if non_default_device_is_available(self._config.DEVICE):
- inputs.to(get_settings().DEVICE)
-
- with torch.no_grad():
- outputs = self.model(**inputs, output_hidden_states=True)
-
- last_hidden_state = outputs.hidden_states[-1]
- attention_mask = inputs["attention_mask"]
- masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1)
- sum_hidden_states = masked_hidden_states.sum(dim=1)
- num_tokens = attention_mask.sum(dim=1, keepdim=True)
- embeddings = sum_hidden_states / num_tokens
-
- results = embeddings.cpu().numpy().tolist()
- return results[0] if isinstance(text, str) else results
diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py
index 9b7ea9f..4bd4e9d 100644
--- a/tests/app/api/test_serving_hf_llm.py
+++ b/tests/app/api/test_serving_hf_llm.py
@@ -31,9 +31,7 @@ def llm_app(llm_model_service):
@pytest.fixture(scope="function")
def client(llm_model_service):
- llm_model_service.model_name = "HuggingFace LLM model"
llm_model_service.generate.return_value = "Yeah."
- llm_model_service.create_embeddings.return_value = [[1.0, 2.0, 3.0]]
app = get_generative_server(config, msd_overwritten=lambda: llm_model_service)
app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None
client = TestClient(app)
@@ -84,7 +82,6 @@ async def test_generate_chat_completions(llm_model_service, llm_app):
"content": "Who are you?"
}
],
- "model": "HuggingFace LLM model",
"stream": True,
"max_tokens": 128,
"temperature": 0.7
@@ -101,22 +98,3 @@ async def test_generate_chat_completions(llm_model_service, llm_app):
assert response.text.startswith("data:")
assert "id" in response.text
assert "chat.completion.chunk" in response.text
-
-
-def test_create_embeddings(client):
- request_data = {
- "input": ["Alright"],
- "model": "HuggingFace LLM model",
- }
- response = client.post(
- "/v1/embeddings",
- data=json.dumps(request_data),
- headers={"Content-Type": "application/json"},
- )
- assert response.status_code == 200
- assert response.headers["content-type"] == "application/json"
- assert response.json() == {
- "object": "list",
- "data": [{"object": "embedding", "embedding": [1.0, 2.0, 3.0], "index": 0}],
- "model": "HuggingFace LLM model"
- }
diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py
index e6ee637..7cd4941 100644
--- a/tests/app/model_services/test_huggingface_llm_model.py
+++ b/tests/app/model_services/test_huggingface_llm_model.py
@@ -1,5 +1,5 @@
import os
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from tests.app.conftest import MODEL_PARENT_DIR
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from app import __version__
@@ -129,104 +129,4 @@ async def test_generate_async(huggingface_llm_model):
prompt_token_num=2,
completion_token_num=2,
)
- assert result == "Yeah."
-
-
-def test_create_embeddings_single_text(huggingface_llm_model):
- """Test create_embeddings with single text input."""
- huggingface_llm_model.init_model()
- huggingface_llm_model.model = MagicMock()
- huggingface_llm_model.tokenizer = MagicMock()
- mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()]
- mock_outputs = MagicMock()
- mock_outputs.hidden_states = mock_hidden_states
- mock_last_hidden_state = MagicMock()
- mock_last_hidden_state.shape = [1, 3, 768]
- mock_hidden_states[-1] = mock_last_hidden_state
- mock_attention_mask = MagicMock()
- mock_attention_mask.shape = [1, 3]
- mock_attention_mask.sum.return_value = MagicMock()
- mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock()
- mock_inputs = MagicMock()
- mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock()
- huggingface_llm_model.tokenizer.return_value = mock_inputs
- huggingface_llm_model.model.return_value = mock_outputs
- expected_result = [0.1, 0.2, 0.3]
- mock_embeddings_batch = MagicMock()
- mock_first_embedding = MagicMock()
- mock_cpu_tensor = MagicMock()
- mock_numpy_array = MagicMock()
- mock_numpy_array.tolist.return_value = expected_result
- mock_embeddings_batch.__getitem__.return_value = mock_first_embedding
- mock_first_embedding.cpu.return_value = mock_cpu_tensor
- mock_cpu_tensor.numpy.return_value = mock_numpy_array
- mock_masked_hidden_states = MagicMock()
- mock_sum_hidden_states = MagicMock()
- mock_num_tokens = MagicMock()
- mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states
- mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states
- mock_attention_mask.sum.return_value = mock_num_tokens
- mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch
-
- result = huggingface_llm_model.create_embeddings("Alright")
-
- huggingface_llm_model.model.eval.assert_called_once()
- huggingface_llm_model.tokenizer.assert_called_once_with(
- "Alright",
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True
- )
- huggingface_llm_model.model.assert_called_once_with(
- **mock_inputs,
- output_hidden_states=True
- )
-
- assert result is not None
-
-
-def test_create_embeddings_list_text(huggingface_llm_model):
- huggingface_llm_model.init_model()
- huggingface_llm_model.model = MagicMock()
- huggingface_llm_model.tokenizer = MagicMock()
- mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()]
- mock_outputs = MagicMock()
- mock_outputs.hidden_states = mock_hidden_states
- mock_last_hidden_state = MagicMock()
- mock_last_hidden_state.shape = [2, 3, 768]
- mock_hidden_states[-1] = mock_last_hidden_state
- mock_attention_mask = MagicMock()
- mock_attention_mask.shape = [2, 3]
- mock_attention_mask.sum.return_value = MagicMock()
- mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock()
- mock_inputs = MagicMock()
- mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock()
- huggingface_llm_model.tokenizer.return_value = mock_inputs
- huggingface_llm_model.model.return_value = mock_outputs
- mock_embeddings_batch = MagicMock()
- mock_first_embedding = MagicMock()
- mock_cpu_tensor = MagicMock()
- mock_numpy_array = MagicMock()
- mock_numpy_array.tolist.return_value = [[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]]
- mock_embeddings_batch.__getitem__.return_value = mock_first_embedding
- mock_first_embedding.cpu.return_value = mock_cpu_tensor
- mock_cpu_tensor.numpy.return_value = mock_numpy_array
- mock_masked_hidden_states = MagicMock()
- mock_sum_hidden_states = MagicMock()
- mock_num_tokens = MagicMock()
- mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states
- mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states
- mock_attention_mask.sum.return_value = mock_num_tokens
- mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch
-
- result = huggingface_llm_model.create_embeddings(["Alright", "Alright"])
-
- huggingface_llm_model.tokenizer.assert_called_once_with(
- ["Alright", "Alright"],
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True,
- )
- assert result is not None
+ assert result == "Yeah."
\ No newline at end of file
From eb422a3b2aceaf9f9047c008010b5e9ad2320d29 Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 6/9] Revert "feat: add the chat template factory"
This reverts commit 587e9ea3c5b87bca3f668fd641d60e5b3d23057a.
---
app/processors/prompt_factory.py | 262 --------------------
app/utils.py | 66 ++---
tests/app/processors/test_prompt_factory.py | 228 -----------------
3 files changed, 26 insertions(+), 530 deletions(-)
delete mode 100644 app/processors/prompt_factory.py
delete mode 100644 tests/app/processors/test_prompt_factory.py
diff --git a/app/processors/prompt_factory.py b/app/processors/prompt_factory.py
deleted file mode 100644
index ee1a45c..0000000
--- a/app/processors/prompt_factory.py
+++ /dev/null
@@ -1,262 +0,0 @@
-class PromptFactory:
-
- _ALPACA = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'].strip() + '\n' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{% set content = system_message + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
-
- "{% if message['role'] == 'user' %}"
- "{{ '### Instruction:\n' + content.strip() + '\n\n'}}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ '### Response:\n' + content.strip() + '\n\n' }}"
- "{% endif %}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{{ '### Response:\n' }}"
- "{% endif %}"
- )
-
- _CHAT_ML = (
- "{% for message in messages %}"
- "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() + '<|im_end|>' + '\n'}}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{{'<|im_start|>assistant\n'}}"
- "{% endif %}"
- )
-
- _DEFAULT = (
- "{% for message in messages %}"
- "{% if message['role'] == 'user' %}"
- "{{'<|user|>\n' + message['content'] + eos_token}}"
- "{% elif message['role'] == 'system' %}"
- "{{'<|system|>\n' + message['content'] + eos_token}}"
- "{% elif message['role'] == 'assistant' %}"
- "{{'<|assistant|>\n' + message['content'] + eos_token}}"
- "{% endif %}"
- "{% if loop.last and add_generation_prompt %}"
- "{{'<|assistant|>'}}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- _FALCON = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'] %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{{ system_message.strip() }}"
- "{% endif %}"
- "{{ '\n\n' + message['role'].title() + ': ' + message['content'].strip().replace('\r\n', '\n').replace('\n\n', '\n') }}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{ '\n\nAssistant:' }}"
- "{% endif %}"
- )
-
- _GEMMA = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{% set content = system_message + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if (message['role'] == 'assistant') %}"
- "{% set role = 'model' %}"
- "{% else %}"
- "{% set role = message['role'] %}"
- "{% endif %}"
- "{{ '' + role + '\n' + content.strip() + '\n' }}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{{'model\n'}}"
- "{% endif %}"
- )
-
- _LLAMA_2 = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = '<>\n' + messages[0]['content'].strip() + '\n<>\n\n' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{% set content = system_message + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ ' ' + content.strip() + ' ' + eos_token }}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- _LLAMA_3 = (
- "{{ bos_token }}"
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + messages[0]['content'].strip() + '<|eot_id|>' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{{ system_message }}"
- "{% endif %}"
- "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'].strip() + '<|eot_id|>' }}"
- "{% if loop.last and message['role'] == 'user' and add_generation_prompt %}"
- "{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- _MISTRAL = (
- "{{ bos_token }}"
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{% set content = system_message + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ '[INST] ' + content.strip() + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ content.strip() + eos_token}}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- _PHI_2 = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = '' %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 %}"
- "{% set content = system_message + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ 'Instruct: ' + content.strip() + '\n' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ 'Output: ' + content.strip() + '\n' }}"
- "{% endif %}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{{ 'Output:' }}"
- "{% endif %}"
- )
-
- _PHI_3 = (
- "{{ bos_token }}"
- "{% for message in messages %}"
- "{% if (message['role'] == 'system') %}"
- "{{'<|system|>' + '\n' + message['content'].strip() + '<|end|>' + '\n'}}"
- "{% elif (message['role'] == 'user') %}"
- "{{'<|user|>' + '\n' + message['content'].strip() + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
- "{% elif message['role'] == 'assistant' %}"
- "{{message['content'].strip() + '<|end|>' + '\n'}}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- _QWEN = (
- "{% for message in messages %}"
- "{% if loop.first and messages[0]['role'] != 'system' %}"
- "{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}"
- "{% endif %}"
- "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() }}"
- "{% if (loop.last and add_generation_prompt) or not loop.last %}"
- "{{ '<|im_end|>' + '\n'}}"
- "{% endif %}"
- "{% endfor %}"
- "{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}"
- "{{ '<|im_start|>assistant\n' }}"
- "{% endif %}"
- )
-
- @classmethod
- def create_chat_template(cls, name: str = "default") -> str:
- if name.lower() == "default":
- return cls._DEFAULT
- elif name.lower() == "alpaca":
- return cls._ALPACA
- elif name.lower() == "chat_ml":
- return cls._CHAT_ML
- elif name.lower() == "falcon":
- return cls._FALCON
- elif name.lower() == "gemma":
- return cls._GEMMA
- elif name.lower() == "llama_2":
- return cls._LLAMA_2
- elif name.lower() == "llama_3":
- return cls._LLAMA_3
- elif name.lower() == "mistral":
- return cls._MISTRAL
- elif name.lower() == "phi_2":
- return cls._PHI_2
- elif name.lower() == "phi_3":
- return cls._PHI_3
- elif name.lower() == "qwen":
- return cls._QWEN
- else:
- raise ValueError("Invalid template name")
diff --git a/app/utils.py b/app/utils.py
index fcc231e..ee97d51 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -27,7 +27,6 @@
from app.config import Settings
from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole
from app.exception import ManagedModelException
-from app.processors.prompt_factory import PromptFactory
@lru_cache
@@ -740,60 +739,47 @@ def download_model_package(
retry_delay *= 2
-def get_prompt_from_messages(
- tokenizer: PreTrainedTokenizer,
- messages: List[PromptMessage],
- override_template: Optional[str] = None,
-) -> str:
+def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[PromptMessage]) -> str:
"""
Generates a prompt from a list of prompt messages.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template.
messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt.
- override_template (str): The name of the chat template to use for generating the prompt.
Returns:
str: The generated prompt.
"""
- if override_template is None:
- if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
- prompt = tokenizer.apply_chat_template(
- [dump_pydantic_object_to_dict(message) for message in messages],
- tokenize=False,
- add_generation_prompt=True,
- )
- elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
- # This largely depends on how older versions of HF tokenizers behave and may not work universally
- tokenizer.chat_template = tokenizer.default_chat_template
- prompt = tokenizer.apply_chat_template(
- [dump_pydantic_object_to_dict(message) for message in messages],
- tokenize=False,
- add_generation_prompt=True,
- )
- else:
- system_content = ""
- prompt_parts: List[str] = []
- for message in messages:
- content = message.content.strip()
- if message.role == PromptRole.SYSTEM:
- system_content = content
- elif message.role == PromptRole.USER:
- prompt_parts.append(f"<|user|>\n{content}")
- elif message.role == PromptRole.ASSISTANT:
- prompt_parts.append(f"<|assistant|>\n{content}")
- if system_content:
- prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts)
- else:
- prompt = "\n".join(prompt_parts)
- prompt += "\n<|assistant|>\n"
- else:
- tokenizer.chat_template = PromptFactory.create_chat_template(name=override_template)
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
+ prompt = tokenizer.apply_chat_template(
+ [dump_pydantic_object_to_dict(message) for message in messages],
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
+ # This largely depends on how older versions of HF tokenizers behave and may not work universally
+ tokenizer.chat_template = tokenizer.default_chat_template
prompt = tokenizer.apply_chat_template(
[dump_pydantic_object_to_dict(message) for message in messages],
tokenize=False,
add_generation_prompt=True,
)
+ else:
+ system_content = ""
+ prompt_parts: List[str] = []
+ for message in messages:
+ content = message.content.strip()
+ if message.role == PromptRole.SYSTEM:
+ system_content = content
+ elif message.role == PromptRole.USER:
+ prompt_parts.append(f"<|user|>\n{content}")
+ elif message.role == PromptRole.ASSISTANT:
+ prompt_parts.append(f"<|assistant|>\n{content}")
+ if system_content:
+ prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts)
+ else:
+ prompt = "\n".join(prompt_parts)
+ prompt += "\n<|assistant|>\n"
return prompt
TYPE_ID_TO_NAME_PATCH = {
diff --git a/tests/app/processors/test_prompt_factory.py b/tests/app/processors/test_prompt_factory.py
deleted file mode 100644
index ab03388..0000000
--- a/tests/app/processors/test_prompt_factory.py
+++ /dev/null
@@ -1,228 +0,0 @@
-from jinja2.sandbox import ImmutableSandboxedEnvironment
-from app.processors.prompt_factory import PromptFactory
-
-
-def test_create_default_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("default"))
- prompt = template.render(
- messages=messages,
- bos_token="<|system|>",
- eos_token="<|end|>",
- add_generation_prompt=True,
- )
- assert prompt == "<|system|>\nAlright?<|end|><|user|>\nYeah.<|end|><|assistant|>"
-
-
-def test_create_alpaca_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("alpaca"))
- prompt = template.render(
- messages=messages,
- add_generation_prompt=True,
- )
- assert prompt == "### Instruction:\nAlright?\nYeah.\n\n### Response:\n"
-
-
-def test_create_chat_ml_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("chat_ml"))
- prompt = template.render(
- messages=messages,
- add_generation_prompt=True,
- )
- assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n"
-
-def test_create_falcon_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("falcon"))
- prompt = template.render(
- messages=messages,
- add_generation_prompt=True,
- )
- assert prompt == "Alright?\n\nUser: Yeah.{ '\n\nAssistant:' }}"
-
-def test_create_gemma_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("gemma"))
- prompt = template.render(
- messages=messages,
- add_generation_prompt=True,
- )
- assert prompt == "user\nAlright?\n\nYeah.\nmodel\n"
-
-def test_create_llama_2_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("LLAMA_2"))
- prompt = template.render(
- messages=messages,
- bos_token="",
- eos_token="",
- add_generation_prompt=True,
- )
- assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]"
-
-def test_create_llama_3_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("LLAMA_2"))
- prompt = template.render(
- messages=messages,
- bos_token="",
- eos_token="",
- add_generation_prompt=True,
- )
- assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]"
-
-def test_create_mistral_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("mistral"))
- prompt = template.render(
- messages=messages,
- bos_token="",
- eos_token="",
- add_generation_prompt=True,
- )
- assert prompt == "[INST] Alright?\n\nYeah. [/INST]"
-
-def test_create_phi_2_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("phi_2"))
- prompt = template.render(
- messages=messages,
- bos_token="<|endoftext|>",
- eos_token="<|endoftext|>",
- add_generation_prompt=True,
- )
- assert prompt == "Instruct: Alright?\n\nYeah.\nOutput:"
-
-def test_create_phi_3_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("phi_3"))
- prompt = template.render(
- messages=messages,
- bos_token="",
- eos_token="<|end|>",
- add_generation_prompt=True,
- )
- assert prompt == "<|system|>\nAlright?<|end|>\n<|user|>\nYeah.<|end|>\n<|assistant|>\n"
-
-def test_create_qwen_chat_template():
- env = ImmutableSandboxedEnvironment()
- messages = [
- {
- "role": "system",
- "content": "Alright?"
- },
- {
- "role": "user",
- "content": "Yeah."
- },
- ]
- template = env.from_string(PromptFactory.create_chat_template("qwen"))
- prompt = template.render(
- messages=messages,
- bos_token="",
- eos_token="<|end|>",
- add_generation_prompt=True,
- )
- assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n"
From 1a5c2334b4fe110cb92146b261d9ff4865e0e63b Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 7/9] Revert "feat: handle the deprecated default chat
template"
This reverts commit 6c05ee1c5c1e5a3c1daf4c70b6ba1238ed2777df.
---
app/api/utils.py | 2 +-
app/utils.py | 9 +--------
tests/app/test_utils.py | 21 ++-------------------
3 files changed, 4 insertions(+), 28 deletions(-)
diff --git a/app/api/utils.py b/app/api/utils.py
index 05726ba..ba99eed 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -352,7 +352,7 @@ async def generate_text(
chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
tokenize=True,
)
- prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) # type: ignore
+ prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens)
async def _stream() -> AsyncGenerator[bytes, None]:
start = 0
diff --git a/app/utils.py b/app/utils.py
index ee97d51..23968b2 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -756,14 +756,6 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
tokenize=False,
add_generation_prompt=True,
)
- elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
- # This largely depends on how older versions of HF tokenizers behave and may not work universally
- tokenizer.chat_template = tokenizer.default_chat_template
- prompt = tokenizer.apply_chat_template(
- [dump_pydantic_object_to_dict(message) for message in messages],
- tokenize=False,
- add_generation_prompt=True,
- )
else:
system_content = ""
prompt_parts: List[str] = []
@@ -843,3 +835,4 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
"25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.',
"55540447": "linkage concept"
}
+
diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py
index 2f00e10..470d0d8 100644
--- a/tests/app/test_utils.py
+++ b/tests/app/test_utils.py
@@ -412,27 +412,10 @@ def test_get_prompt_with_chat_template():
assert prompt == "Mock chat template applied"
-def test_get_prompt_with_default_chat_template():
- with patch('transformers.PreTrainedTokenizer') as tok:
- mock_tokenizer = tok.return_value
- mock_tokenizer.chat_template = None
- mock_tokenizer.default_chat_template = "Mock default chat template"
- mock_tokenizer.apply_chat_template.return_value = "Mock default chat template applied"
- messages = [
- PromptMessage(content="Alright?", role=PromptRole.USER.value),
- PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value),
- ]
-
- prompt = get_prompt_from_messages(mock_tokenizer, messages)
-
- assert prompt == "Mock default chat template applied"
-
-
def test_get_prompt_without_chat_template():
with patch('transformers.PreTrainedTokenizer') as tok:
mock_tokenizer = tok.return_value
mock_tokenizer.chat_template = None
- mock_tokenizer.default_chat_template = None
messages = [
PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value),
PromptMessage(content="Alright?", role=PromptRole.USER.value),
@@ -449,9 +432,9 @@ def test_get_prompt_with_no_messages():
with patch('transformers.PreTrainedTokenizer') as tok:
mock_tokenizer = tok.return_value
mock_tokenizer.chat_template = None
- mock_tokenizer.default_chat_template = None
messages = []
prompt = get_prompt_from_messages(mock_tokenizer, messages)
- assert prompt == "\n<|assistant|>\n"
+ expected_prompt = "\n<|assistant|>\n"
+ assert prompt == expected_prompt
\ No newline at end of file
From a4ff387eaf4bb3dac9152852c47834d7f3dbe45b Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 8/9] Revert "feat: add metrics for usages of prompt and
completion tokens"
This reverts commit b338b4f827d107aaab27ac9bed060664171f8dbc.
---
app/api/routers/generative.py | 67 +++++--------------
app/domain.py | 10 +--
app/management/prometheus_metrics.py | 21 ------
app/model_services/huggingface_llm_model.py | 25 +------
app/utils.py | 3 +-
.../test_huggingface_llm_model.py | 12 ----
6 files changed, 24 insertions(+), 114 deletions(-)
diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py
index 763c717..5f9ff9f 100644
--- a/app/api/routers/generative.py
+++ b/app/api/routers/generative.py
@@ -6,7 +6,6 @@
from typing import Union, Iterable, AsyncGenerator
from typing_extensions import Annotated
-from functools import partial
from fastapi import APIRouter, Depends, Request, Body, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
@@ -16,7 +15,6 @@
from app.utils import get_settings, get_prompt_from_messages
from app.api.utils import get_rate_limiter
from app.api.dependencies import validate_tracking_id
-from app.management.prometheus_metrics import cms_prompt_tokens, cms_completion_tokens, cms_total_tokens
PATH_GENERATE = "/generate"
PATH_GENERATE_ASYNC = "/stream/generate"
@@ -46,7 +44,7 @@ def generate_text(
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> PlainTextResponse:
"""
- Generates text based on the prompt provided.
+ Generate text based on the prompt provided.
Args:
request (Request): The request object.
@@ -63,12 +61,7 @@ def generate_text(
tracking_id = tracking_id or str(uuid.uuid4())
if prompt:
return PlainTextResponse(
- model_service.generate(
- prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE),
- ),
+ model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature),
headers={"x-cms-tracking-id": tracking_id},
status_code=HTTP_200_OK,
)
@@ -96,7 +89,7 @@ async def generate_text_stream(
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> StreamingResponse:
"""
- Generates a stream of texts in near real-time.
+ Generate a stream of texts in near real-time.
Args:
request (Request): The request object.
@@ -113,12 +106,7 @@ async def generate_text_stream(
tracking_id = tracking_id or str(uuid.uuid4())
if prompt:
return StreamingResponse(
- model_service.generate_async(
- prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC),
- ),
+ model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature),
media_type="text/event-stream",
headers={"x-cms-tracking-id": tracking_id},
status_code=HTTP_200_OK,
@@ -133,7 +121,7 @@ async def generate_text_stream(
@router.post(
- PATH_OPENAI_COMPLETIONS,
+ "/v1/chat/completions",
tags=[Tags.Generative.name],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
@@ -148,7 +136,7 @@ def generate_chat_completions(
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> Union[StreamingResponse, JSONResponse]:
"""
- Generates chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
+ Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
Args:
request (Request): The request object.
@@ -165,7 +153,6 @@ def generate_chat_completions(
stream = request_data.stream
max_tokens = request_data.max_tokens
temperature = request_data.temperature
- tracking_id = tracking_id or str(uuid.uuid4())
if not messages:
error_response = {
@@ -176,25 +163,16 @@ def generate_chat_completions(
"code": "missing_field",
}
}
- return JSONResponse(
- content=error_response,
- status_code=HTTP_400_BAD_REQUEST,
- headers={"x-cms-tracking-id": tracking_id},
- )
+ return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST)
- async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator:
+ async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
data = {
- "id": tracking_id,
+ "id": tracking_id or str(uuid.uuid4()),
"object": "chat.completion.chunk",
"choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}],
}
yield f"data: {json.dumps(data)}\n\n"
- async for chunk in model_service.generate_async(
- prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS)
- ):
+ async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t):
data = {
"choices": [
{
@@ -210,18 +188,12 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
if stream:
return StreamingResponse(
_stream(prompt, max_tokens, temperature),
- media_type="text/event-stream",
- headers={"x-cms-tracking-id": tracking_id},
+ media_type="text/event-stream"
)
else:
- generated_text = model_service.generate(
- prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS),
- )
+ generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature)
completion = OpenAIChatResponse(
- id=tracking_id,
+ id=str(uuid.uuid4()),
object="chat.completion",
created=int(time.time()),
model=model_service.model_name,
@@ -234,19 +206,10 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
),
"finish_reason": "stop",
}
- ],
+ ]
)
- return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})
+ return JSONResponse(content=jsonable_encoder(completion))
def _empty_prompt_error() -> Iterable[str]:
yield "ERROR: No prompt text provided\n"
-
-
-def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None:
- cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num)
- logger.debug(f"Sent prompt tokens usage: {prompt_token_num}")
- cms_completion_tokens.labels(handler=handler).observe(completion_token_num)
- logger.debug(f"Sent completion tokens usage: {completion_token_num}")
- cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num)
- logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}")
diff --git a/app/domain.py b/app/domain.py
index 2ddd6e3..29a779a 100644
--- a/app/domain.py
+++ b/app/domain.py
@@ -189,8 +189,8 @@ class OpenAIChatRequest(BaseModel):
class OpenAIChatResponse(BaseModel):
- id: str = Field(..., description="The unique identifier for the chat completion request")
- object: str = Field(..., description="The type of the response")
- created: int = Field(..., description="The timestamp when the completion was generated")
- model: str = Field(..., description="The name of the model used for generating the completion")
- choices: List = Field(..., description="The generated messages and their metadata")
+ id: str
+ object: str
+ created: int
+ model: str
+ choices: List
diff --git a/app/management/prometheus_metrics.py b/app/management/prometheus_metrics.py
index 78c5698..3f48858 100644
--- a/app/management/prometheus_metrics.py
+++ b/app/management/prometheus_metrics.py
@@ -34,24 +34,3 @@
"Number of bulk-processed documents",
["handler"],
)
-
-# The histogram metric to track the number of tokens in the messages of the input prompt
-cms_prompt_tokens = Histogram(
- "cms_prompt_tokens",
- "Number of tokens in the messages of the input prompt",
- ["handler"],
-)
-
-# The histogram metric to track the number of tokens in the generated assistant reply
-cms_completion_tokens = Histogram(
- "cms_completion_tokens",
- "Number of tokens in the generated assistant reply",
- ["handler"],
-)
-
-# The histogram metric to track the total number of tokens used in the prompt and the completion
-cms_total_tokens = Histogram(
- "cms_total_tokens",
- "Number of tokens used in the prompt and the completion",
- ["handler"],
-)
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index 7340f67..848b02c 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -2,7 +2,7 @@
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
-from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable
+from typing import Dict, List, Optional, Tuple, Any, AsyncIterable
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -198,7 +198,6 @@ def generate(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
- report_tokens: Optional[Callable[[str], None]] = None,
**kwargs: Any
) -> str:
"""
@@ -208,7 +207,6 @@ def generate(
prompt (str): The prompt for the text generation
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
temperature (float): The temperature for the text generation. Defaults to 0.7.
- report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None.
**kwargs (Any): Additional keyword arguments to be passed to this method.
Returns:
@@ -232,13 +230,9 @@ def generate(
outputs = self.model.generate(**generation_kwargs)
generated_text = self.tokenizer.decode(outputs[0], skip_prompt=True, skip_special_tokens=True)
- logger.debug("Response generation completed")
- if report_tokens:
- report_tokens(
- prompt_token_num=inputs.input_ids.shape[-1], # type: ignore
- completion_token_num=outputs[0].shape[-1], # type: ignore
- )
+
+ logger.debug("Response generation completed")
return generated_text
@@ -247,7 +241,6 @@ async def generate_async(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
- report_tokens: Optional[Callable[[str], None]] = None,
**kwargs: Any
) -> AsyncIterable:
"""
@@ -257,7 +250,6 @@ async def generate_async(
prompt (str): The prompt for the text generation.
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
temperature (float): The temperature for the text generation. Defaults to 0.7.
- report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None.
**kwargs (Any): Additional keyword arguments to be passed to the model loader.
Returns:
@@ -287,20 +279,9 @@ async def generate_async(
try:
_ = self._text_generator.submit(self.model.generate, **generation_kwargs)
- output = ""
for content in streamer:
yield content
- output += content
await asyncio.sleep(0.01)
- if report_tokens:
- report_tokens(
- prompt_token_num=inputs.input_ids.shape[-1], # type: ignore
- completion_token_num=self.tokenizer( # type: ignore
- output,
- add_special_tokens=False,
- return_tensors="pt"
- ).input_ids.shape[-1],
- )
except Exception as e:
logger.error("An error occurred while generating the response")
logger.exception(e)
diff --git a/app/utils.py b/app/utils.py
index 23968b2..705539b 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -694,7 +694,7 @@ def dump_pydantic_object_to_dict(model: BaseModel) -> Dict:
"""
if hasattr(model, "model_dump"):
- return model.model_dump(mode="json") # type: ignore
+ return model.model_dump() # type: ignore
elif hasattr(model, "dict"):
return model.dict() # type: ignore
else:
@@ -835,4 +835,3 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
"25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.',
"55540447": "linkage concept"
}
-
diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py
index 7cd4941..0262c41 100644
--- a/tests/app/model_services/test_huggingface_llm_model.py
+++ b/tests/app/model_services/test_huggingface_llm_model.py
@@ -46,7 +46,6 @@ def test_generate(huggingface_llm_model):
huggingface_llm_model.init_model()
huggingface_llm_model.model = MagicMock()
huggingface_llm_model.tokenizer = MagicMock()
- mock_send_metrics = MagicMock()
inputs = MagicMock()
inputs.input_ids = MagicMock(shape=[1, 2])
inputs.attention_mask = MagicMock()
@@ -59,7 +58,6 @@ def test_generate(huggingface_llm_model):
prompt="Alright?",
max_tokens=128,
temperature=0.5,
- report_tokens=mock_send_metrics
)
huggingface_llm_model.tokenizer.assert_called_once_with(
@@ -80,10 +78,6 @@ def test_generate(huggingface_llm_model):
skip_prompt=True,
skip_special_tokens=True,
)
- mock_send_metrics.assert_called_once_with(
- prompt_token_num=2,
- completion_token_num=2,
- )
assert result == "Yeah."
@@ -91,7 +85,6 @@ async def test_generate_async(huggingface_llm_model):
huggingface_llm_model.init_model()
huggingface_llm_model.model = MagicMock()
huggingface_llm_model.tokenizer = MagicMock()
- mock_send_metrics = MagicMock()
inputs = MagicMock()
inputs.input_ids = MagicMock(shape=[1, 2])
inputs.attention_mask = MagicMock()
@@ -104,7 +97,6 @@ async def test_generate_async(huggingface_llm_model):
prompt="Alright?",
max_tokens=128,
temperature=0.5,
- report_tokens=mock_send_metrics
)
huggingface_llm_model.tokenizer.assert_called_once_with(
@@ -125,8 +117,4 @@ async def test_generate_async(huggingface_llm_model):
skip_prompt=True,
skip_special_tokens=True,
)
- mock_send_metrics.assert_called_once_with(
- prompt_token_num=2,
- completion_token_num=2,
- )
assert result == "Yeah."
\ No newline at end of file
From 3b043eea03316f6f951a00d15cecb54b18ac7b32 Mon Sep 17 00:00:00 2001
From: Xi Bai <82581439+baixiac@users.noreply.github.com>
Date: Wed, 17 Sep 2025 15:52:05 +0100
Subject: [PATCH 9/9] Revert "feat: add the endpoint compatible with OpenAI
client protocols"
This reverts commit bcea8fe447795ea0fd4cd8d61904b877e518d42e.
---
app/api/routers/generative.py | 150 +-----------------
app/api/utils.py | 39 ++---
app/domain.py | 27 ----
app/model_services/huggingface_llm_model.py | 26 +--
app/utils.py | 57 +------
pyproject.toml | 4 +-
tests/app/api/test_serving_hf_llm.py | 64 +-------
.../test_huggingface_llm_model.py | 77 +--------
tests/app/test_utils.py | 49 +-----
uv.lock | 133 +++-------------
10 files changed, 60 insertions(+), 566 deletions(-)
diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py
index 5f9ff9f..007eb0b 100644
--- a/app/api/routers/generative.py
+++ b/app/api/routers/generative.py
@@ -1,24 +1,16 @@
-import json
import logging
-import time
-import uuid
import app.api.globals as cms_globals
-from typing import Union, Iterable, AsyncGenerator
from typing_extensions import Annotated
from fastapi import APIRouter, Depends, Request, Body, Query
-from fastapi.encoders import jsonable_encoder
-from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
-from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
-from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole
+from fastapi.responses import PlainTextResponse, StreamingResponse
+from app.domain import Tags
from app.model_services.base import AbstractModelService
-from app.utils import get_settings, get_prompt_from_messages
+from app.utils import get_settings
from app.api.utils import get_rate_limiter
-from app.api.dependencies import validate_tracking_id
PATH_GENERATE = "/generate"
PATH_GENERATE_ASYNC = "/stream/generate"
-PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
router = APIRouter()
config = get_settings()
@@ -39,8 +31,6 @@ def generate_text(
request: Request,
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512,
- temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7,
- tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> PlainTextResponse:
"""
@@ -50,27 +40,13 @@ def generate_text(
request (Request): The request object.
prompt (str): The prompt to be sent to the model.
max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature of the generated text.
- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.
Returns:
PlainTextResponse: A response containing the generated text.
"""
- tracking_id = tracking_id or str(uuid.uuid4())
- if prompt:
- return PlainTextResponse(
- model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature),
- headers={"x-cms-tracking-id": tracking_id},
- status_code=HTTP_200_OK,
- )
- else:
- return PlainTextResponse(
- _empty_prompt_error(),
- headers={"x-cms-tracking-id": tracking_id},
- status_code=HTTP_400_BAD_REQUEST,
- )
+ return PlainTextResponse(model_service.generate(prompt, max_tokens=max_tokens))
@router.post(
@@ -84,8 +60,6 @@ async def generate_text_stream(
request: Request,
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512,
- temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7,
- tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> StreamingResponse:
"""
@@ -95,121 +69,13 @@ async def generate_text_stream(
request (Request): The request object.
prompt (str): The prompt to be sent to the model.
max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature of the generated text.
- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.
Returns:
StreamingResponse: A streaming response containing the text generated in near real-time.
"""
- tracking_id = tracking_id or str(uuid.uuid4())
- if prompt:
- return StreamingResponse(
- model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature),
- media_type="text/event-stream",
- headers={"x-cms-tracking-id": tracking_id},
- status_code=HTTP_200_OK,
- )
- else:
- return StreamingResponse(
- _empty_prompt_error(),
- media_type="text/event-stream",
- headers={"x-cms-tracking-id": tracking_id},
- status_code=HTTP_400_BAD_REQUEST,
- )
-
-
-@router.post(
- "/v1/chat/completions",
- tags=[Tags.Generative.name],
- response_model=None,
- dependencies=[Depends(cms_globals.props.current_active_user)],
- description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions",
-)
-def generate_chat_completions(
- request: Request,
- request_data: Annotated[OpenAIChatRequest, Body(
- description="OpenAI-like completion request", media_type="application/json"
- )],
- tracking_id: Union[str, None] = Depends(validate_tracking_id),
- model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
-) -> Union[StreamingResponse, JSONResponse]:
- """
- Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
-
- Args:
- request (Request): The request object.
- request_data (OpenAIChatRequest): The request data containing model, messages, and stream.
- tracking_id (Union[str, None]): An optional tracking ID of the requested task.
- model_service (AbstractModelService): The model service dependency.
-
- Returns:
- StreamingResponse: A OpenAI-like response containing the text generated in near real-time.
- JSONResponse: A response containing an error message if the prompt messages are empty.
- """
-
- messages = request_data.messages
- stream = request_data.stream
- max_tokens = request_data.max_tokens
- temperature = request_data.temperature
-
- if not messages:
- error_response = {
- "error": {
- "message": "No prompt messages provided",
- "type": "invalid_request_error",
- "param": "messages",
- "code": "missing_field",
- }
- }
- return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST)
-
- async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
- data = {
- "id": tracking_id or str(uuid.uuid4()),
- "object": "chat.completion.chunk",
- "choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}],
- }
- yield f"data: {json.dumps(data)}\n\n"
- async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t):
- data = {
- "choices": [
- {
- "delta": {"content": chunk}
- }
- ],
- "object": "chat.completion.chunk",
- }
- yield f"data: {json.dumps(data)}\n\n"
- yield "data: [DONE]\n\n"
-
- prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore
- if stream:
- return StreamingResponse(
- _stream(prompt, max_tokens, temperature),
- media_type="text/event-stream"
- )
- else:
- generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature)
- completion = OpenAIChatResponse(
- id=str(uuid.uuid4()),
- object="chat.completion",
- created=int(time.time()),
- model=model_service.model_name,
- choices=[
- {
- "index": 0,
- "message": PromptMessage(
- role=PromptRole.ASSISTANT,
- content=generated_text,
- ),
- "finish_reason": "stop",
- }
- ]
- )
- return JSONResponse(content=jsonable_encoder(completion))
-
-
-def _empty_prompt_error() -> Iterable[str]:
- yield "ERROR: No prompt text provided\n"
+ return StreamingResponse(
+ model_service.generate_async(prompt, max_tokens=max_tokens),
+ media_type="text/event-stream"
+ )
diff --git a/app/api/utils.py b/app/api/utils.py
index ba99eed..a14714c 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -286,7 +286,6 @@ async def init_vllm_engine(app: FastAPI,
"""
try:
- # Import necessary vLLM components
from vllm.utils import FlexibleArgumentParser
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
@@ -299,19 +298,16 @@ async def init_vllm_engine(app: FastAPI,
)
from vllm import SamplingParams, TokensPrompt
except ImportError:
- # Raise a custom exception if vLLM is not installed
- raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.")
+ logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
parser = FlexibleArgumentParser()
parser = make_arg_parser(parser)
args = parser.parse_args([])
validate_parsed_serve_args(args)
-
args.model = model_dir_path
args.dtype = "float16"
args.served_model_name = [model_name]
- args.max_model_len = 2048 # The default batched length (2048) needs to be higher than max_model_len.
- # args.tokenizer = model_dir_path # Uncomment if your tokenizer is in a different path or needs explicit setting.
+ # args.tokenizer = model_dir_path
args.log_level = log_level
exit_stack = contextlib.AsyncExitStack()
@@ -321,11 +317,9 @@ async def init_vllm_engine(app: FastAPI,
disable_frontend_multiprocessing=True,
)
)
-
tokenizer = await engine.get_tokenizer()
vllm_config = await engine.get_vllm_config()
model_config = await engine.get_model_config()
-
await init_app_state(engine, vllm_config, app.state, args)
async def generate_text(
@@ -333,32 +327,27 @@ async def generate_text(
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512
) -> StreamingResponse:
- """
- Custom endpoint for streaming text generation.
- This endpoint takes a raw text prompt and streams back the generated text.
- It applies a chat template to the prompt internally for model compatibility.
- """
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
params = SamplingParams(max_tokens=max_tokens)
-
conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore
- prompt_tokens = apply_hf_chat_template( # type: ignore
- tokenizer,
- conversation=conversation,
- tools=None,
- add_generation_prompt=True,
- continue_final_message=False,
- chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
- tokenize=True,
+ prompt = TokensPrompt(
+ prompt_token_ids=apply_hf_chat_template( # type: ignore
+ tokenizer,
+ conversation=conversation,
+ tools=None,
+ add_generation_prompt=True,
+ continue_final_message=False,
+ chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
+ tokenize=True,
+ )
)
- prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens)
async def _stream() -> AsyncGenerator[bytes, None]:
start = 0
- async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt_obj, sampling_params=params):
+ async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params):
text = output.outputs[0].text
- yield text[start:].encode("utf-8")
+ yield text[start:] # type: ignore
start = len(text)
return StreamingResponse(_stream(), media_type="text/event-stream")
diff --git a/app/domain.py b/app/domain.py
index 29a779a..c9d38cf 100644
--- a/app/domain.py
+++ b/app/domain.py
@@ -167,30 +167,3 @@ class Doc(BaseModel):
text: str = Field(description="The text from which the entities are extracted")
ents: List[Entity] = Field(description="The list of extracted entities")
title: Optional[str] = Field(default=None, description="The headline of the text")
-
-
-class PromptRole(Enum):
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
-
-
-class PromptMessage(BaseModel):
- role: PromptRole = Field(description="The role who generates the message")
- content: str = Field(description="The actual text of the message")
-
-
-class OpenAIChatRequest(BaseModel):
- messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model")
- stream: bool = Field(..., description="Whether to stream the response")
- max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
- temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
-
-
-class OpenAIChatResponse(BaseModel):
- id: str
- object: str
- created: int
- model: str
- choices: List
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index 848b02c..566bbf9 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -168,8 +168,6 @@ def init_model(self) -> None:
logger.warning("Model service is already initialised and can be initialised only once")
else:
self._model, self._tokenizer = self.load_model(self._model_pack_path)
- if non_default_device_is_available(get_settings().DEVICE):
- self._model.to(get_settings().DEVICE)
if self._enable_trainer:
logger.error("Trainers are not yet implemented for HuggingFace Generative models")
@@ -193,20 +191,13 @@ def annotate(self, text: str) -> List[Annotation]:
def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace Generative models")
- def generate(
- self,
- prompt: str,
- max_tokens: int = 512,
- temperature: float = 0.7,
- **kwargs: Any
- ) -> str:
+ def generate(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> str:
"""
Generates text based on the prompt.
Args:
prompt (str): The prompt for the text generation
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
- temperature (float): The temperature for the text generation. Defaults to 0.7.
**kwargs (Any): Additional keyword arguments to be passed to this method.
Returns:
@@ -223,8 +214,8 @@ def generate(
inputs=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
- do_sample=False,
- temperature=temperature,
+ do_sample=True,
+ temperature=0.7,
top_p=0.9,
)
@@ -236,20 +227,13 @@ def generate(
return generated_text
- async def generate_async(
- self,
- prompt: str,
- max_tokens: int = 512,
- temperature: float = 0.7,
- **kwargs: Any
- ) -> AsyncIterable:
+ async def generate_async(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> AsyncIterable:
"""
Asynchronously generates text stream based on the prompt.
Args:
prompt (str): The prompt for the text generation.
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
- temperature (float): The temperature for the text generation. Defaults to 0.7.
**kwargs (Any): Additional keyword arguments to be passed to the model loader.
Returns:
@@ -273,7 +257,7 @@ async def generate_async(
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
- temperature=temperature,
+ temperature=0.7,
top_p=0.9,
)
diff --git a/app/utils.py b/app/utils.py
index 705539b..245da97 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -20,12 +20,12 @@
from spacy.lang.en import English
from spacy.util import filter_spans
from safetensors.torch import load_file
-from transformers import PreTrainedModel, PreTrainedTokenizer
+from transformers import PreTrainedModel
from urllib.parse import ParseResult
from functools import lru_cache
from typing import List, Optional, Dict, Callable, Any, Union, Type, TypeVar
from app.config import Settings
-from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole
+from app.domain import Annotation, Entity, CodeType, ModelType, Device
from app.exception import ManagedModelException
@@ -682,24 +682,6 @@ def load_pydantic_object_from_dict(model: Type[T], obj: Dict) -> T:
raise TypeError("Model must have a known method for parsing objects.")
-def dump_pydantic_object_to_dict(model: BaseModel) -> Dict:
- """
- Dumps the pydantic model object to a dictionary.
-
- Args:
- model (BaseModel): The pydantic model to dump.
-
- Returns:
- Dict: The dictionary object.
- """
-
- if hasattr(model, "model_dump"):
- return model.model_dump() # type: ignore
- elif hasattr(model, "dict"):
- return model.dict() # type: ignore
- else:
- raise TypeError("Model must have a known method for dumping objects.")
-
def download_model_package(
model_package_url: str,
destination_path: str,
@@ -739,41 +721,6 @@ def download_model_package(
retry_delay *= 2
-def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[PromptMessage]) -> str:
- """
- Generates a prompt from a list of prompt messages.
-
- Args:
- tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template.
- messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt.
-
- Returns:
- str: The generated prompt.
- """
- if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
- prompt = tokenizer.apply_chat_template(
- [dump_pydantic_object_to_dict(message) for message in messages],
- tokenize=False,
- add_generation_prompt=True,
- )
- else:
- system_content = ""
- prompt_parts: List[str] = []
- for message in messages:
- content = message.content.strip()
- if message.role == PromptRole.SYSTEM:
- system_content = content
- elif message.role == PromptRole.USER:
- prompt_parts.append(f"<|user|>\n{content}")
- elif message.role == PromptRole.ASSISTANT:
- prompt_parts.append(f"<|assistant|>\n{content}")
- if system_content:
- prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts)
- else:
- prompt = "\n".join(prompt_parts)
- prompt += "\n<|assistant|>\n"
- return prompt
-
TYPE_ID_TO_NAME_PATCH = {
"32816260": "physical object",
"2680757": "observable entity",
diff --git a/pyproject.toml b/pyproject.toml
index eaba606..01b8d3d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,7 +37,7 @@ dependencies = [
"pynvml~=11.5.3",
"toml~=0.10.2",
"peft<0.14.0",
- "huggingface-hub~=0.33.0",
+ "huggingface-hub~=0.32.0",
]
readme = "README.md"
keywords = ["natural-language-processing", "electronic-health-records", "clinical-data"]
@@ -71,7 +71,6 @@ dev = [
"locust<2.32.0",
"typer-cli~=0.15.1",
"types-toml==0.10.8.20240310",
- "openai>=1.84.0",
]
docs = [
"sphinx~=7.1.2",
@@ -100,7 +99,6 @@ dev = [
"locust<2.32.0",
"typer-cli~=0.15.1",
"types-toml==0.10.8.20240310",
- "openai>=1.84.0",
]
docs = [
"sphinx~=7.1.2",
diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py
index 4bd4e9d..39e82bf 100644
--- a/tests/app/api/test_serving_hf_llm.py
+++ b/tests/app/api/test_serving_hf_llm.py
@@ -1,10 +1,8 @@
import httpx
-import json
import pytest
import app.api.globals as cms_globals
from unittest.mock import create_autospec
-from fastapi.testclient import TestClient
from app.api.api import get_generative_server
from app.model_services.huggingface_llm_model import HuggingFaceLlmModel
from app.utils import get_settings
@@ -29,72 +27,14 @@ def llm_app(llm_model_service):
yield app
app.dependency_overrides.clear()
-@pytest.fixture(scope="function")
-def client(llm_model_service):
- llm_model_service.generate.return_value = "Yeah."
- app = get_generative_server(config, msd_overwritten=lambda: llm_model_service)
- app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None
- client = TestClient(app)
- yield client
- client.app.dependency_overrides.clear()
-
-
-def test_generate(client):
- response = client.post(
- "/generate?max_tokens=128&temperature=0.7",
- data="Alright?",
- headers={"Content-Type": "text/plain"},
- )
-
- assert response.status_code == 200
- assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing"
- assert response.headers["content-type"] == "text/plain; charset=utf-8"
- assert response.text == "Yeah."
-
@pytest.mark.asyncio
async def test_stream_generate(llm_model_service, llm_app):
- llm_model_service.generate_async.return_value = "Fine."
async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac:
response = await ac.post(
- "/stream/generate?max_tokens=32&temperature=0.7",
+ "/stream/generate?max_tokens=32",
data="How are you doing?",
headers={"Content-Type": "text/plain"},
)
- assert response.status_code == 200
- assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing"
- assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
- assert response.text == "Fine."
-
-
-@pytest.mark.asyncio
-async def test_generate_chat_completions(llm_model_service, llm_app):
- llm_model_service.generate.return_value = "I'm a chat bot."
- request_data = {
- "messages": [
- {
- "role": "system",
- "content": "You are a chat bot."
- },
- {
- "role": "user",
- "content": "Who are you?"
- }
- ],
- "stream": True,
- "max_tokens": 128,
- "temperature": 0.7
- }
- async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac:
- response = await ac.post(
- "/v1/chat/completions?max_tokens=128&temperature=0.7",
- data=json.dumps(request_data),
- headers={"Content-Type": "application/json"},
- )
-
- assert response.status_code == 200
- assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
- assert response.text.startswith("data:")
- assert "id" in response.text
- assert "chat.completion.chunk" in response.text
+ assert response.status_code == 200
\ No newline at end of file
diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py
index 0262c41..1f33572 100644
--- a/tests/app/model_services/test_huggingface_llm_model.py
+++ b/tests/app/model_services/test_huggingface_llm_model.py
@@ -1,5 +1,4 @@
import os
-from unittest.mock import MagicMock
from tests.app.conftest import MODEL_PARENT_DIR
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from app import __version__
@@ -44,77 +43,5 @@ def test_info(huggingface_llm_model):
def test_generate(huggingface_llm_model):
huggingface_llm_model.init_model()
- huggingface_llm_model.model = MagicMock()
- huggingface_llm_model.tokenizer = MagicMock()
- inputs = MagicMock()
- inputs.input_ids = MagicMock(shape=[1, 2])
- inputs.attention_mask = MagicMock()
- huggingface_llm_model.tokenizer.return_value = inputs
- outputs = [MagicMock(shape=[2])]
- huggingface_llm_model.model.generate.return_value = outputs
- huggingface_llm_model.tokenizer.decode.return_value = "Yeah."
-
- result = huggingface_llm_model.generate(
- prompt="Alright?",
- max_tokens=128,
- temperature=0.5,
- )
-
- huggingface_llm_model.tokenizer.assert_called_once_with(
- "Alright?",
- add_special_tokens=False,
- return_tensors="pt",
- )
- huggingface_llm_model.model.generate.assert_called_once_with(
- inputs=inputs.input_ids,
- attention_mask=inputs.attention_mask,
- max_new_tokens=128,
- do_sample=False,
- temperature=0.5,
- top_p=0.9,
- )
- huggingface_llm_model.tokenizer.decode.assert_called_once_with(
- outputs[0],
- skip_prompt=True,
- skip_special_tokens=True,
- )
- assert result == "Yeah."
-
-
-async def test_generate_async(huggingface_llm_model):
- huggingface_llm_model.init_model()
- huggingface_llm_model.model = MagicMock()
- huggingface_llm_model.tokenizer = MagicMock()
- inputs = MagicMock()
- inputs.input_ids = MagicMock(shape=[1, 2])
- inputs.attention_mask = MagicMock()
- huggingface_llm_model.tokenizer.return_value = inputs
- outputs = [MagicMock(shape=[2])]
- huggingface_llm_model.model.generate.return_value = outputs
- huggingface_llm_model.tokenizer.decode.return_value = "Yeah."
-
- result = await huggingface_llm_model.generate_async(
- prompt="Alright?",
- max_tokens=128,
- temperature=0.5,
- )
-
- huggingface_llm_model.tokenizer.assert_called_once_with(
- "Alright?",
- add_special_tokens=False,
- return_tensors="pt",
- )
- huggingface_llm_model.model.generate_async.assert_called_once_with(
- inputs=inputs.input_ids,
- attention_mask=inputs.attention_mask,
- max_new_tokens=128,
- do_sample=False,
- temperature=0.5,
- top_p=0.9,
- )
- huggingface_llm_model.tokenizer.decode.assert_called_once_with(
- outputs[0],
- skip_prompt=True,
- skip_special_tokens=True,
- )
- assert result == "Yeah."
\ No newline at end of file
+ output = huggingface_llm_model.generate("How are you doing?")
+ assert isinstance(output, str)
diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py
index 470d0d8..4519350 100644
--- a/tests/app/test_utils.py
+++ b/tests/app/test_utils.py
@@ -6,7 +6,7 @@
import zipfile
import tarfile
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from safetensors.torch import save_file
from transformers import PreTrainedModel
from urllib.parse import urlparse
@@ -33,10 +33,8 @@
pyproject_dependencies_to_pip_requirements,
get_model_data_package_base_name,
load_pydantic_object_from_dict,
- dump_pydantic_object_to_dict,
- get_prompt_from_messages,
)
-from app.domain import Annotation, Entity, PromptMessage, PromptRole
+from app.domain import Annotation, Entity
def test_get_code_base_uri():
@@ -395,46 +393,3 @@ def __init__(self):
def forward(self, x):
return self.linear(x)
-
-
-def test_get_prompt_with_chat_template():
- with patch('transformers.PreTrainedTokenizer') as tok:
- mock_tokenizer = tok.return_value
- mock_tokenizer.chat_template = "Mock chat template"
- mock_tokenizer.apply_chat_template.return_value = "Mock chat template applied"
- messages = [
- PromptMessage(content="Alright?", role=PromptRole.USER.value),
- PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value),
- ]
-
- prompt = get_prompt_from_messages(mock_tokenizer, messages)
-
- assert prompt == "Mock chat template applied"
-
-
-def test_get_prompt_without_chat_template():
- with patch('transformers.PreTrainedTokenizer') as tok:
- mock_tokenizer = tok.return_value
- mock_tokenizer.chat_template = None
- messages = [
- PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value),
- PromptMessage(content="Alright?", role=PromptRole.USER.value),
- PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value),
- ]
-
- prompt = get_prompt_from_messages(mock_tokenizer, messages)
-
- expected_prompt = "<|system|>\nYou are a helpful assistant.\n<|user|>\nAlright?\n<|assistant|>\nYeah.\n<|assistant|>\n"
- assert prompt == expected_prompt
-
-
-def test_get_prompt_with_no_messages():
- with patch('transformers.PreTrainedTokenizer') as tok:
- mock_tokenizer = tok.return_value
- mock_tokenizer.chat_template = None
- messages = []
-
- prompt = get_prompt_from_messages(mock_tokenizer, messages)
-
- expected_prompt = "\n<|assistant|>\n"
- assert prompt == expected_prompt
\ No newline at end of file
diff --git a/uv.lock b/uv.lock
index cc97b01..0b032d7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1215,7 +1215,6 @@ dev = [
{ name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
{ name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
{ name = "mypy" },
- { name = "openai" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-bdd" },
@@ -1243,7 +1242,6 @@ dev = [
{ name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
{ name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
{ name = "mypy" },
- { name = "openai" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-bdd" },
@@ -1279,14 +1277,13 @@ requires-dist = [
{ name = "fastapi-users-db-sqlalchemy", specifier = "~=5.0.0" },
{ name = "graypy", specifier = "~=2.1.0" },
{ name = "httpx", marker = "extra == 'dev'", specifier = "~=0.24.1" },
- { name = "huggingface-hub", specifier = "~=0.33.0" },
+ { name = "huggingface-hub", specifier = "~=0.32.0" },
{ name = "ijson", specifier = "~=3.1.4" },
{ name = "locust", marker = "extra == 'dev'", specifier = "<2.32.0" },
{ name = "medcat", marker = "python_full_version < '3.9'", specifier = "~=1.13.1" },
{ name = "medcat", marker = "python_full_version >= '3.9'", specifier = "~=1.16.0" },
{ name = "mlflow", specifier = "~=2.16.2" },
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.14.0" },
- { name = "openai", marker = "extra == 'dev'", specifier = ">=1.84.0" },
{ name = "peft", specifier = "<0.14.0" },
{ name = "prometheus-fastapi-instrumentator", specifier = "~=7.0.0" },
{ name = "psycopg2-binary", specifier = "~=2.9.4" },
@@ -1323,7 +1320,6 @@ dev = [
{ name = "httpx", specifier = "~=0.24.1" },
{ name = "locust", specifier = "<2.32.0" },
{ name = "mypy", specifier = "~=1.14.0" },
- { name = "openai", specifier = ">=1.84.0" },
{ name = "pytest", specifier = "~=7.1.2" },
{ name = "pytest-asyncio", specifier = "~=0.23.7" },
{ name = "pytest-bdd", specifier = "~=7.2.0" },
@@ -3267,7 +3263,7 @@ wheels = [
[[package]]
name = "huggingface-hub"
-version = "0.33.5"
+version = "0.32.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
@@ -3280,9 +3276,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/02/16/5716d03e2b48bcc8e32d9b18ed7e55d2ae52e3d5df146cced9fe0581b5ff/huggingface_hub-0.33.5.tar.gz", hash = "sha256:814097e475646d170c44be4c38f7d381ccc4539156a5ac62a54f53aaf1602ed8", size = 427075 }
+sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/33/d5/d9e9b75d8dc9cf125fff16fb0cd51d864a29e8b46b6880d8808940989405/huggingface_hub-0.33.5-py3-none-any.whl", hash = "sha256:29b4e64982c2064006021af297e1b17d44c85a8aaf90a0d7efeff7e7d2426296", size = 515705 },
+ { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101 },
]
[package.optional-dependencies]
@@ -3624,88 +3620,10 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 },
]
-[[package]]
-name = "jiter"
-version = "0.9.1"
-source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version < '3.9' and sys_platform != 'win32'",
- "python_full_version < '3.9' and sys_platform == 'win32'",
-]
-sdist = { url = "https://files.pythonhosted.org/packages/84/72/c28662416d9807bb5a38625eadedb82d4bd14fd2700c308ece7acdb8e89f/jiter-0.9.1.tar.gz", hash = "sha256:7852990068b6e06102ecdc44c1619855a2af63347bfb5e7e009928dcacf04fdd", size = 162540 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/2b/5f/7f6aaca7943c644b4fd220650771f39dbfb74f9690efc6fb8c0d4092a399/jiter-0.9.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c0163baa7ee85860fdc14cc39263014500df901eeffdf94c1eab9a2d713b2a9d", size = 312882 },
- { url = "https://files.pythonhosted.org/packages/86/0d/aac9eafc5d46bdf5c4f127ac1ce85e434d003bb5e3ae886f5e726a988cf6/jiter-0.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:514d4dd845e0af4da15112502e6fcb952f0721f27f17e530454e379472b90c14", size = 311743 },
- { url = "https://files.pythonhosted.org/packages/b8/54/fab1f4d8634af7bb1ad6dc49bee50ea9f649de0e5309c80192ace739f968/jiter-0.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b879faee1cc1a67fde3f3f370041239fd260ac452bd53e861aa4a94a51e3fd02", size = 1085889 },
- { url = "https://files.pythonhosted.org/packages/bd/86/bf4ed251d8035d5d72a46c8f9969bd5054fad052371cbea0cb161060e660/jiter-0.9.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20a5ce641f93bfb8d8e336f8c4a045e491652f41eaacc707b15b245ece611e72", size = 1117896 },
- { url = "https://files.pythonhosted.org/packages/62/40/b04c40deccd5edd5f2a3853f4a80dc0ddbe157d1d523a573fb3d224315fc/jiter-0.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8575b1d2b49df04ca82d658882f4a432b7ed315a69126a379df4d10aeb416021", size = 1211956 },
- { url = "https://files.pythonhosted.org/packages/85/f0/114e9893e4ef5b423718efe9b3da01117539c333f06ef19543c68c8b7ed1/jiter-0.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc61831699904e0c58e82943f529713833db87acd13f95a3c0feb791f862d47b", size = 1219691 },
- { url = "https://files.pythonhosted.org/packages/02/9a/1aeac4541ce1c59c65dc76dbab642232da3d8db0581df3e61b8943033bd7/jiter-0.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb733faf4d0e730d6663873249c1acb572fc8bd9dae3836ceda69751f27c5be", size = 352604 },
- { url = "https://files.pythonhosted.org/packages/6b/27/446ec6ca0a25d9d2f45ad546633a2b4a1b6a7f28fb6819c7056b163c5aee/jiter-0.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d903b3bb917c0df24f2ef62f587c8f32f6003cb2f97264109ca56c023262557f", size = 1147136 },
- { url = "https://files.pythonhosted.org/packages/09/9d/c8540bc097b07e106d060c21395c6fa6561223e7366c948a04ef0aa39979/jiter-0.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:eac3eb5206845b170142c016ae467eca523a25459dc9c53fcd8e154ea263406c", size = 1255843 },
- { url = "https://files.pythonhosted.org/packages/d3/61/9b377ecf4e09e325e90f77a7a4859ec933162f58ff5c6b7730aff6352033/jiter-0.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7ea0c20cfc61acc5335bb8ee36d639e6a4ded03f34f878b2b3038bb9f3bb553c", size = 1257536 },
- { url = "https://files.pythonhosted.org/packages/ed/f6/b6754e11ac9d02f05a2d713c0846ce813a69c1f6f7de7f1ae216c4e35ace/jiter-0.9.1-cp310-cp310-win32.whl", hash = "sha256:0f8f812dd6d2b4112db9ab4c1079c4fe73e553a500e936657fdda394fa2517e1", size = 214064 },
- { url = "https://files.pythonhosted.org/packages/1d/cb/7b9c5d6f73499d1fb5e97e36e8078f3bea00d7541a973117eccf9db1e079/jiter-0.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:f7f0198889170e7af6210509803e6527b402efc6c26f42e2896883597a10426f", size = 209952 },
- { url = "https://files.pythonhosted.org/packages/ee/3b/9f9deaef471e346354c832b6627e0d1b9ba3d9611d0e0fd394c2acf2a615/jiter-0.9.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b8564e3198c4c8d835fc95cc54d6bcbd2fd8dc33a047fecc12c208491196995", size = 312737 },
- { url = "https://files.pythonhosted.org/packages/36/00/76fa6d519f8289aad32ec1caf3716eb700ba48e3212d1dda71e74c385a5c/jiter-0.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:90b92044588d14efe89b394eca735adc4ac096eba82dc75d93c3083b1eebce8d", size = 313357 },
- { url = "https://files.pythonhosted.org/packages/b3/e9/f864ebe9ddf07761d5bdd3148b45a5d433c6cbce7c7e8be29baf806fa612/jiter-0.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3505f7f419b355c7788fcaae0dfc4c6ccbc50c0dc3633a2da797e841c5a423dc", size = 1085946 },
- { url = "https://files.pythonhosted.org/packages/82/a1/ed02d4c86d620989dcd392366daa67198961eedaf2e66f7a68f0d3846dba/jiter-0.9.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93af8c3f4a3bf145c690e857a945eb5c655534bf95c67e1447d85c02e5af64d7", size = 1118090 },
- { url = "https://files.pythonhosted.org/packages/d3/01/d107531d215a57cda3cbc4adfcf3119166dd32adc1c332c1f3f36efd3484/jiter-0.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43b81dd21e260a249780764921b1f9a6379cb31e24e7b61e6bf0799f38ec4b91", size = 1212231 },
- { url = "https://files.pythonhosted.org/packages/45/1e/6801a81a2ef1f917fe9a7d2139e576dd4f53497c309dab9461136922709c/jiter-0.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:db639fad5631b3d1692609f6dd77b64e8578321b7aeec07a026acd2c867c04a5", size = 1219263 },
- { url = "https://files.pythonhosted.org/packages/a5/d4/40082e8666cfdb24461855e9bb29fe77f063cc65a6c903291f2e5225f780/jiter-0.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15356b943e70ca7ab3b587ffaffadc0158467f6c4e0b491e52a0743c4bdf5ba1", size = 350364 },
- { url = "https://files.pythonhosted.org/packages/c4/09/09bc72dd143f76acd55e04c3a45b9f9ee3ed28e00b49924e3702ad041812/jiter-0.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53a7033a46141ff815518a6972d657c75d8f5946b9315e1c25b07e9677c1ff6c", size = 1146802 },
- { url = "https://files.pythonhosted.org/packages/5b/34/9d15a9c04d5760537b432134447bde94b936ec73dc922b4d14a48def2e1f/jiter-0.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:68cf519a6f00b8127f9be64a37e97e978094438abced5adebe088a98c64bdcff", size = 1256019 },
- { url = "https://files.pythonhosted.org/packages/8f/01/1fcd165fb28968a54bb46a209d5919f7649b96608eef7dc4622ea378b95a/jiter-0.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9098abdd34cd9ddeb04768cc4f5fc725ebd9a52978c488da74e58a837ce93506", size = 1257610 },
- { url = "https://files.pythonhosted.org/packages/9f/87/93ac6a57331dd90e4c896ac852bf8ce6b28b40dace4b9698a207dbb99af2/jiter-0.9.1-cp311-cp311-win32.whl", hash = "sha256:7179ce96aecd096af890dd57b84133e47a59fbde32a77734f09bafa6a4da619e", size = 214515 },
- { url = "https://files.pythonhosted.org/packages/bb/ee/3678b8a3bd5f6471d0a492540e7ff9c63db278d844214458ec5cfb22adb2/jiter-0.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:e6517f5b7b6f60fd77fc1099572f445be19553c6f61b907ab5b413fb7179663f", size = 212258 },
- { url = "https://files.pythonhosted.org/packages/26/ca/1c7438d66969a13938266492de65daf752754ec59f2a3f3716027c7d708f/jiter-0.9.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:95065923a49ae387bab62b1bf5f798beb12e6fb4469a079fdd0ecad64b40b272", size = 313516 },
- { url = "https://files.pythonhosted.org/packages/e8/d9/3a6300309e312f8ed529ae57d565f69abdb520e4f12460cefa7996d0716c/jiter-0.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a179fbc5c7922844a673be35099a3036a7276dc63753c6c81a77c3cb525f2f8d", size = 308161 },
- { url = "https://files.pythonhosted.org/packages/b3/91/2aca15be38514daf8f1a1460fd9c4b652ed09148fe109520298858be7928/jiter-0.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abd30dc5c0183d31faf30ce8279d723809c54b3fe6d95d922d4a4b31bc462799", size = 1086100 },
- { url = "https://files.pythonhosted.org/packages/9f/6f/f7ba3dfe7be08bf58939324e0bb4f4aa605eff7f2c2ac140a41221cf50a4/jiter-0.9.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9765512bdeae269843e6615377f48123432da247e18048d05e9c5685377c241c", size = 1118922 },
- { url = "https://files.pythonhosted.org/packages/b5/4e/b1f4d9bdba81de293e1b8672598300a9195cf3d77b0acc5f331a75695b58/jiter-0.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f15cdbdc1e1e89e0d9ea581de63e03975043a4b40ab87d5554fdc440357b771", size = 1212327 },
- { url = "https://files.pythonhosted.org/packages/3e/ab/e417aaf5a62067bd91c5f7ed4e5ab83bd46f349449adde1159ad8e2d3a21/jiter-0.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1a639b2cfe56b5b687c678ed45d68f46dfb922c2f338fdfb227eb500053929d", size = 1220860 },
- { url = "https://files.pythonhosted.org/packages/1e/50/c5ba756c641ca8ebc1e4ff07c03ce5c8ef5052b0238f514436f8de3c9fc4/jiter-0.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41955c9d83c8470de9cc64c97b04a3ffd2f32815bb2c4307f44d8e21542b74df", size = 344077 },
- { url = "https://files.pythonhosted.org/packages/c6/b3/bd7d8d4bad65aa1f4a20562233080054149785c0d7f7b9027e761335d882/jiter-0.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f26f6d42c330e26a6ba3471b390364faad96f3ca965a6c579957810b0c078efa", size = 1148785 },
- { url = "https://files.pythonhosted.org/packages/c0/12/bfd9a167709f96171312d1e0ae2c1be70a167abcc3bff6f3441967e3626a/jiter-0.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a23e01bd7e918f27f02d3df8721b8a395211070a8a65aeb353209b8c72720cf", size = 1255962 },
- { url = "https://files.pythonhosted.org/packages/5f/3c/3a79020862d2511b854b350bc9229cf228fd38b836e94f274ca940e22e95/jiter-0.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8a96ad217989dd9df661711c3fa2e6fb2601c4bbb482e89718110bdafbc16c9e", size = 1257561 },
- { url = "https://files.pythonhosted.org/packages/93/d3/7f6f8e57613d4947a872980befa6af19de9252e310ea4a512eed0fe1e064/jiter-0.9.1-cp38-cp38-win32.whl", hash = "sha256:4b180e7baa4747b3834c5a9202b1ba30dc64797f45236d9142cdb2a8807763cf", size = 215019 },
- { url = "https://files.pythonhosted.org/packages/9b/5d/b6f0cd60c8f702936f253644a92dee19e2c82010290e4607af462033351f/jiter-0.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:baf881de1fbc7b3343cce24f75a2ab6350e03fc13d16d00f452929788a6cdc3f", size = 199563 },
- { url = "https://files.pythonhosted.org/packages/4f/3a/a8a4768af26578c87894bb130bcd6fb6c97f4cb36ed7a20a664412d41935/jiter-0.9.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ec95aa1b433c50b2b129456b4680b239ec93206ea3f86cfd41b6a70be5beb2f3", size = 313942 },
- { url = "https://files.pythonhosted.org/packages/63/74/05977891db48000d985a5f573493c43adf0f190eada670e51b92c9ed9139/jiter-0.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5d92cb50d135dbdd33b638fa2e0c6af25e1d635d38da13aa9ab05d021fb0c869", size = 308160 },
- { url = "https://files.pythonhosted.org/packages/21/54/75f529e90442c8ad41acd8cf08323a4f3dcaa105710b2c8a1fda56e3a462/jiter-0.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b146dc2464f1d96007271d08bdf79288a5f1aa4aae5329eb79dcffb1181c703e", size = 1086503 },
- { url = "https://files.pythonhosted.org/packages/bf/fa/02532a7ce7b712c576125d4f2614e77bc897c95b2b15e21ee25f42b3ff34/jiter-0.9.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcf20ba858658ecd54b4710172d92009afa66d41d967c86d11607592a3c220fa", size = 1120444 },
- { url = "https://files.pythonhosted.org/packages/91/c2/ab8cebaea6f2691eddcc5b6c67deb1399adbd85f12ad836f7cd77be78bf8/jiter-0.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:147fccc44bebdb672d4c601e9312730488b840d415e201e89c8ea0929a63dacf", size = 1212370 },
- { url = "https://files.pythonhosted.org/packages/13/e3/90dddb7877b67cc0e1ddb864c2ca74314def26ff6542431a6e3061e0f805/jiter-0.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a428061aae26efaa6fb690ef9e7d6224aefe4eef7524165d073beb3cdad75f6f", size = 1221210 },
- { url = "https://files.pythonhosted.org/packages/81/76/90ee847519a94a4a1a8bad7addce7019f424aea03c55eacf068469226760/jiter-0.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7164d92bb901784bd3c098ac0b0beae4306ea6c741dbd3a375449a8affc5366", size = 353774 },
- { url = "https://files.pythonhosted.org/packages/59/a6/614a5d672d4b9c6bc9ad34579f0522577a0a78cc265069fca96543a832ca/jiter-0.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:93049a562233808914a2b938b0c745d7049db1667b3f42f0f5cf48e617393ba5", size = 1148581 },
- { url = "https://files.pythonhosted.org/packages/2d/94/c100147c310361fa83e25c4c6ce17723532147580252962b89e6085795c2/jiter-0.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f6dcf2cb16cc15d82a018e20eeaf169e6f6cd8c426f4c312ebe11710c623bed2", size = 1256636 },
- { url = "https://files.pythonhosted.org/packages/51/9a/dc82e218ba839052899df555e34f16b8ad1d7da9c01be208f65a5bf0083c/jiter-0.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2da9d485a7c526817cde9ff8b3394fa50ff5b782b86b6896378a3ba8844550f2", size = 1258099 },
- { url = "https://files.pythonhosted.org/packages/58/d5/d853e069624038950265ac0e877985b249049b624e925dab6cd11035140c/jiter-0.9.1-cp39-cp39-win32.whl", hash = "sha256:ea58c155d827d24e5ba8d7958ec4738b26be0894c0881a91d88b39ff48bb06c9", size = 214611 },
- { url = "https://files.pythonhosted.org/packages/cb/8d/7b6b1ee6e3d9d1a06237bbdfe4c6bb21baf323d3f70a0cc8f203de40c6b2/jiter-0.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:be2e911ecdb438951290c2079fe4190e7cc5be9e849df4caeb085b83ed620ff6", size = 211171 },
-]
-
[[package]]
name = "jiter"
version = "0.10.0"
source = { registry = "https://pypi.org/simple" }
-resolution-markers = [
- "python_full_version > '3.11' and sys_platform == 'darwin'",
- "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.11' and sys_platform == 'darwin'",
- "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.10.*' and sys_platform == 'darwin'",
- "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'",
- "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
- "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
- "python_full_version > '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.11' and sys_platform == 'win32'",
- "python_full_version == '3.10.*' and sys_platform == 'win32'",
- "python_full_version == '3.9.*' and sys_platform == 'win32'",
-]
sdist = { url = "https://files.pythonhosted.org/packages/ee/9d/ae7ddb4b8ab3fb1b51faf4deb36cb48a4fbbd7cb36bad6a5fca4741306f7/jiter-0.10.0.tar.gz", hash = "sha256:07a7142c38aacc85194391108dc91b5b57093c978a9932bd86a36862759d9500", size = 162759 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/7e/4011b5c77bec97cb2b572f566220364e3e21b51c48c5bd9c4a9c26b41b67/jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303", size = 317215 },
@@ -5586,7 +5504,7 @@ version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", version = "12.1.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9' and sys_platform != 'win32'" },
- { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -5614,7 +5532,7 @@ resolution-markers = [
"(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
]
dependencies = [
- { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
@@ -5672,9 +5590,9 @@ resolution-markers = [
"(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
]
dependencies = [
- { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
- { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
- { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
+ { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
+ { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
@@ -5705,7 +5623,7 @@ resolution-markers = [
"(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')",
]
dependencies = [
- { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
@@ -5799,17 +5717,14 @@ name = "openai"
version = "1.84.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "anyio", version = "4.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
{ name = "anyio", version = "4.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "distro" },
- { name = "httpx" },
- { name = "jiter", version = "0.9.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
- { name = "jiter", version = "0.10.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "pydantic", version = "1.10.22", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },
+ { name = "distro", marker = "python_full_version >= '3.9'" },
+ { name = "httpx", marker = "python_full_version >= '3.9'" },
+ { name = "jiter", marker = "python_full_version >= '3.9'" },
{ name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
- { name = "sniffio" },
- { name = "tqdm" },
- { name = "typing-extensions" },
+ { name = "sniffio", marker = "python_full_version >= '3.9'" },
+ { name = "tqdm", marker = "python_full_version >= '3.9'" },
+ { name = "typing-extensions", marker = "python_full_version >= '3.9'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/91/a3/128caf24e116f48fad3e4d5122cdf84db06c5127911849d51663c66158c8/openai-1.84.0.tar.gz", hash = "sha256:4caa43bdab262cc75680ce1a2322cfc01626204074f7e8d9939ab372acf61698", size = 467066 }
wheels = [
@@ -9890,8 +9805,8 @@ name = "xformers"
version = "0.0.29.post2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
- { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/27/ed/04ec7ef97a7e1c836add41ef5a2aef8cbdd45c0190ca42cc08f3c21e2b7b/xformers-0.0.29.post2.tar.gz", hash = "sha256:6ca3d1a6db6f2abff25c1154adee96987f77f4dfd5141771805afa5fc13e9395", size = 8468494 }
wheels = [
@@ -9905,12 +9820,12 @@ name = "xgrammar"
version = "0.1.18"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "ninja", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
- { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
- { name = "sentencepiece", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
- { name = "tiktoken", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
- { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
- { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" },
+ { name = "ninja", marker = "python_full_version >= '3.9'" },
+ { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
+ { name = "sentencepiece", marker = "python_full_version >= '3.9'" },
+ { name = "tiktoken", marker = "python_full_version >= '3.9'" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
+ { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" },
{ name = "triton", version = "3.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8f/c3/22c9eeab6ee1dd6d0513d227e9d307fd20a0491db58f1f04bc5d566d13dc/xgrammar-0.1.18.tar.gz", hash = "sha256:a0438a0f9262fff1d0e4f184268eb759f094243edce92b67eb7aa5f245c47471", size = 1697230 }