From 1477deac2d421c15b06bc948e9184c2c18edffa7 Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 1/9] Revert "feat: add the 4-bit quantisation option and remove unnecessary base model copying" This reverts commit d03f4875d323d005b43d6e3e9e60abf1c93ff07d. --- .github/workflows/main.yaml | 2 +- app/cli/cli.py | 8 +--- app/model_services/base.py | 6 +-- app/model_services/huggingface_llm_model.py | 12 ++--- app/model_services/huggingface_ner_model.py | 9 +--- app/model_services/medcat_model.py | 9 +--- app/model_services/medcat_model_deid.py | 9 +--- app/model_services/trf_model_deid.py | 2 +- app/trainers/huggingface_llm_trainer.py | 52 +++++++++------------ app/utils.py | 5 ++ 10 files changed, 42 insertions(+), 72 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 6c3a9dd..ffce1f8 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -24,7 +24,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - uv sync --group dev --group docs + uv sync --group dev --group docs --group vllm - name: Check types run: | uv run mypy app diff --git a/app/cli/cli.py b/app/cli/cli.py index 6003407..8a94647 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -67,7 +67,6 @@ def serve_model( streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"), device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"), llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"), - load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"), debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), ) -> None: """ @@ -85,7 +84,6 @@ def serve_model( streamable (bool): Serve the streamable endpoints only. Defaults to False. device (Device): The device to serve the model on. Defaults to Device.DEFAULT. llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS. - load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False. debug (Optional[bool]): Run in debug mode if set to True. """ @@ -137,7 +135,7 @@ def serve_model( if model_path: model_service = model_service_dep() model_service.model_name = model_name - model_service.init_model(load_in_4bit=load_in_4bit) + model_service.init_model() cms_globals.model_manager_dep = ModelManagerDep(model_service) elif mlflow_model_uri: model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) @@ -189,7 +187,6 @@ def train_model( description: Optional[str] = typer.Option(None, help="The description of the training or change logs"), model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"), device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"), - load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"), debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), ) -> None: """ @@ -209,7 +206,6 @@ def train_model( description (Optional[str]): The optional description of the training or change logs. model_name (Optional[str]): The optional string representation of the model name. device (Device): The device to train the model on. Defaults to Device.DEFAULT. - load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False. debug (Optional[bool]): Run in debug mode if set to True. """ @@ -233,7 +229,7 @@ def train_model( pass model_service = model_service_dep() model_service.model_name = model_name if model_name is not None else "CMS model" - model_service.init_model(load_in_4bit=load_in_4bit) + model_service.init_model() elif mlflow_model_uri: model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) model_service.model_name = model_name if model_name is not None else "CMS model" diff --git a/app/model_services/base.py b/app/model_services/base.py index a7b6323..dfde491 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -154,14 +154,10 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: raise NotImplementedError @abstractmethod - def init_model(self, *args: Any, **kwargs: Any) -> None: + def init_model(self) -> None: """ Initialises the model and auxiliary resources. - Args: - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - Raises: NotImplementedError: If the method is not implemented by the subclass. """ diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index a747739..98f9bea 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -174,14 +174,8 @@ def load_model( else: raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}") - def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> None: - """Initialises the HuggingFace model and its tokenizer based on the configuration. - - Args: - load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - """ + def init_model(self) -> None: + """Initialises the HuggingFace model and its tokenizer based on the configuration.""" if all([ hasattr(self, "_model"), @@ -191,7 +185,7 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N ]): logger.warning("Model service is already initialised and can be initialised only once") else: - self._model, self._tokenizer = self.load_model(self._model_pack_path, load_in_4bit=load_in_4bit) + self._model, self._tokenizer = self.load_model(self._model_pack_path) if non_default_device_is_available(get_settings().DEVICE): self._model.to(get_settings().DEVICE) if self._enable_trainer: diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index e741705..d982836 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -175,13 +175,8 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> else: raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}") - def init_model(self, *args: Any, **kwargs: Any) -> None: - """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration. - - Args: - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - """ + def init_model(self) -> None: + """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration.""" if all([ hasattr(self, "_model"), diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index 9ab5235..a3a6f2c 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -119,13 +119,8 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> else: raise ConfigurationException("Model package archive format is not supported") - def init_model(self, *args: Any, **kwargs: Any) -> None: - """Initializes the MedCAT model based on the configuration. - - Args: - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - """ + def init_model(self) -> None: + """Initializes the MedCAT model based on the configuration.""" if hasattr(self, "_model") and isinstance(self._model, CAT): logger.warning("Model service is already initialised and can be initialised only once") diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index 9ec1248..fe94dde 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -178,13 +178,8 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: return annotations_list - def init_model(self, *args: Any, **kwargs: Any) -> None: - """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration. - - Args: - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - """ + def init_model(self) -> None: + """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration.""" if hasattr(self, "_model") and isinstance(self._model, CAT): logger.warning("Model service is already initialised and can be initialised only once") diff --git a/app/model_services/trf_model_deid.py b/app/model_services/trf_model_deid.py index fbf6290..fb8e3ac 100644 --- a/app/model_services/trf_model_deid.py +++ b/app/model_services/trf_model_deid.py @@ -86,7 +86,7 @@ def load_model( logger.info("Model loaded from %s", unpacked_model_dir) return tokenizer, model - def init_model(self, *args: Any, **kwargs: Any) -> None: + def init_model(self) -> None: if hasattr(self, "_model") and isinstance(self._model, PreTrainedModel): logger.warning("Model service is already initialised and can be initialised only once") else: diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index b85f44b..895a58e 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -88,11 +88,8 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None: self._model_service = model_service self._model_name = model_service.model_name self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join( - model_service._model_parent_dir, - "retrained", - self._model_name.replace(" ", "_"), - ) + self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", + self._model_name.replace(" ", "_")) self._model_manager = ModelManager(type(model_service), model_service._config) self._max_length = model_service.model.config.max_position_embeddings os.makedirs(self._retrained_models_dir, exist_ok=True) @@ -309,7 +306,7 @@ def run( logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") - trained_model_pack_path = None + copied_model_pack_path = None redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true" skip_save_model = self._config.SKIP_SAVE_MODEL == "true" results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results")) @@ -322,16 +319,15 @@ def run( if not eval_mode: try: - logger.info("Loading a PEFT model for training...") - model_pack_file_ext = get_model_data_package_extension(self._model_pack_path) - trained_model_pack_path = self._model_pack_path.replace( - model_pack_file_ext, - f"_trained_{run_id}{model_pack_file_ext}", + logger.info("Loading a new model copy for training...") + copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id) + model, tokenizer = self._model_service.load_model( + copied_model_pack_path, + load_in_4bit=True, # for memory efficient training ) - model, tokenizer = self._model_service.model, self._model_service.tokenizer - trained_model_directory = os.path.join( - os.path.dirname(trained_model_pack_path), - get_model_data_package_base_name(trained_model_pack_path), + copied_model_directory = os.path.join( + os.path.dirname(copied_model_pack_path), + get_model_data_package_base_name(copied_model_pack_path), ) if non_default_device_is_available(self._config.DEVICE): @@ -359,7 +355,7 @@ def run( ], ) - peft_model = get_peft_model(model, lora_config) + model = get_peft_model(model, lora_config) mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) @@ -382,7 +378,6 @@ def run( training_args = GRPOConfig( output_dir=results_path, logging_dir=logs_path, - logging_steps=log_frequency, learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, @@ -390,18 +385,20 @@ def run( warmup_ratio=0.1, lr_scheduler_type="cosine", optim="paged_adamw_8bit", + logging_steps=1, per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations gradient_accumulation_steps=1, num_generations=6, max_prompt_length=max_prompt_length, max_completion_length=max_seq_length - max_prompt_length, num_train_epochs = training_params["nepochs"], + max_steps=250, save_steps=250, max_grad_norm=0.1, report_to="none", ) trainer = GRPOTrainer( - model=peft_model, + model=model, processing_class=tokenizer, reward_funcs=self._get_reward_functions(), args=training_args, @@ -412,7 +409,7 @@ def run( else: raise ConfigurationException(f"Unsupported trainer type: {trainer_type}") - self._tracker_client.log_model_config({**model.config.to_dict(), **peft_model.peft_config}) + self._tracker_client.log_model_config(model.config.to_dict()) self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version) logger.info(f"Performing {trainer_type.upper()} training...") @@ -425,13 +422,11 @@ def run( model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE) model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}" retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name) - model = peft_model.merge_and_unload() model.save_pretrained( - trained_model_directory, + copied_model_directory, safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), ) - tokenizer.save_pretrained(trained_model_directory) - create_model_data_package(trained_model_directory, retrained_model_pack_path) + create_model_data_package(copied_model_directory, retrained_model_pack_path) model_uri = self._tracker_client.save_model( retrained_model_pack_path, self._model_name, @@ -480,7 +475,7 @@ def run( with self._training_lock: self._training_in_progress = False self._clean_up_training_cache() - self._housekeep_file(trained_model_pack_path) + self._housekeep_file(copied_model_pack_path) if trainer is not None: del trainer gc.collect() @@ -510,7 +505,6 @@ def run( training_args = GRPOConfig( output_dir=results_path, logging_dir=logs_path, - logging_steps=log_frequency, per_device_eval_batch_size=6, num_generations=2, max_prompt_length=max_prompt_length, @@ -613,19 +607,19 @@ def correctness_reward_func( ) return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] - def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: responses = [completion[0]["content"] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] - def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] - def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] @@ -646,7 +640,7 @@ def count_xml(text: str) -> float: count -= (len(text.split("\n")[-1]) - 1) * 0.001 return count - def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] diff --git a/app/utils.py b/app/utils.py index 47faefa..370c4e2 100644 --- a/app/utils.py +++ b/app/utils.py @@ -547,6 +547,11 @@ def unpack_model_data_package(model_data_file_path: str, model_data_folder_path: elif model_data_file_path.endswith(".tar.gz"): with tarfile.open(model_data_file_path, "r:gz") as f: for member in f.getmembers(): + path_parts = member.name.split(os.sep) + stripped_path = os.sep.join(path_parts[1:]) + if not stripped_path: + continue + member.name = stripped_path f.extract(member, path=model_data_folder_path) return True else: From e4966c223c606a2159ac6503927606dbcc9d2f8f Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 2/9] Revert "feat: add the option to include rewards metrics" This reverts commit d1ff2fb180178e16e73bf7be06bcbe8000d4e44d. --- app/trainers/huggingface_llm_trainer.py | 91 +------------------------ 1 file changed, 3 insertions(+), 88 deletions(-) diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index 895a58e..b41bdd1 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -7,7 +7,6 @@ import re import threading import json -import inspect import pandas as pd from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable from transformers import __version__ as transformers_version @@ -483,7 +482,6 @@ def run( else: try: logger.info("Evaluating the running model...") - include_rewards_metrics = training_params.get("include_rewards_metrics", False) model, tokenizer = self._model_service.model, self._model_service.tokenizer if non_default_device_is_available(self._config.DEVICE): model.to(self._config.DEVICE) @@ -530,23 +528,8 @@ def run( ) eval_metrics = trainer.evaluate() - if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics: - eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])}) logger.info(f"Evaluation metrics: {eval_metrics}") self._tracker_client.send_hf_metrics_logs(eval_metrics, 0) - if include_rewards_metrics: - try: - reward_metrics = self._evaluate_with_rewards( - model=model, - tokenizer=tokenizer, - eval_dataset=eval_dataset, - max_new_tokens=training_args.max_completion_length, - ) - if reward_metrics: - logger.info(f"Reward metrics: {reward_metrics}") - self._tracker_client.send_hf_metrics_logs(reward_metrics, 0) - except Exception as e: - logger.warning(f"Failed to compute reward-based metrics: {e}") self._tracker_client.end_with_success() logger.info("Model evaluation finished") except torch.OutOfMemoryError as e: @@ -594,8 +577,8 @@ def correctness_reward_func( answer: List, **kwargs: Dict[str, Any] ) -> List[float]: - responses = [completion[0]["content"] for completion in completions] - q = prompts[0][-1]["content"] + responses = [completion[0]['content'] for completion in completions] + q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] logger.debug( "%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s", @@ -608,7 +591,7 @@ def correctness_reward_func( return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: - responses = [completion[0]["content"] for completion in completions] + responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] @@ -652,74 +635,6 @@ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> l correctness_reward_func, ] - def _evaluate_with_rewards( - self, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - eval_dataset: datasets.Dataset, - max_new_tokens: int, - ) -> Dict[str, float]: - model.eval() - if non_default_device_is_available(self._config.DEVICE): - model.to(self._config.DEVICE) - - reward_funcs = self._get_reward_functions() - reward_sums: Dict[str, float] = {fn.__name__: 0.0 for fn in reward_funcs} - count = 0 - - for example in eval_dataset: - if "prompt" not in example: - continue - messages = example["prompt"] - answer = example.get("answer", "") - - prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - inputs = tokenizer(prompt_text, return_tensors="pt") - input_ids = inputs["input_ids"] - attention_mask = inputs.get("attention_mask") - if non_default_device_is_available(self._config.DEVICE): - input_ids = input_ids.to(self._config.DEVICE) - attention_mask = attention_mask.to(self._config.DEVICE) - - with torch.no_grad(): - generated = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - temperature=0.0, - eos_token_id=getattr(tokenizer, "eos_token_id", None), - pad_token_id=getattr(tokenizer, "pad_token_id", 0), - ) - - completion_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True) - for fn in reward_funcs: - sig = inspect.signature(fn) - kwargs: Dict[str, Any] = {} - if "prompts" in sig.parameters: - kwargs["prompts"] = [messages] - if "completions" in sig.parameters: - kwargs["completions"] = [({"content": completion_text},)] - if "answer" in sig.parameters: - kwargs["answer"] = [answer] - - try: - rewards = fn(**kwargs) # type: ignore - value = float(rewards[0]) if isinstance(rewards, (list, tuple)) and rewards else float(rewards) - except Exception: - value = 0.0 - - reward_sums[fn.__name__] += value - count += 1 - if count == 0: - return {} - - reward_avgs = {f"reward_{name}": total / count for name, total in reward_sums.items()} - reward_overall_mean = sum(reward_avgs.values()) / len(reward_avgs) if reward_avgs else 0.0 - reward_avgs["reward_overall_mean"] = reward_overall_mean - reward_avgs["reward_samples"] = float(count) - return reward_avgs - @final class MLflowLoggingCallback(TrainerCallback): From ed1e0fa166339254996b880250006b2e5d25e6b9 Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 3/9] Revert "fix: use the GRPO trainer for evaluation" This reverts commit 5075fa3cd1f2cc32ab4dbbb3ce2ff22e4b1d056e. --- app/api/utils.py | 30 +-- app/exception.py | 4 - app/model_services/huggingface_llm_model.py | 3 +- app/trainers/huggingface_llm_trainer.py | 261 +++++++++----------- 4 files changed, 116 insertions(+), 182 deletions(-) diff --git a/app/api/utils.py b/app/api/utils.py index 87cea26..e73425d 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -27,13 +27,7 @@ from fastapi_users.jwt import decode_jwt from app.config import Settings from app.domain import TagsGenerative -from app.exception import ( - StartTrainingException, - AnnotationException, - ConfigurationException, - ClientException, - ExtraDependencyRequiredException, -) +from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException logger = logging.getLogger("cms") @@ -124,24 +118,6 @@ async def configuration_exception_handler(_: Request, exception: ConfigurationEx logger.exception(exception) return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) - @app.exception_handler(ExtraDependencyRequiredException) - async def extra_dependency_exception_handler( - _: Request, - exception: ExtraDependencyRequiredException - ) -> JSONResponse: - """ - Handles extra dependency required exceptions. - - Args: - _ (Request): The request object. - exception (ExtraDependencyRequiredException): The extra dependency required exception. - - Returns: - JSONResponse: A JSON response with a 500 status code and an error message. - """ - logger.exception(exception) - return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) - @app.exception_handler(ClientException) async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse: """ @@ -323,8 +299,8 @@ async def init_vllm_engine(app: FastAPI, ) from vllm import SamplingParams, TokensPrompt except ImportError: - logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") - raise ExtraDependencyRequiredException("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") + # Raise a custom exception if vLLM is not installed + raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.") parser = FlexibleArgumentParser() parser = make_arg_parser(parser) diff --git a/app/exception.py b/app/exception.py index ddba71b..99b87a6 100644 --- a/app/exception.py +++ b/app/exception.py @@ -32,7 +32,3 @@ class DatasetException(Exception): class DeviceNotAvailableError(RuntimeError): """An exception raised when a specificy device is required but not available.""" - - -class ExtraDependencyRequiredException(Exception): - """An exception raised when an extra dependency is required but not found.""" diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 98f9bea..13838d7 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -369,9 +369,8 @@ def create_embeddings( sum_hidden_states = masked_hidden_states.sum(dim=1) num_tokens = attention_mask.sum(dim=1, keepdim=True) embeddings = sum_hidden_states / num_tokens - l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1) - results = l2_normalised.cpu().numpy().tolist() + results = embeddings.cpu().numpy().tolist() return results[0] if isinstance(text, str) else results def train_supervised( diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index b41bdd1..49082d5 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -18,6 +18,7 @@ TrainerCallback, TrainerState, TrainerControl, + Trainer ) from peft import LoraConfig, get_peft_model # type: ignore from app.management.model_manager import ModelManager @@ -40,7 +41,6 @@ DatasetException, ConfigurationException, DeviceNotAvailableError, - ExtraDependencyRequiredException, ) if TYPE_CHECKING: from app.model_services.huggingface_llm_model import HuggingFaceLlmModel @@ -93,6 +93,25 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None: self._max_length = model_service.model.config.max_position_embeddings os.makedirs(self._retrained_models_dir, exist_ok=True) + class _LocalDataCollator: + + def __init__(self, max_length: int, pad_token_id: int) -> None: + self.max_length = max_length + self.pad_token_id = pad_token_id + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + return { + "input_ids": torch.tensor([self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) for f in features], dtype=torch.long), + "labels": torch.tensor([self._add_padding(f["labels"], self.max_length, HuggingFaceLlmSupervisedTrainer.PAD_LABEL_ID) for f in features], dtype=torch.long), + "attention_mask": torch.tensor([self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], dtype=torch.long), + } + + @staticmethod + def _add_padding(target: List[int], max_length: int, pad_token_id: int) -> List[int]: + padding_length = max(0, max_length - len(target)) + paddings = [pad_token_id] * padding_length + return target + paddings + def _load_dataset_from_config(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: """ Loads training and validation datasets based on configuration in training_params. @@ -137,6 +156,8 @@ def _set_dataset_format(train_dataset: datasets.Dataset, test_dataset: datasets. else: raise DatasetException("Unsupported dataset format") + + def _load_huggingface_dataset(self, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: """Loads dataset from HuggingFace Hub.""" @@ -299,12 +320,6 @@ def run( if self._config.DEVICE is not Device.GPU.value: raise DeviceNotAvailableError("This trainer currently requires a CUDA device") - try: - from trl import GRPOConfig, GRPOTrainer # , PPOConfig, PPOTrainer - except ImportError: - logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") - raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") - copied_model_pack_path = None redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true" skip_save_model = self._config.SKIP_SAVE_MODEL == "true" @@ -313,9 +328,6 @@ def run( reset_random_seed() eval_mode = training_params["nepochs"] == 0 self._tracker_client.log_trainer_mode(not eval_mode) - trainer = None - max_seq_length = 1024 - if not eval_mode: try: logger.info("Loading a new model copy for training...") @@ -356,27 +368,79 @@ def run( model = get_peft_model(model, lora_config) + def extract_xml_answer(text: str) -> str: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + + # Reward functions + def correctness_reward_func( + prompts: List, + completions: List, + answer: List, + **kwargs: Dict[str, Any] + ) -> List[float]: + responses = [completion[0]['content'] for completion in completions] + q = prompts[0][-1]['content'] + extracted_responses = [extract_xml_answer(r) for r in responses] + print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", + f"\nExtracted:\n{extracted_responses[0]}") + return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + + def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + responses = [completion[0]['content'] for completion in completions] + extracted_responses = [extract_xml_answer(r) for r in responses] + return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + + def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r"^\n.*?\n\n\n.*?\n\n$" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r) for r in responses] + return [0.5 if match else 0.0 for match in matches] + + def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r".*?\s*.*?" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r) for r in responses] + return [0.5 if match else 0.0 for match in matches] + + def count_xml(text: str) -> float: + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + count -= len(text.split("\n\n")[-1]) * 0.001 + if text.count("\n") == 1: + count += 0.125 + count -= (len(text.split("\n")[-1]) - 1) * 0.001 + return count + + def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + contents = [completion[0]["content"] for completion in completions] + return [count_xml(c) for c in contents] + mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] + max_prompt_length = 256 + max_seq_length = 1024 + + try: + from trl import GRPOConfig, GRPOTrainer #, PPOConfig, PPOTrainer + except ImportError: + logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") + trainer_type = training_params.get("trainer_type", LlmTrainerType.GRPO.value).lower() - max_prompt_length = max(train_dataset.map( - lambda x: { - "tokens": tokenizer.apply_chat_template( - x["prompt"], - add_generation_prompt=True, - tokenize=True - ) - }, - batched=True, - ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1 if trainer_type == LlmTrainerType.PPO.value: raise NotImplementedError("PPO training is not yet supported for HuggingFace LLM models") elif trainer_type == LlmTrainerType.GRPO.value: training_args = GRPOConfig( - output_dir=results_path, - logging_dir=logs_path, learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, @@ -395,11 +459,18 @@ def run( save_steps=250, max_grad_norm=0.1, report_to="none", + output_dir="outputs", ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, - reward_funcs=self._get_reward_functions(), + reward_funcs=[ + xmlcount_reward_func, + soft_format_reward_func, + strict_format_reward_func, + int_reward_func, + correctness_reward_func, + ], args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, @@ -475,56 +546,41 @@ def run( self._training_in_progress = False self._clean_up_training_cache() self._housekeep_file(copied_model_pack_path) - if trainer is not None: - del trainer - gc.collect() - torch.cuda.empty_cache() + del trainer + gc.collect() + torch.cuda.empty_cache() else: try: logger.info("Evaluating the running model...") - model, tokenizer = self._model_service.model, self._model_service.tokenizer + model, tokenizer = self._model_service.load_model(self._model_pack_path) if non_default_device_is_available(self._config.DEVICE): model.to(self._config.DEVICE) eval_dataset, _ = self._load_dataset_from_config(data_file, training_params) make_conversation = self._create_conversation_formatter(training_params) eval_dataset = eval_dataset.map(make_conversation) - max_prompt_length = max(eval_dataset.map( - lambda x: { - "tokens": tokenizer.apply_chat_template( - x["prompt"], - add_generation_prompt=True, - tokenize=True - ) - }, - batched=True, - ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1 - - training_args = GRPOConfig( + + data_collator = self._LocalDataCollator( + max_length=self._max_length, + pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 + ) + + training_args = TrainingArguments( output_dir=results_path, logging_dir=logs_path, - per_device_eval_batch_size=6, - num_generations=2, - max_prompt_length=max_prompt_length, - max_completion_length=max_seq_length - max_prompt_length, - num_train_epochs=training_params["nepochs"], - report_to="none", + per_device_eval_batch_size=1, do_train=False, do_eval=True, + report_to="none", + dataloader_drop_last=False, ) - mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) - cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) - trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] - - trainer = GRPOTrainer( + trainer = Trainer( model=model, - processing_class=tokenizer, args=training_args, - reward_funcs=self._get_reward_functions(), - train_dataset=None, + data_collator=data_collator, eval_dataset=eval_dataset, - callbacks=trainer_callbacks, + tokenizer=tokenizer, ) eval_metrics = trainer.evaluate() @@ -532,24 +588,8 @@ def run( self._tracker_client.send_hf_metrics_logs(eval_metrics, 0) self._tracker_client.end_with_success() logger.info("Model evaluation finished") - except torch.OutOfMemoryError as e: - logger.exception("Evaluation failed on CUDA OOM") - try: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - try: - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_accumulated_memory_stats() - except Exception: - pass - torch.cuda.synchronize() - except Exception: - pass - self._tracker_client.log_exceptions(e) - self._tracker_client.end_with_failure() except Exception as e: - logger.exception("Evaluation failed") + logger.exception("Model evaluation failed") self._tracker_client.log_exceptions(e) self._tracker_client.end_with_failure() finally: @@ -557,83 +597,6 @@ def run( with self._training_lock: self._training_in_progress = False self._clean_up_training_cache() - if trainer is not None: - del trainer - gc.collect() - torch.cuda.empty_cache() - - @staticmethod - def _get_reward_functions() -> List: - - def extract_xml_answer(text: str) -> str: - answer = text.split("")[-1] - answer = answer.split("")[0] - return answer.strip() - - # Reward functions - def correctness_reward_func( - prompts: List, - completions: List, - answer: List, - **kwargs: Dict[str, Any] - ) -> List[float]: - responses = [completion[0]['content'] for completion in completions] - q = prompts[0][-1]['content'] - extracted_responses = [extract_xml_answer(r) for r in responses] - logger.debug( - "%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s", - "-" * 20, - q, - answer[0], - responses[0], - extracted_responses[0] - ) - return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] - - def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: - responses = [completion[0]['content'] for completion in completions] - extracted_responses = [extract_xml_answer(r) for r in responses] - return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] - - def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: - """Reward function that checks if the completion has a specific format.""" - pattern = r"^\n.*?\n\n\n.*?\n\n$" - responses = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, r) for r in responses] - return [0.5 if match else 0.0 for match in matches] - - def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: - """Reward function that checks if the completion has a specific format.""" - pattern = r".*?\s*.*?" - responses = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, r) for r in responses] - return [0.5 if match else 0.0 for match in matches] - - def count_xml(text: str) -> float: - count = 0.0 - if text.count("\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 - count -= len(text.split("\n\n")[-1]) * 0.001 - if text.count("\n") == 1: - count += 0.125 - count -= (len(text.split("\n")[-1]) - 1) * 0.001 - return count - - def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]: - contents = [completion[0]["content"] for completion in completions] - return [count_xml(c) for c in contents] - - return [ - xmlcount_reward_func, - soft_format_reward_func, - strict_format_reward_func, - int_reward_func, - correctness_reward_func, - ] @final From 69da1e3c2efab802a6a9c44117527a76cdbe0df8 Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 4/9] Revert "feat: add the trainer for HF LLMs" This reverts commit 994f88d55233a733274b34f2952538243811ed5f. --- app/api/api.py | 24 +- app/api/routers/generative.py | 5 +- app/api/routers/supervised_training.py | 24 +- app/api/utils.py | 4 +- app/domain.py | 19 - app/exception.py | 6 +- app/model_services/huggingface_llm_model.py | 71 +- app/processors/metrics_collector.py | 30 - app/trainers/huggingface_llm_trainer.py | 674 ------------------ app/utils.py | 50 -- pyproject.toml | 4 - tests/app/api/test_api.py | 4 +- .../app/processors/test_metrics_collector.py | 54 -- uv.lock | 177 +---- 14 files changed, 26 insertions(+), 1120 deletions(-) delete mode 100644 app/trainers/huggingface_llm_trainer.py diff --git a/app/api/api.py b/app/api/api.py index 5dd1522..b9874ab 100644 --- a/app/api/api.py +++ b/app/api/api.py @@ -4,7 +4,7 @@ import os.path import app.api.globals as cms_globals -from typing import Dict, Any, Optional, Union, Type +from typing import Dict, Any, Optional from concurrent.futures import ThreadPoolExecutor from anyio.lowlevel import RunVar from anyio import CapacityLimiter @@ -20,7 +20,7 @@ from app.api.dependencies import ModelServiceDep from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine from app.config import Settings -from app.domain import Tags, TagsStreamable, TagsGenerative +from app.domain import Tags, TagsStreamable from app.management.tracker_client import TrackerClient from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name from app.exception import ConfigurationException @@ -131,11 +131,6 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi app = _load_health_check_router(app) logger.debug("Health check router loaded") - if config.ENABLE_TRAINING_APIS == "true": - app = _load_supervised_training_router(app) - logger.debug("Supervised training router loaded") - app = _load_training_operations(app) - if config.AUTH_USER_ENABLED == "true": app = _load_auth_router(app) logger.debug("Auth router loaded") @@ -203,18 +198,11 @@ def _get_app( streamable: bool = False, generative: bool = False, ) -> FastAPI: - config = get_settings() - tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]] - if generative: - tags = TagsGenerative - elif streamable: - tags = TagsStreamable - else: - tags = Tags tags_metadata = [{ # type: ignore - "name": tag.name, # type: ignore - "description": tag.value # type: ignore - } for tag in tags] + "name": tag.name, + "description": tag.value + } for tag in (Tags if not streamable else TagsStreamable)] + config = get_settings() app = FastAPI( title="CogStack ModelServe", summary="A model serving and governance system for CogStack NLP solutions", diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index b2d9454..26f4fd1 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -13,7 +13,6 @@ from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from app.domain import ( Tags, - TagsGenerative, OpenAIChatRequest, OpenAIChatResponse, OpenAIEmbeddingsRequest, @@ -42,7 +41,7 @@ @router.post( PATH_GENERATE, - tags=[TagsGenerative.Generative], + tags=[Tags.Generative.name], response_class=PlainTextResponse, dependencies=[Depends(cms_globals.props.current_active_user)], description="Generate text", @@ -92,7 +91,7 @@ def generate_text( @router.post( PATH_GENERATE_ASYNC, - tags=[TagsGenerative.Generative], + tags=[Tags.Generative.name], response_class=StreamingResponse, dependencies=[Depends(cms_globals.props.current_active_user)], description="Generate a stream of texts", diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index 1ee7c9b..89d0d17 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -12,9 +12,9 @@ import app.api.globals as cms_globals from app.api.dependencies import validate_tracking_id -from app.domain import Tags, ModelType +from app.domain import Tags from app.model_services.base import AbstractModelService -from app.processors.metrics_collector import concat_json_lists, concat_trainer_exports +from app.processors.metrics_collector import concat_trainer_exports from app.utils import filter_by_concept_ids router = APIRouter() @@ -72,19 +72,12 @@ async def train_supervised( files.append(temp_te) file_names.append("" if te.filename is None else te.filename) - if model_service.info().model_type is not ModelType.HUGGINGFACE_LLM: - concatenated_te = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) - logger.debug("Training exports concatenated") - data_file = tempfile.NamedTemporaryFile(mode="w+") - concatenated_te = filter_by_concept_ids(cast(Dict[str, Any], concatenated_te), model_service.info().model_type) - logger.debug("Training exports filtered by concept IDs") - json.dump(concatenated_te, data_file) - else: - concatenated = concat_json_lists([file.name for file in files]) - logger.debug("Training exports concatenated") - data_file = tempfile.NamedTemporaryFile(mode="w+") - json.dump(concatenated, data_file) - + concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + logger.debug("Training exports concatenated") + data_file = tempfile.NamedTemporaryFile(mode="w") + concatenated = filter_by_concept_ids(cast(Dict[str, Any], concatenated), model_service.info().model_type) + logger.debug("Training exports filtered by concept IDs") + json.dump(concatenated, data_file) data_file.flush() data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) @@ -109,7 +102,6 @@ async def train_supervised( return _get_training_response(training_response, training_id) - def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: training_accepted, experiment_id, run_id = training_response if training_accepted: diff --git a/app/api/utils.py b/app/api/utils.py index e73425d..05726ba 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -26,7 +26,7 @@ from slowapi.errors import RateLimitExceeded from fastapi_users.jwt import decode_jwt from app.config import Settings -from app.domain import TagsGenerative +from app.domain import Tags from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException logger = logging.getLogger("cms") @@ -376,7 +376,7 @@ async def _stream() -> AsyncGenerator[bytes, None]: endpoint=endpoint, methods=methods, include_in_schema=True, - tags=[TagsGenerative.Generative.name], + tags=[Tags.Generative], ) app.include_router(router) diff --git a/app/domain.py b/app/domain.py index 6be1564..de098a8 100644 --- a/app/domain.py +++ b/app/domain.py @@ -31,15 +31,9 @@ class Tags(str, Enum): class TagsStreamable(str, Enum): - Metadata = "Get the model card" Streaming = "Retrieve NER entities as a stream by running the model" -class TagsGenerative(str, Enum): - Metadata = "Get the model card" - Generative = "Generate text based on the input prompt" - - class CodeType(str, Enum): SNOMED = "SNOMED" UMLS = "UMLS" @@ -110,19 +104,6 @@ class LlmEngine(Enum): CMS = "CMS" VLLM = "vLLM" -class LlmRole(Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - -class LlmTrainerType(Enum): - GRPO = "grpo" - PPO = "ppo" - -class LlmDatasetType(Enum): - JSON = "json" - CSV = "csv" class Annotation(BaseModel): doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs") diff --git a/app/exception.py b/app/exception.py index 99b87a6..1b8f9bc 100644 --- a/app/exception.py +++ b/app/exception.py @@ -27,8 +27,4 @@ class ClientException(Exception): class DatasetException(Exception): - """An exception raised due to dataset errors""" - - -class DeviceNotAvailableError(RuntimeError): - """An exception raised when a specificy device is required but not available.""" + """ An exception raised due to dataset errors""" diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 13838d7..25cd032 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -3,19 +3,17 @@ import asyncio import torch from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, TextIO, Callable, Union +from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, TextIteratorStreamer, - BitsAndBytesConfig, ) from app import __version__ as app_version from app.exception import ConfigurationException from app.model_services.base import AbstractModelService -from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer from app.domain import ModelCard, ModelType, Annotation from app.config import Settings from app.utils import ( @@ -125,19 +123,13 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) return model_service @staticmethod - def load_model( - model_file_path: str, - *args: Tuple, - load_in_4bit: bool = False, - **kwargs: Dict[str, Any] - ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: """ Loads a pre-trained model and its tokenizer from a model package file. Args: model_file_path (str): The path to the model package file. *args (Tuple): Additional positional arguments. - load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: @@ -150,16 +142,7 @@ def load_model( model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path)) if unpack_model_data_package(model_file_path, model_path): try: - if load_in_4bit: - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - ) - model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) - else: - model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained(model_path) ensure_tensor_contiguity(model) tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -189,7 +172,7 @@ def init_model(self) -> None: if non_default_device_is_available(get_settings().DEVICE): self._model.to(get_settings().DEVICE) if self._enable_trainer: - self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self) + logger.error("Trainers are not yet implemented for HuggingFace Generative models") def info(self) -> ModelCard: """ @@ -372,49 +355,3 @@ def create_embeddings( results = embeddings.cpu().numpy().tolist() return results[0] if isinstance(text, str) else results - - def train_supervised( - self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any], - ) -> Tuple[bool, str, str]: - """ - Initiates supervised training on the model. - - Args: - data_file (TextIO): The file containing the trainer export data. - epochs (int): The number of training epochs. - log_frequency (int): The number of epochs after which training metrics will be logged. - training_id (str): A unique identifier for the training process. - input_file_name (str): The name of the input file to be logged. - raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None. - description (Optional[str]): The description of the training or change logs. Defaults to empty. - synchronised (bool): Whether to wait for the training to complete. - **hyperparams (Dict[str, Any]): Additional hyperparameters for training. - - Returns: - Tuple[bool, str, str]: A tuple with the first element indicating success or failure. - - Raises: - ConfigurationException: If the supervised trainer is not enabled. - """ - if self._supervised_trainer is None: - raise ConfigurationException("The supervised trainer is not enabled") - return self._supervised_trainer.train( - data_file, - epochs, - log_frequency, - training_id, - input_file_name, - raw_data_files, - description, - synchronised, - **hyperparams, - ) diff --git a/app/processors/metrics_collector.py b/app/processors/metrics_collector.py index 84f74da..07f0592 100644 --- a/app/processors/metrics_collector.py +++ b/app/processors/metrics_collector.py @@ -194,36 +194,6 @@ def concat_trainer_exports( return combined -def concat_json_lists( - data_file_paths: List[str], - combined_data_file_path: Optional[str] = None, -) -> Union[List[Dict[str, Any]], str]: - """ - Concatenates multiple json list files into a single combined file. - - Args: - data_file_paths (List[str]): List of paths to files each containing a json list. - combined_data_file_path (Optional[str]): The file path where the combined data will be saved. If None, the combined data will be returned as a list. - - - Returns: - Union[List[Dict[str, Any]], str]: The path to the combined data file if `combined_data_file_path` is provided, or the combined data as a list otherwise. - """ - combined: List = [] - for path in data_file_paths: - with open(path, "r") as f: - data = json.load(f) - combined.extend(data) - - if isinstance(combined_data_file_path, str): - with open(combined_data_file_path, "w") as f: - json.dump(combined, f) - - return combined_data_file_path - else: - return combined - - def get_stats_from_trainer_export( trainer_export: Union[str, IO, Dict], return_df: bool = False, diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py deleted file mode 100644 index 49082d5..0000000 --- a/app/trainers/huggingface_llm_trainer.py +++ /dev/null @@ -1,674 +0,0 @@ -import os -import logging -import math -import torch -import gc -import datasets -import re -import threading -import json -import pandas as pd -from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable -from transformers import __version__ as transformers_version -from transformers import ( - TrainingArguments, - PreTrainedModel, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, - TrainerCallback, - TrainerState, - TrainerControl, - Trainer -) -from peft import LoraConfig, get_peft_model # type: ignore -from app.management.model_manager import ModelManager -from app.management.tracker_client import TrackerClient -from app.utils import ( - reset_random_seed, - non_default_device_is_available, - create_model_data_package, - get_model_data_package_extension, - load_pydantic_object_from_dict, - get_default_chat_template, - get_default_system_prompt, - get_model_data_package_base_name, -) -from app.trainers.base import SupervisedTrainer -from app.domain import ModelType, TrainerBackend, LlmRole, LlmTrainerType, LlmDatasetType, PromptMessage, Device -from app.exception import ( - TrainingCancelledException, - ManagedModelException, - DatasetException, - ConfigurationException, - DeviceNotAvailableError, -) -if TYPE_CHECKING: - from app.model_services.huggingface_llm_model import HuggingFaceLlmModel - -logger = logging.getLogger("cms") - - -class _HuggingFaceLlmTrainerCommon(object): - - @staticmethod - def deploy_model( - model_service: "HuggingFaceLlmModel", - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - ) -> None: - del model_service.model - del model_service.tokenizer - gc.collect() - model_service.model = model - model_service.tokenizer = tokenizer - logger.info("Retrained model deployed") - - -@final -class HuggingFaceLlmSupervisedTrainer(SupervisedTrainer, _HuggingFaceLlmTrainerCommon): - """ - A supervised trainer class for HuggingFace LLM models. - - Args: - model_service (HuggingFaceLlmModel): An instance of the HuggingFace LLM model service. - """ - - MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT = 5 - MAX_CONCEPTS_TO_TRACK = 20 - PAD_LABEL_ID = -100 - DEFAULT_LABEL_ID = 0 - CONTINUING_TOKEN_LABEL_ID = 1 - - def __init__(self, model_service: "HuggingFaceLlmModel") -> None: - if not isinstance(model_service.tokenizer, PreTrainedTokenizerFast): - logger.error("The supervised trainer requires a fast tokenizer to function correctly") - raise ManagedModelException("The supervised trainer requires a fast tokenizer to function correctly") - SupervisedTrainer.__init__(self, model_service._config, model_service.model_name) - self._model_service = model_service - self._model_name = model_service.model_name - self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", - self._model_name.replace(" ", "_")) - self._model_manager = ModelManager(type(model_service), model_service._config) - self._max_length = model_service.model.config.max_position_embeddings - os.makedirs(self._retrained_models_dir, exist_ok=True) - - class _LocalDataCollator: - - def __init__(self, max_length: int, pad_token_id: int) -> None: - self.max_length = max_length - self.pad_token_id = pad_token_id - - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - return { - "input_ids": torch.tensor([self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) for f in features], dtype=torch.long), - "labels": torch.tensor([self._add_padding(f["labels"], self.max_length, HuggingFaceLlmSupervisedTrainer.PAD_LABEL_ID) for f in features], dtype=torch.long), - "attention_mask": torch.tensor([self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], dtype=torch.long), - } - - @staticmethod - def _add_padding(target: List[int], max_length: int, pad_token_id: int) -> List[int]: - padding_length = max(0, max_length - len(target)) - paddings = [pad_token_id] * padding_length - return target + paddings - - def _load_dataset_from_config(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: - """ - Loads training and validation datasets based on configuration in training_params. - - Args: - data_file: The training data file - training_params: Dictionary containing dataset configuration - - Returns: - Tuple of (train_dataset, validation_dataset) - """ - dataset_type = training_params.get("dataset_type", "json") - - # if dataset_type == "huggingface": - # return self._load_huggingface_dataset(training_params) - if dataset_type == LlmDatasetType.JSON.value: - return self._load_json_dataset(data_file, training_params) - elif dataset_type == LlmDatasetType.CSV.value: - return self._load_csv_dataset(data_file, training_params) - else: - raise DatasetException(f"Unsupported dataset type: {dataset_type}") - - @staticmethod - def _set_dataset_format(train_dataset: datasets.Dataset, test_dataset: datasets.Dataset) -> None: - """Sets the format of the datasets based on the dataset structure.""" - - if "messages" in train_dataset.column_names: - train_dataset.set_format(type=None, columns=["messages"]) - test_dataset.set_format(type=None, columns=["messages"]) - elif "question" in train_dataset.column_names and "answer" in train_dataset.column_names: - train_dataset.set_format(type=None, columns=["question", "answer"]) - test_dataset.set_format(type=None, columns=["question", "answer"]) - elif "input" in train_dataset.column_names and "output" in train_dataset.column_names: - train_dataset.set_format(type=None, columns=["input", "output"]) - test_dataset.set_format(type=None, columns=["input", "output"]) - elif "prompt" in train_dataset.column_names and "completion" in train_dataset.column_names: - train_dataset.set_format(type=None, columns=["prompt", "completion"]) - test_dataset.set_format(type=None, columns=["prompt", "completion"]) - elif "problem" in train_dataset.column_names and "solution" in train_dataset.column_names: - train_dataset.set_format(type=None, columns=["problem", "solution"]) - test_dataset.set_format(type=None, columns=["problem", "solution"]) - else: - raise DatasetException("Unsupported dataset format") - - - - def _load_huggingface_dataset(self, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: - """Loads dataset from HuggingFace Hub.""" - - dataset_id = training_params.get("dataset_id", "AI-MO/NuminaMath-TIR") - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] - split_ratio = 1 - test_size - train_percentage = int(split_ratio * 100) - test_percentage = 100 - train_percentage - train_split = training_params.get("train_split", f"train[:{train_percentage}%]") - test_split = training_params.get("test_split", f"test[:{test_percentage}%]") - - logger.info(f"Loading HuggingFace dataset: {dataset_id}") - train_dataset, test_dataset = datasets.load_dataset(dataset_id, split=[train_split, test_split]) - self._set_dataset_format(train_dataset, test_dataset) - - return train_dataset, test_dataset - - - def _load_json_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: - """Loads dataset from JSON file.""" - - data = json.load(data_file) - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] - split_ratio = 1 - test_size - - if isinstance(data, list): - examples = data - split_idx = int(len(examples) * split_ratio) - train_examples = examples[:split_idx] - test_examples = examples[split_idx:] - elif isinstance(data, dict) and "train" in data and "test" in data: - train_examples = data["train"] - test_examples = data["test"] - elif isinstance(data, dict) and "examples" in data: - examples = data["examples"] - split_idx = int(len(examples) * split_ratio) - train_examples = examples[:split_idx] - test_examples = examples[split_idx:] - else: - raise DatasetException("Unsupported JSON format") - - train_dataset = datasets.Dataset.from_list(train_examples) - test_dataset = datasets.Dataset.from_list(test_examples) - self._set_dataset_format(train_dataset, test_dataset) - - return train_dataset, test_dataset - - def _load_csv_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: - """Loads dataset from CSV file.""" - - df = pd.read_csv(data_file) - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] - split_ratio = 1 - test_size - split_idx = int(len(df) * split_ratio) - - train_df = df[:split_idx] - test_df = df[split_idx:] - - train_dataset = datasets.Dataset.from_pandas(train_df) - test_dataset = datasets.Dataset.from_pandas(test_df) - self._set_dataset_format(train_dataset, test_dataset) - - return train_dataset, test_dataset - - def _create_conversation_formatter(self, training_params: Dict) -> Callable: - """ - Creates a conversation formatter based on training parameters. - - Args: - training_params: Dictionary containing formatting configuration - - Returns: - Function that formats examples into conversations - """ - format_config = training_params.get("format_config", {}) - system_prompt = format_config.get("system_prompt", get_default_system_prompt()) - - def make_conversation(example: Dict[str, Any]) -> Dict[str, Any]: - # Handle different input formats - if "messages" in example: - system_content = None - question_content = None - answer_content = None - for message in example.get("messages", []): - msg = load_pydantic_object_from_dict(PromptMessage, message) - if msg.role == LlmRole.SYSTEM: - system_content = msg.content - elif msg.role == LlmRole.USER: - question_content = msg.content - elif msg.role == LlmRole.ASSISTANT: - answer_content = msg.content - - return { - "prompt": [ - {"role": "system", "content": system_prompt if system_content is None else system_content}, - {"role": "user", "content": question_content if question_content is not None else ""}, - ], - "answer": answer_content if answer_content is not None else "", - } - elif "question" in example and "answer" in example: - # Question/Answer format - return { - "prompt": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": example.get("question")}, - ], - "answer": example["answer"], - } - elif "input" in example and "output" in example: - # Input/Output format - return { - "prompt": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": example.get("input")}, - ], - "answer": example["output"], - } - elif "prompt" in example and "completion" in example: - # Prompt/Completion format - return { - "prompt": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": example.get("prompt")}, - ], - "answer": example["completion"], - } - elif "problem" in example and "solution" in example: - # Problem/Solution format - return { - "prompt": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": example.get("problem")}, - ], - "answer": example["solution"], - } - else: - raise DatasetException(f"Cannot determine the conversation format from example: {example}") - - return make_conversation - - def run( - self, - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None, - ) -> None: - """ - Runs the supervised training loop for HuggingFace LLM models. - - Args: - training_params (Dict): A dictionary containing parameters for the training. - data_file (Union[TextIO, tempfile.TemporaryDirectory]): The file-like object or temporary directory containing the training data. - log_frequency (int): The frequency at which logs should be recorded (e.g, the number of processed documents or finished epochs). - run_id (str): The run ID of the training job. - description (Optional[str]): The optional description of the training or change logs. - """ - - if self._config.DEVICE is not Device.GPU.value: - raise DeviceNotAvailableError("This trainer currently requires a CUDA device") - - copied_model_pack_path = None - redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true" - skip_save_model = self._config.SKIP_SAVE_MODEL == "true" - results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results")) - logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs")) - reset_random_seed() - eval_mode = training_params["nepochs"] == 0 - self._tracker_client.log_trainer_mode(not eval_mode) - if not eval_mode: - try: - logger.info("Loading a new model copy for training...") - copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id) - model, tokenizer = self._model_service.load_model( - copied_model_pack_path, - load_in_4bit=True, # for memory efficient training - ) - copied_model_directory = os.path.join( - os.path.dirname(copied_model_pack_path), - get_model_data_package_base_name(copied_model_pack_path), - ) - - if non_default_device_is_available(self._config.DEVICE): - model.to(self._config.DEVICE) - - train_dataset, test_dataset = self._load_dataset_from_config(data_file, training_params) - make_conversation = self._create_conversation_formatter(training_params) - train_dataset = train_dataset.map(make_conversation) - test_dataset = test_dataset.map(make_conversation) - - if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: - logger.warning("The tokenizer does not have a chat template. Using the default one.") - tokenizer.chat_template = get_default_chat_template() - else: - logger.debug(f"Found a chat template in the tokenizer:\n {tokenizer.chat_template}") - - lora_config = LoraConfig( - task_type="CAUSAL_LM", - r=8, - lora_alpha=32, - lora_dropout=0.1, - target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj", - ], - ) - - model = get_peft_model(model, lora_config) - - def extract_xml_answer(text: str) -> str: - answer = text.split("")[-1] - answer = answer.split("")[0] - return answer.strip() - - # Reward functions - def correctness_reward_func( - prompts: List, - completions: List, - answer: List, - **kwargs: Dict[str, Any] - ) -> List[float]: - responses = [completion[0]['content'] for completion in completions] - q = prompts[0][-1]['content'] - extracted_responses = [extract_xml_answer(r) for r in responses] - print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", - f"\nExtracted:\n{extracted_responses[0]}") - return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] - - def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: - responses = [completion[0]['content'] for completion in completions] - extracted_responses = [extract_xml_answer(r) for r in responses] - return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] - - def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: - """Reward function that checks if the completion has a specific format.""" - pattern = r"^\n.*?\n\n\n.*?\n\n$" - responses = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, r) for r in responses] - return [0.5 if match else 0.0 for match in matches] - - def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: - """Reward function that checks if the completion has a specific format.""" - pattern = r".*?\s*.*?" - responses = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, r) for r in responses] - return [0.5 if match else 0.0 for match in matches] - - def count_xml(text: str) -> float: - count = 0.0 - if text.count("\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 - count -= len(text.split("\n\n")[-1]) * 0.001 - if text.count("\n") == 1: - count += 0.125 - count -= (len(text.split("\n")[-1]) - 1) * 0.001 - return count - - def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: - contents = [completion[0]["content"] for completion in completions] - return [count_xml(c) for c in contents] - - mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) - cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) - trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] - - max_prompt_length = 256 - max_seq_length = 1024 - - try: - from trl import GRPOConfig, GRPOTrainer #, PPOConfig, PPOTrainer - except ImportError: - logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") - - trainer_type = training_params.get("trainer_type", LlmTrainerType.GRPO.value).lower() - if trainer_type == LlmTrainerType.PPO.value: - raise NotImplementedError("PPO training is not yet supported for HuggingFace LLM models") - elif trainer_type == LlmTrainerType.GRPO.value: - training_args = GRPOConfig( - learning_rate=5e-6, - adam_beta1=0.9, - adam_beta2=0.99, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type="cosine", - optim="paged_adamw_8bit", - logging_steps=1, - per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations - gradient_accumulation_steps=1, - num_generations=6, - max_prompt_length=max_prompt_length, - max_completion_length=max_seq_length - max_prompt_length, - num_train_epochs = training_params["nepochs"], - max_steps=250, - save_steps=250, - max_grad_norm=0.1, - report_to="none", - output_dir="outputs", - ) - trainer = GRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ - xmlcount_reward_func, - soft_format_reward_func, - strict_format_reward_func, - int_reward_func, - correctness_reward_func, - ], - args=training_args, - train_dataset=train_dataset, - eval_dataset=test_dataset, - callbacks=trainer_callbacks, - ) - else: - raise ConfigurationException(f"Unsupported trainer type: {trainer_type}") - - self._tracker_client.log_model_config(model.config.to_dict()) - self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version) - - logger.info(f"Performing {trainer_type.upper()} training...") - trainer.train() - - if cancel_event_check_callback.training_cancelled: - raise TrainingCancelledException("Training was cancelled by the user") - - if not skip_save_model: - model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE) - model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}" - retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name) - model.save_pretrained( - copied_model_directory, - safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), - ) - create_model_data_package(copied_model_directory, retrained_model_pack_path) - model_uri = self._tracker_client.save_model( - retrained_model_pack_path, - self._model_name, - self._model_manager, - ) - logger.info(f"Retrained model saved: {model_uri}") - else: - logger.info("Skipped saving on the retrained model") - if redeploy: - self.deploy_model(self._model_service, model, tokenizer) - else: - del model - del tokenizer - gc.collect() - logger.info("Skipped deployment on the retrained model") - logger.info("Supervised training finished") - self._tracker_client.end_with_success() - except TrainingCancelledException as e: - logger.exception(e) - logger.info("Supervised training was cancelled") - del model - gc.collect() - self._tracker_client.end_with_interruption() - except torch.OutOfMemoryError as e: - logger.exception("Supervised training failed on CUDA OOM") - try: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - try: - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_accumulated_memory_stats() - except Exception: - pass - torch.cuda.synchronize() - except Exception: - pass - self._tracker_client.log_exceptions(e) - self._tracker_client.end_with_failure() - except Exception as e: - logger.exception("Supervised training failed") - self._tracker_client.log_exceptions(e) - self._tracker_client.end_with_failure() - finally: - data_file.close() - with self._training_lock: - self._training_in_progress = False - self._clean_up_training_cache() - self._housekeep_file(copied_model_pack_path) - del trainer - gc.collect() - torch.cuda.empty_cache() - else: - try: - logger.info("Evaluating the running model...") - model, tokenizer = self._model_service.load_model(self._model_pack_path) - if non_default_device_is_available(self._config.DEVICE): - model.to(self._config.DEVICE) - - eval_dataset, _ = self._load_dataset_from_config(data_file, training_params) - make_conversation = self._create_conversation_formatter(training_params) - eval_dataset = eval_dataset.map(make_conversation) - - data_collator = self._LocalDataCollator( - max_length=self._max_length, - pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 - ) - - training_args = TrainingArguments( - output_dir=results_path, - logging_dir=logs_path, - per_device_eval_batch_size=1, - do_train=False, - do_eval=True, - report_to="none", - dataloader_drop_last=False, - ) - - trainer = Trainer( - model=model, - args=training_args, - data_collator=data_collator, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - ) - - eval_metrics = trainer.evaluate() - logger.info(f"Evaluation metrics: {eval_metrics}") - self._tracker_client.send_hf_metrics_logs(eval_metrics, 0) - self._tracker_client.end_with_success() - logger.info("Model evaluation finished") - except Exception as e: - logger.exception("Model evaluation failed") - self._tracker_client.log_exceptions(e) - self._tracker_client.end_with_failure() - finally: - data_file.close() - with self._training_lock: - self._training_in_progress = False - self._clean_up_training_cache() - - -@final -class MLflowLoggingCallback(TrainerCallback): - """ - A callback class for logging training metrics to MLflow. - - Args: - tracker_client (TrackerClient): An instance of TrackerClient used for logging. - """ - - def __init__(self, tracker_client: TrackerClient) -> None: - self.tracker_client = tracker_client - self.epoch = 0 - - def on_log( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - logs: Dict[str, float], - **kwargs: Dict[str, Any], - ) -> None: - """ - Logs metrics at the end of each epoch. - - Args: - args (TrainingArguments): The arguments used for training. - state (TrainerState): The current state of the Trainer. - control (TrainerControl): The current control of the Trainer. - logs (Dict[str, float]): A dictionary containing the metrics to log. - **kwargs (Dict[str, Any]): Additional keyword arguments. - """ - - if logs is not None: - if logs.get("eval_loss", None) is not None: - logs["perplexity"] = math.exp(logs["eval_loss"]) - self.tracker_client.send_hf_metrics_logs(logs, self.epoch) - self.epoch += 1 - - -@final -class CancelEventCheckCallback(TrainerCallback): - """ - A callback class for checking a cancellation event during training. - - Args: - cancel_event (threading.Event): A threading event that signals whether training should be cancelled. - """ - - def __init__(self, cancel_event: threading.Event) -> None: - self.cancel_event = cancel_event - self.training_cancelled = False - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs: Dict[str, Any], - ) -> None: - """ - Checks if the training should be cancelled at the end of each training step. - - Args: - args (TrainingArguments): The arguments used for training. - state (TrainerState): The current state of the Trainer. - control (TrainerControl): The current control of the Trainer. - **kwargs (Dict[str, Any]): Additional keyword arguments. - """ - - if self.cancel_event.is_set(): - control.should_training_stop = True - self.cancel_event.clear() - self.training_cancelled = True diff --git a/app/utils.py b/app/utils.py index 370c4e2..fcc231e 100644 --- a/app/utils.py +++ b/app/utils.py @@ -740,55 +740,6 @@ def download_model_package( retry_delay *= 2 -def get_default_chat_template() -> str: - """ - Gets the default chat template. - - Returns: - str: The default chat template. - """ - - return ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" - "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ '[INST] ' + content + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content + ' ' }}" - "{% endif %}" - "{% endfor %}" - ) - - -def get_default_system_prompt() -> str: - """ - Gets the default system prompt. - - Returns: - str: The default system prompt. - """ - return ( - "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " - "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " - "process and answer are enclosed within and tags, respectively, i.e., " - " reasoning process here answer here " - ) - - def get_prompt_from_messages( tokenizer: PreTrainedTokenizer, messages: List[PromptMessage], @@ -845,7 +796,6 @@ def get_prompt_from_messages( ) return prompt - TYPE_ID_TO_NAME_PATCH = { "32816260": "physical object", "2680757": "observable entity", diff --git a/pyproject.toml b/pyproject.toml index 0aaaaea..eaba606 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,8 +82,6 @@ docs = [ vllm = [ "vllm~=0.8.5; python_version >= '3.9'", - "trl>=0.11.4", - "bitsandbytes>=0.45.5", ] # For pip versions not supporting PEP 735 @@ -113,8 +111,6 @@ docs = [ vllm = [ "vllm~=0.8.5; python_version >= '3.9'", - "trl>=0.11.4", - "bitsandbytes>=0.45.5", ] [tool.setuptools] diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py index 1d5f077..1a0399b 100644 --- a/tests/app/api/test_api.py +++ b/tests/app/api/test_api.py @@ -27,6 +27,7 @@ def test_get_model_server(): assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags assert {"name": "Authentication", "description": "Authenticate registered users"} in tags + assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags assert "/info" in paths assert "/process" in paths assert "/process_jsonl" in paths @@ -90,8 +91,7 @@ def test_get_generative_server(): assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Metadata", "description": "Get the model card"} in tags - assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags + assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags assert "/info" in paths assert "/generate" in paths assert "/stream/generate" in paths diff --git a/tests/app/processors/test_metrics_collector.py b/tests/app/processors/test_metrics_collector.py index 6e53275..bd46fff 100644 --- a/tests/app/processors/test_metrics_collector.py +++ b/tests/app/processors/test_metrics_collector.py @@ -7,7 +7,6 @@ from app.processors.metrics_collector import ( sanity_check_model_with_trainer_export, concat_trainer_exports, - concat_json_lists, get_stats_from_trainer_export, get_iaa_scores_per_concept, get_iaa_scores_per_doc, @@ -333,56 +332,3 @@ def test_get_iaa_scores_per_span_and_return_dataframe(): assert len(result["cohens_kappa"]) == 30 assert len(result["iaa_percentage_meta"]) == 30 assert len(result["cohens_kappa_meta"]) == 30 - - -def test_concat_json_lists_return_list(): - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1: - json.dump([{"question": "question_1", "answer": "answer_1"}, {"question": "question_2", "answer": "answer_2"}], f1) - file1_path = f1.name - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2: - json.dump([{"question": "question_3", "answer": "answer_3"}], f2) - file2_path = f2.name - - try: - result = concat_json_lists([file1_path, file2_path]) - - assert isinstance(result, list) - assert len(result) == 3 - assert result[0] == {"question": "question_1", "answer": "answer_1"} - assert result[1] == {"question": "question_2", "answer": "answer_2"} - assert result[2] == {"question": "question_3", "answer": "answer_3"} - finally: - os.unlink(file1_path) - os.unlink(file2_path) - - -def test_concat_json_lists_save_to_file(): - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1: - json.dump([{"question": "question_1", "answer": "answer_1"}], f1) - file1_path = f1.name - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2: - json.dump([{"question": "question_2", "answer": "answer_2"}], f2) - file2_path = f2.name - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as output_file: - output_path = output_file.name - - try: - result = concat_json_lists([file1_path, file2_path], output_path) - - assert isinstance(result, str) - assert result == output_path - - with open(output_path, 'r') as f: - saved_data = json.load(f) - - assert isinstance(saved_data, list) - assert len(saved_data) == 2 - assert saved_data[0] == {"question": "question_1", "answer": "answer_1"} - assert saved_data[1] == {"question": "question_2", "answer": "answer_2"} - finally: - os.unlink(file1_path) - os.unlink(file2_path) - os.unlink(output_path) diff --git a/uv.lock b/uv.lock index 255a329..cc97b01 100644 --- a/uv.lock +++ b/uv.lock @@ -734,55 +734,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d6/59/831b66ba317496332d4e9e1a33bcdd14922d6cfecc411dc315a229b67127/bcrypt-4.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba4e4cc26610581a6329b3937e02d319f5ad4b85b074846bf4fef8a8cf51e7bb", size = 698384 }, ] -[[package]] -name = "bitsandbytes" -version = "0.45.5" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9' and sys_platform != 'win32'", - "python_full_version < '3.9' and sys_platform == 'win32'", -] -dependencies = [ - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/b7/cb5ce4d1a382cf53c19ef06c5fc29e85f5e129b4da6527dd207d90a5b8ad/bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a5453f30cc6aab6ccaac364e6bf51a7808d3da5f71763dffeb6d9694c59136e4", size = 76059261 }, - { url = "https://files.pythonhosted.org/packages/a6/4c/77b535e025ce780d2ada8271c1e481fb7337c1df2588a52fe1c9bd87d2e8/bitsandbytes-0.45.5-py3-none-win_amd64.whl", hash = "sha256:ed1c61b91d989d6a33fd05737d6edbf5086d8ebc89235ee632c7a19144085da2", size = 75430204 }, -] - -[[package]] -name = "bitsandbytes" -version = "0.47.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version > '3.11' and sys_platform == 'darwin'", - "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.11' and sys_platform == 'darwin'", - "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.10.*' and sys_platform == 'darwin'", - "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version > '3.11' and sys_platform == 'win32'", - "python_full_version == '3.11' and sys_platform == 'win32'", - "python_full_version == '3.10.*' and sys_platform == 'win32'", - "python_full_version == '3.9.*' and sys_platform == 'win32'", -] -dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/eb/477d6b5602f469c7305fd43eec71d890c39909f615c1d7138f6e7d226eff/bitsandbytes-0.47.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:2f805b76891a596025e9e13318b675d08481b9ee650d65e5d2f9d844084c6521", size = 30004641 }, - { url = "https://files.pythonhosted.org/packages/9c/40/91f1a5a694f434bc13cba160045fdc4e867032e627b001bf411048fefd9c/bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:68f3fffd494a47ed1fd7593bfc5dd2ac69b68260599b71b4c4b3a32f90f3b184", size = 61284639 }, - { url = "https://files.pythonhosted.org/packages/18/a9/e07a227f1cd6562844cea2f05ee576b0991a9a91f45965c06034178ba0f6/bitsandbytes-0.47.0-py3-none-win_amd64.whl", hash = "sha256:4880a6d42ca9628b5a571c8cc3093dc3f5f52511e5a9e47d52d569807975531a", size = 60725121 }, -] - [[package]] name = "blake3" version = "1.0.5" @@ -1283,10 +1234,6 @@ docs = [ { name = "sphinx-rtd-theme" }, ] vllm = [ - { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "vllm", marker = "python_full_version >= '3.9'" }, ] @@ -1315,10 +1262,6 @@ docs = [ { name = "sphinx-rtd-theme" }, ] vllm = [ - { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "vllm", marker = "python_full_version >= '3.9'" }, ] @@ -1326,7 +1269,6 @@ vllm = [ requires-dist = [ { name = "aiosqlite", specifier = "~=0.19.0" }, { name = "asyncpg", specifier = "~=0.27.0" }, - { name = "bitsandbytes", marker = "extra == 'vllm'", specifier = ">=0.45.5" }, { name = "blis", specifier = "<1.0.0" }, { name = "boto3", specifier = "~=1.28.84" }, { name = "click", specifier = "<8.2.0" }, @@ -1367,7 +1309,6 @@ requires-dist = [ { name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = "~=3.0.2" }, { name = "toml", specifier = "~=0.10.2" }, { name = "torch", marker = "python_full_version < '3.9'", specifier = "<2.5.0" }, - { name = "trl", marker = "extra == 'vllm'", specifier = ">=0.11.4" }, { name = "typer", specifier = "~=0.15.1" }, { name = "typer-cli", marker = "extra == 'dev'", specifier = "~=0.15.1" }, { name = "types-toml", marker = "extra == 'dev'", specifier = "==0.10.8.20240310" }, @@ -1400,11 +1341,7 @@ docs = [ { name = "sphinx-autodoc-typehints", specifier = "~=2.0.1" }, { name = "sphinx-rtd-theme", specifier = "~=3.0.2" }, ] -vllm = [ - { name = "bitsandbytes", specifier = ">=0.45.5" }, - { name = "trl", specifier = ">=0.11.4" }, - { name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" }, -] +vllm = [{ name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" }] [[package]] name = "colorama" @@ -2010,15 +1947,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774 }, ] -[[package]] -name = "docstring-parser" -version = "0.16" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/08/12/9c22a58c0b1e29271051222d8906257616da84135af9ed167c9e28f85cb3/docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e", size = 26565 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637", size = 36533 }, -] - [[package]] name = "docutils" version = "0.20.1" @@ -2051,15 +1979,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/05/8b171626b850e870fc4433225cd6d5bec5a9916b1c39b3d7c67a60492aeb/email_validator-2.1.2-py3-none-any.whl", hash = "sha256:d89f6324e13b1e39889eab7f9ca2f91dc9aebb6fa50a6d8bd4329ab50f251115", size = 30739 }, ] -[[package]] -name = "eval-type-backport" -version = "0.2.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/30/ea/8b0ac4469d4c347c6a385ff09dc3c048c2d021696664e26c7ee6791631b5/eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1", size = 9079 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830 }, -] - [[package]] name = "evaluate" version = "0.4.3" @@ -8179,15 +8098,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, ] -[[package]] -name = "shtab" -version = "1.7.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5a/3e/837067b970c1d2ffa936c72f384a63fdec4e186b74da781e921354a94024/shtab-1.7.2.tar.gz", hash = "sha256:8c16673ade76a2d42417f03e57acf239bfb5968e842204c17990cae357d07d6f", size = 45751 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/74/03/3271b7bb470fbab4adf5bd30b0d32143909d96f3608d815b447357f47f2b/shtab-1.7.2-py3-none-any.whl", hash = "sha256:858a5805f6c137bb0cda4f282d27d08fd44ca487ab4a6a36d2a400263cd0b5c1", size = 14214 }, -] - [[package]] name = "six" version = "1.17.0" @@ -9444,73 +9354,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/74/9f12bdedeb110242d8bb1bd621f6605e753ee0cbf73cf7f3a62b8173f190/triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee", size = 253057866 }, ] -[[package]] -name = "trl" -version = "0.11.4" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9' and sys_platform != 'win32'", - "python_full_version < '3.9' and sys_platform == 'win32'", -] -dependencies = [ - { name = "accelerate", version = "1.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "datasets", marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "transformers", version = "4.46.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "tyro", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/07/39/a78c0608190cc412c49631dfb8c3e57f5c5b2fb0d79709071c992e707aa4/trl-0.11.4.tar.gz", hash = "sha256:de52a023fc35d580ab809fd74cd4f362a259e463bb968580e0e97e1b98a0fe79", size = 307304 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/dd/d2cf3dbc1013cee71ceef584f5ab69915fc05d209ef1e276f8652058c350/trl-0.11.4-py3-none-any.whl", hash = "sha256:071d64164c152ef65b44d15f878793b28d3340310c9e157dc3608bbe5fa549a9", size = 316575 }, -] - -[[package]] -name = "trl" -version = "0.15.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version > '3.11' and sys_platform == 'darwin'", - "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.11' and sys_platform == 'darwin'", - "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.10.*' and sys_platform == 'darwin'", - "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version > '3.11' and sys_platform == 'win32'", - "python_full_version == '3.11' and sys_platform == 'win32'", - "python_full_version == '3.10.*' and sys_platform == 'win32'", - "python_full_version == '3.9.*' and sys_platform == 'win32'", -] -dependencies = [ - { name = "accelerate", version = "1.7.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "datasets", marker = "python_full_version >= '3.9'" }, - { name = "rich", marker = "python_full_version >= '3.9'" }, - { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/95/fe/ae0d782c48eef56d0ec125ebd05998539ede7cbf0e307a48f9323998b9e7/trl-0.15.2.tar.gz", hash = "sha256:0f82190a058a0a194dbcfae1fe9548b68a0a05b2f4d1824f8db1ae7d949cdd47", size = 333962 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/29/25378447c48359843de0e4ce1995d367210601c3b437ddf1c779b6393d74/trl-0.15.2-py3-none-any.whl", hash = "sha256:bf2b88e3cf5da08cd533dc03273d977965bd5d86c5878f76285fba45d9cb9634", size = 318931 }, -] - -[[package]] -name = "typeguard" -version = "4.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "typing-extensions", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/5a/91b7c8cfc2e96962442abc9d65c650436dd831910b4d7878980d6596fb98/typeguard-4.4.0.tar.gz", hash = "sha256:463bd8697a65a4aa576a63767c369b1ecfba8a5ba735edfe3223127b6ecfa28c", size = 74399 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/a3/00203767544b597a9e3c57b29a84967b3230f00bdd9aa6a52a73187043b4/typeguard-4.4.0-py3-none-any.whl", hash = "sha256:8ca34c14043f53b2caae7040549ba431770869bcd6287cfa8239db7ecb882b4a", size = 35736 }, -] - [[package]] name = "typer" version = "0.15.4" @@ -9568,24 +9411,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, ] -[[package]] -name = "tyro" -version = "0.9.24" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "python_full_version < '3.9' and sys_platform == 'win32'" }, - { name = "docstring-parser", marker = "python_full_version < '3.9'" }, - { name = "eval-type-backport", marker = "python_full_version < '3.9'" }, - { name = "rich", marker = "python_full_version < '3.9'" }, - { name = "shtab", marker = "python_full_version < '3.9'" }, - { name = "typeguard", marker = "python_full_version < '3.9'" }, - { name = "typing-extensions", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/57/49/ca1698fcc5479fe9c7eff48861ebb671c5a6afba0245ea7cd560a939f281/tyro-0.9.24.tar.gz", hash = "sha256:5a9ef93d1b8e93cff2c5d82789a571d905d152e92af82a3ec96a17d668194df3", size = 303651 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/59/4c865b56babef1aa6a9662879c94507dc62d0173ac7433579d7a2728f7e5/tyro-0.9.24-py3-none-any.whl", hash = "sha256:d8152e47375419752210da455226007b4bb9bd9c65af1de8fb12daf0658c91dc", size = 128326 }, -] - [[package]] name = "tzdata" version = "2025.2" From 9d18564f0ee023d28cd5afb94e8a0fdfd63ba01a Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 5/9] Revert "feat: add the text embedding endpoint for LLM serving" This reverts commit d76098609662430b74e02c1dcc463b090d7a1dac. --- app/api/routers/generative.py | 110 ++---------------- app/domain.py | 15 +-- app/model_services/base.py | 27 +---- app/model_services/huggingface_llm_model.py | 50 +------- tests/app/api/test_serving_hf_llm.py | 22 ---- .../test_huggingface_llm_model.py | 104 +---------------- 6 files changed, 13 insertions(+), 315 deletions(-) diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index 26f4fd1..763c717 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -10,16 +10,8 @@ from fastapi import APIRouter, Depends, Request, Body, Query from fastapi.encoders import jsonable_encoder from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR -from app.domain import ( - Tags, - OpenAIChatRequest, - OpenAIChatResponse, - OpenAIEmbeddingsRequest, - OpenAIEmbeddingsResponse, - PromptMessage, - PromptRole, -) +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST +from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole from app.model_services.base import AbstractModelService from app.utils import get_settings, get_prompt_from_messages from app.api.utils import get_rate_limiter @@ -29,7 +21,6 @@ PATH_GENERATE = "/generate" PATH_GENERATE_ASYNC = "/stream/generate" PATH_OPENAI_COMPLETIONS = "/v1/chat/completions" -PATH_OPENAI_EMBEDDINGS = "/v1/embeddings" router = APIRouter() config = get_settings() @@ -143,7 +134,7 @@ async def generate_text_stream( @router.post( PATH_OPENAI_COMPLETIONS, - tags=[Tags.OpenAICompatible.name], + tags=[Tags.Generative.name], response_model=None, dependencies=[Depends(cms_globals.props.current_active_user)], description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions", @@ -171,7 +162,6 @@ def generate_chat_completions( """ messages = request_data.messages - model = model_service.model_name if request_data.model != model_service.model_name else request_data.model stream = request_data.stream max_tokens = request_data.max_tokens temperature = request_data.temperature @@ -234,7 +224,7 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene id=tracking_id, object="chat.completion", created=int(time.time()), - model=model, + model=model_service.model_name, choices=[ { "index": 0, @@ -249,100 +239,14 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id}) -@router.post( - PATH_OPENAI_EMBEDDINGS, - tags=[Tags.OpenAICompatible.name], - response_model=None, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint", -) -def embed_texts( - request: Request, - request_data: Annotated[OpenAIEmbeddingsRequest, Body( - description="Text(s) to be embedded", media_type="application/json" - )], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep) -) -> JSONResponse: - """ - Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint. - - Args: - request (Request): The request object. - request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s). - tracking_id (Union[str, None]): An optional tracking ID of the requested task. - model_service (AbstractModelService): The model service dependency. - - Returns: - JSONResponse: A response containing the embeddings of the text(s). - """ - tracking_id = tracking_id or str(uuid.uuid4()) - - if not hasattr(model_service, "create_embeddings"): - error_response = { - "error": { - "message": "Model does not support embeddings", - "type": "invalid_request_error", - "param": "model", - "code": "model_not_supported", - } - } - return JSONResponse( - content=error_response, - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - headers={"x-cms-tracking-id": tracking_id}, - ) - - input_text = request_data.input - model = model_service.model_name if request_data.model != model_service.model_name else request_data.model - - if isinstance(input_text, str): - input_texts = [input_text] - else: - input_texts = input_text - - try: - embeddings_data = [] - - for i, embedding in enumerate(model_service.create_embeddings(input_texts)): - embeddings_data.append({ - "object": "embedding", - "embedding": embedding, - "index": i, - }) - - response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model) - - return JSONResponse( - content=jsonable_encoder(response), - headers={"x-cms-tracking-id": tracking_id}, - ) - - except Exception as e: - logger.error("Failed to create embeddings") - logger.exception(e) - error_response = { - "error": { - "message": f"Failed to create embeddings: {str(e)}", - "type": "server_error", - "code": "internal_error", - } - } - return JSONResponse( - content=error_response, - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - headers={"x-cms-tracking-id": tracking_id}, - ) - - def _empty_prompt_error() -> Iterable[str]: yield "ERROR: No prompt text provided\n" def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None: cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num) - logger.debug("Sent prompt tokens usage: %s", prompt_token_num) + logger.debug(f"Sent prompt tokens usage: {prompt_token_num}") cms_completion_tokens.labels(handler=handler).observe(completion_token_num) - logger.debug("Sent completion tokens usage: %s", completion_token_num) + logger.debug(f"Sent completion tokens usage: {completion_token_num}") cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num) - logger.debug("Sent total tokens usage: %s", prompt_token_num + completion_token_num) + logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}") diff --git a/app/domain.py b/app/domain.py index de098a8..2ddd6e3 100644 --- a/app/domain.py +++ b/app/domain.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, Dict, Any, Union +from typing import List, Optional, Dict, Any from fastapi import HTTPException from starlette.status import HTTP_400_BAD_REQUEST @@ -27,7 +27,6 @@ class Tags(str, Enum): Evaluating = "Evaluate the deployed model with trainer export" Authentication = "Authenticate registered users" Generative = "Generate text based on the input prompt" - OpenAICompatible = "Compatible with OpenAI APIs" class TagsStreamable(str, Enum): @@ -186,7 +185,6 @@ class OpenAIChatRequest(BaseModel): messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model") stream: bool = Field(..., description="Whether to stream the response") max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0) - model: str = Field(..., description="The name of the model used for generating the completion") temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0) @@ -196,14 +194,3 @@ class OpenAIChatResponse(BaseModel): created: int = Field(..., description="The timestamp when the completion was generated") model: str = Field(..., description="The name of the model used for generating the completion") choices: List = Field(..., description="The generated messages and their metadata") - - -class OpenAIEmbeddingsRequest(BaseModel): - input: Union[str, List[str]] = Field(..., description="Input text or list of texts to embed") - model: str = Field(..., description="The name of the model used for creating the embeddings") - - -class OpenAIEmbeddingsResponse(BaseModel): - object: str = Field(..., description="The type of the response") - data: List[Dict[str, Any]] = Field(..., description="List of embedding objects") - model: str = Field(..., description="The name of the model used for creating the embeddings") diff --git a/app/model_services/base.py b/app/model_services/base.py index dfde491..a3c1ccc 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -1,6 +1,6 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable, Union +from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable from app.config import Settings from app.domain import ModelCard, Annotation @@ -17,7 +17,7 @@ def tracker_client(self) -> Any: T = TypeVar("T", bound=_TrainerCommon) class AbstractModelService(ABC, Generic[T]): - """An abstract base class defining the common interface for NER model services.""" + """An abstract base class defining the common interface for all model services.""" @abstractmethod def __init__(self, config: Settings, *args: Any, **kwargs: Any) -> None: @@ -200,29 +200,6 @@ def generate_async(self, prompt: str, *args: Any, **kwargs: Any) -> AsyncIterabl raise NotImplementedError - def create_embeddings( - self, - text: Union[str, List[str]], - *args: Any, - **kwargs: Any - ) -> Union[List[float], List[List[float]]]: - """ - Creates embeddings for a given text or list of texts. - - Args: - text (Union[str, List[str]]): The text(s) to be embedded. - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - - Returns: - Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s). - - Raises: - NotImplementedError: If the method is not implemented by the subclass. - """ - - raise NotImplementedError - def train_supervised(self, *args: Any, **kwargs: Any) -> Tuple[bool, str, str]: """ Initiates supervised training on the model. diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 25cd032..7340f67 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -1,9 +1,8 @@ import os import logging import asyncio -import torch from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union +from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -308,50 +307,3 @@ async def generate_async( return finally: logger.debug("Chat response generation completed") - - def create_embeddings( - self, - text: Union[str, List[str]], - *args: Any, - **kwargs: Any - ) -> Union[List[float], List[List[float]]]: - """ - Creates embeddings for a given text or list of texts using the model's hidden states. - - Args: - text (Union[str, List[str]]): The text(s) to be embedded. - *args (Any): Additional positional arguments to be passed to this method. - **kwargs (Any): Additional keyword arguments to be passed to this method. - - Returns: - List[float], List[List[float]]: The embedding vector(s) for the text(s). - - Raises: - NotImplementedError: If the model doesn't support embeddings. - """ - - self.model.eval() - - inputs = self.tokenizer( - text, - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True, - ) - - if non_default_device_is_available(self._config.DEVICE): - inputs.to(get_settings().DEVICE) - - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) - - last_hidden_state = outputs.hidden_states[-1] - attention_mask = inputs["attention_mask"] - masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1) - sum_hidden_states = masked_hidden_states.sum(dim=1) - num_tokens = attention_mask.sum(dim=1, keepdim=True) - embeddings = sum_hidden_states / num_tokens - - results = embeddings.cpu().numpy().tolist() - return results[0] if isinstance(text, str) else results diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py index 9b7ea9f..4bd4e9d 100644 --- a/tests/app/api/test_serving_hf_llm.py +++ b/tests/app/api/test_serving_hf_llm.py @@ -31,9 +31,7 @@ def llm_app(llm_model_service): @pytest.fixture(scope="function") def client(llm_model_service): - llm_model_service.model_name = "HuggingFace LLM model" llm_model_service.generate.return_value = "Yeah." - llm_model_service.create_embeddings.return_value = [[1.0, 2.0, 3.0]] app = get_generative_server(config, msd_overwritten=lambda: llm_model_service) app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None client = TestClient(app) @@ -84,7 +82,6 @@ async def test_generate_chat_completions(llm_model_service, llm_app): "content": "Who are you?" } ], - "model": "HuggingFace LLM model", "stream": True, "max_tokens": 128, "temperature": 0.7 @@ -101,22 +98,3 @@ async def test_generate_chat_completions(llm_model_service, llm_app): assert response.text.startswith("data:") assert "id" in response.text assert "chat.completion.chunk" in response.text - - -def test_create_embeddings(client): - request_data = { - "input": ["Alright"], - "model": "HuggingFace LLM model", - } - response = client.post( - "/v1/embeddings", - data=json.dumps(request_data), - headers={"Content-Type": "application/json"}, - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - assert response.json() == { - "object": "list", - "data": [{"object": "embedding", "embedding": [1.0, 2.0, 3.0], "index": 0}], - "model": "HuggingFace LLM model" - } diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index e6ee637..7cd4941 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -1,5 +1,5 @@ import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from tests.app.conftest import MODEL_PARENT_DIR from transformers import PreTrainedModel, PreTrainedTokenizerBase from app import __version__ @@ -129,104 +129,4 @@ async def test_generate_async(huggingface_llm_model): prompt_token_num=2, completion_token_num=2, ) - assert result == "Yeah." - - -def test_create_embeddings_single_text(huggingface_llm_model): - """Test create_embeddings with single text input.""" - huggingface_llm_model.init_model() - huggingface_llm_model.model = MagicMock() - huggingface_llm_model.tokenizer = MagicMock() - mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] - mock_outputs = MagicMock() - mock_outputs.hidden_states = mock_hidden_states - mock_last_hidden_state = MagicMock() - mock_last_hidden_state.shape = [1, 3, 768] - mock_hidden_states[-1] = mock_last_hidden_state - mock_attention_mask = MagicMock() - mock_attention_mask.shape = [1, 3] - mock_attention_mask.sum.return_value = MagicMock() - mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() - mock_inputs = MagicMock() - mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() - huggingface_llm_model.tokenizer.return_value = mock_inputs - huggingface_llm_model.model.return_value = mock_outputs - expected_result = [0.1, 0.2, 0.3] - mock_embeddings_batch = MagicMock() - mock_first_embedding = MagicMock() - mock_cpu_tensor = MagicMock() - mock_numpy_array = MagicMock() - mock_numpy_array.tolist.return_value = expected_result - mock_embeddings_batch.__getitem__.return_value = mock_first_embedding - mock_first_embedding.cpu.return_value = mock_cpu_tensor - mock_cpu_tensor.numpy.return_value = mock_numpy_array - mock_masked_hidden_states = MagicMock() - mock_sum_hidden_states = MagicMock() - mock_num_tokens = MagicMock() - mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states - mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states - mock_attention_mask.sum.return_value = mock_num_tokens - mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch - - result = huggingface_llm_model.create_embeddings("Alright") - - huggingface_llm_model.model.eval.assert_called_once() - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright", - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True - ) - huggingface_llm_model.model.assert_called_once_with( - **mock_inputs, - output_hidden_states=True - ) - - assert result is not None - - -def test_create_embeddings_list_text(huggingface_llm_model): - huggingface_llm_model.init_model() - huggingface_llm_model.model = MagicMock() - huggingface_llm_model.tokenizer = MagicMock() - mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] - mock_outputs = MagicMock() - mock_outputs.hidden_states = mock_hidden_states - mock_last_hidden_state = MagicMock() - mock_last_hidden_state.shape = [2, 3, 768] - mock_hidden_states[-1] = mock_last_hidden_state - mock_attention_mask = MagicMock() - mock_attention_mask.shape = [2, 3] - mock_attention_mask.sum.return_value = MagicMock() - mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() - mock_inputs = MagicMock() - mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() - huggingface_llm_model.tokenizer.return_value = mock_inputs - huggingface_llm_model.model.return_value = mock_outputs - mock_embeddings_batch = MagicMock() - mock_first_embedding = MagicMock() - mock_cpu_tensor = MagicMock() - mock_numpy_array = MagicMock() - mock_numpy_array.tolist.return_value = [[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]] - mock_embeddings_batch.__getitem__.return_value = mock_first_embedding - mock_first_embedding.cpu.return_value = mock_cpu_tensor - mock_cpu_tensor.numpy.return_value = mock_numpy_array - mock_masked_hidden_states = MagicMock() - mock_sum_hidden_states = MagicMock() - mock_num_tokens = MagicMock() - mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states - mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states - mock_attention_mask.sum.return_value = mock_num_tokens - mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch - - result = huggingface_llm_model.create_embeddings(["Alright", "Alright"]) - - huggingface_llm_model.tokenizer.assert_called_once_with( - ["Alright", "Alright"], - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True, - ) - assert result is not None + assert result == "Yeah." \ No newline at end of file From eb422a3b2aceaf9f9047c008010b5e9ad2320d29 Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 6/9] Revert "feat: add the chat template factory" This reverts commit 587e9ea3c5b87bca3f668fd641d60e5b3d23057a. --- app/processors/prompt_factory.py | 262 -------------------- app/utils.py | 66 ++--- tests/app/processors/test_prompt_factory.py | 228 ----------------- 3 files changed, 26 insertions(+), 530 deletions(-) delete mode 100644 app/processors/prompt_factory.py delete mode 100644 tests/app/processors/test_prompt_factory.py diff --git a/app/processors/prompt_factory.py b/app/processors/prompt_factory.py deleted file mode 100644 index ee1a45c..0000000 --- a/app/processors/prompt_factory.py +++ /dev/null @@ -1,262 +0,0 @@ -class PromptFactory: - - _ALPACA = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'].strip() + '\n' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{% set content = system_message + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - - "{% if message['role'] == 'user' %}" - "{{ '### Instruction:\n' + content.strip() + '\n\n'}}" - "{% elif message['role'] == 'assistant' %}" - "{{ '### Response:\n' + content.strip() + '\n\n' }}" - "{% endif %}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '### Response:\n' }}" - "{% endif %}" - ) - - _CHAT_ML = ( - "{% for message in messages %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() + '<|im_end|>' + '\n'}}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{'<|im_start|>assistant\n'}}" - "{% endif %}" - ) - - _DEFAULT = ( - "{% for message in messages %}" - "{% if message['role'] == 'user' %}" - "{{'<|user|>\n' + message['content'] + eos_token}}" - "{% elif message['role'] == 'system' %}" - "{{'<|system|>\n' + message['content'] + eos_token}}" - "{% elif message['role'] == 'assistant' %}" - "{{'<|assistant|>\n' + message['content'] + eos_token}}" - "{% endif %}" - "{% if loop.last and add_generation_prompt %}" - "{{'<|assistant|>'}}" - "{% endif %}" - "{% endfor %}" - ) - - _FALCON = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{{ system_message.strip() }}" - "{% endif %}" - "{{ '\n\n' + message['role'].title() + ': ' + message['content'].strip().replace('\r\n', '\n').replace('\n\n', '\n') }}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{ '\n\nAssistant:' }}" - "{% endif %}" - ) - - _GEMMA = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{% set content = system_message + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if (message['role'] == 'assistant') %}" - "{% set role = 'model' %}" - "{% else %}" - "{% set role = message['role'] %}" - "{% endif %}" - "{{ '' + role + '\n' + content.strip() + '\n' }}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{'model\n'}}" - "{% endif %}" - ) - - _LLAMA_2 = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = '<>\n' + messages[0]['content'].strip() + '\n<>\n\n' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{% set content = system_message + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - - _LLAMA_3 = ( - "{{ bos_token }}" - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + messages[0]['content'].strip() + '<|eot_id|>' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{{ system_message }}" - "{% endif %}" - "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'].strip() + '<|eot_id|>' }}" - "{% if loop.last and message['role'] == 'user' and add_generation_prompt %}" - "{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}" - "{% endif %}" - "{% endfor %}" - ) - - _MISTRAL = ( - "{{ bos_token }}" - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{% set content = system_message + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ content.strip() + eos_token}}" - "{% endif %}" - "{% endfor %}" - ) - - _PHI_2 = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = '' %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 %}" - "{% set content = system_message + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ 'Instruct: ' + content.strip() + '\n' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ 'Output: ' + content.strip() + '\n' }}" - "{% endif %}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ 'Output:' }}" - "{% endif %}" - ) - - _PHI_3 = ( - "{{ bos_token }}" - "{% for message in messages %}" - "{% if (message['role'] == 'system') %}" - "{{'<|system|>' + '\n' + message['content'].strip() + '<|end|>' + '\n'}}" - "{% elif (message['role'] == 'user') %}" - "{{'<|user|>' + '\n' + message['content'].strip() + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}" - "{% elif message['role'] == 'assistant' %}" - "{{message['content'].strip() + '<|end|>' + '\n'}}" - "{% endif %}" - "{% endfor %}" - ) - - _QWEN = ( - "{% for message in messages %}" - "{% if loop.first and messages[0]['role'] != 'system' %}" - "{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}" - "{% endif %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() }}" - "{% if (loop.last and add_generation_prompt) or not loop.last %}" - "{{ '<|im_end|>' + '\n'}}" - "{% endif %}" - "{% endfor %}" - "{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}" - "{{ '<|im_start|>assistant\n' }}" - "{% endif %}" - ) - - @classmethod - def create_chat_template(cls, name: str = "default") -> str: - if name.lower() == "default": - return cls._DEFAULT - elif name.lower() == "alpaca": - return cls._ALPACA - elif name.lower() == "chat_ml": - return cls._CHAT_ML - elif name.lower() == "falcon": - return cls._FALCON - elif name.lower() == "gemma": - return cls._GEMMA - elif name.lower() == "llama_2": - return cls._LLAMA_2 - elif name.lower() == "llama_3": - return cls._LLAMA_3 - elif name.lower() == "mistral": - return cls._MISTRAL - elif name.lower() == "phi_2": - return cls._PHI_2 - elif name.lower() == "phi_3": - return cls._PHI_3 - elif name.lower() == "qwen": - return cls._QWEN - else: - raise ValueError("Invalid template name") diff --git a/app/utils.py b/app/utils.py index fcc231e..ee97d51 100644 --- a/app/utils.py +++ b/app/utils.py @@ -27,7 +27,6 @@ from app.config import Settings from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole from app.exception import ManagedModelException -from app.processors.prompt_factory import PromptFactory @lru_cache @@ -740,60 +739,47 @@ def download_model_package( retry_delay *= 2 -def get_prompt_from_messages( - tokenizer: PreTrainedTokenizer, - messages: List[PromptMessage], - override_template: Optional[str] = None, -) -> str: +def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[PromptMessage]) -> str: """ Generates a prompt from a list of prompt messages. Args: tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template. messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt. - override_template (str): The name of the chat template to use for generating the prompt. Returns: str: The generated prompt. """ - if override_template is None: - if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: - prompt = tokenizer.apply_chat_template( - [dump_pydantic_object_to_dict(message) for message in messages], - tokenize=False, - add_generation_prompt=True, - ) - elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template: - # This largely depends on how older versions of HF tokenizers behave and may not work universally - tokenizer.chat_template = tokenizer.default_chat_template - prompt = tokenizer.apply_chat_template( - [dump_pydantic_object_to_dict(message) for message in messages], - tokenize=False, - add_generation_prompt=True, - ) - else: - system_content = "" - prompt_parts: List[str] = [] - for message in messages: - content = message.content.strip() - if message.role == PromptRole.SYSTEM: - system_content = content - elif message.role == PromptRole.USER: - prompt_parts.append(f"<|user|>\n{content}") - elif message.role == PromptRole.ASSISTANT: - prompt_parts.append(f"<|assistant|>\n{content}") - if system_content: - prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts) - else: - prompt = "\n".join(prompt_parts) - prompt += "\n<|assistant|>\n" - else: - tokenizer.chat_template = PromptFactory.create_chat_template(name=override_template) + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + prompt = tokenizer.apply_chat_template( + [dump_pydantic_object_to_dict(message) for message in messages], + tokenize=False, + add_generation_prompt=True, + ) + elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template: + # This largely depends on how older versions of HF tokenizers behave and may not work universally + tokenizer.chat_template = tokenizer.default_chat_template prompt = tokenizer.apply_chat_template( [dump_pydantic_object_to_dict(message) for message in messages], tokenize=False, add_generation_prompt=True, ) + else: + system_content = "" + prompt_parts: List[str] = [] + for message in messages: + content = message.content.strip() + if message.role == PromptRole.SYSTEM: + system_content = content + elif message.role == PromptRole.USER: + prompt_parts.append(f"<|user|>\n{content}") + elif message.role == PromptRole.ASSISTANT: + prompt_parts.append(f"<|assistant|>\n{content}") + if system_content: + prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts) + else: + prompt = "\n".join(prompt_parts) + prompt += "\n<|assistant|>\n" return prompt TYPE_ID_TO_NAME_PATCH = { diff --git a/tests/app/processors/test_prompt_factory.py b/tests/app/processors/test_prompt_factory.py deleted file mode 100644 index ab03388..0000000 --- a/tests/app/processors/test_prompt_factory.py +++ /dev/null @@ -1,228 +0,0 @@ -from jinja2.sandbox import ImmutableSandboxedEnvironment -from app.processors.prompt_factory import PromptFactory - - -def test_create_default_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("default")) - prompt = template.render( - messages=messages, - bos_token="<|system|>", - eos_token="<|end|>", - add_generation_prompt=True, - ) - assert prompt == "<|system|>\nAlright?<|end|><|user|>\nYeah.<|end|><|assistant|>" - - -def test_create_alpaca_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("alpaca")) - prompt = template.render( - messages=messages, - add_generation_prompt=True, - ) - assert prompt == "### Instruction:\nAlright?\nYeah.\n\n### Response:\n" - - -def test_create_chat_ml_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("chat_ml")) - prompt = template.render( - messages=messages, - add_generation_prompt=True, - ) - assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n" - -def test_create_falcon_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("falcon")) - prompt = template.render( - messages=messages, - add_generation_prompt=True, - ) - assert prompt == "Alright?\n\nUser: Yeah.{ '\n\nAssistant:' }}" - -def test_create_gemma_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("gemma")) - prompt = template.render( - messages=messages, - add_generation_prompt=True, - ) - assert prompt == "user\nAlright?\n\nYeah.\nmodel\n" - -def test_create_llama_2_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("LLAMA_2")) - prompt = template.render( - messages=messages, - bos_token="", - eos_token="", - add_generation_prompt=True, - ) - assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]" - -def test_create_llama_3_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("LLAMA_2")) - prompt = template.render( - messages=messages, - bos_token="", - eos_token="", - add_generation_prompt=True, - ) - assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]" - -def test_create_mistral_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("mistral")) - prompt = template.render( - messages=messages, - bos_token="", - eos_token="", - add_generation_prompt=True, - ) - assert prompt == "[INST] Alright?\n\nYeah. [/INST]" - -def test_create_phi_2_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("phi_2")) - prompt = template.render( - messages=messages, - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - add_generation_prompt=True, - ) - assert prompt == "Instruct: Alright?\n\nYeah.\nOutput:" - -def test_create_phi_3_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("phi_3")) - prompt = template.render( - messages=messages, - bos_token="", - eos_token="<|end|>", - add_generation_prompt=True, - ) - assert prompt == "<|system|>\nAlright?<|end|>\n<|user|>\nYeah.<|end|>\n<|assistant|>\n" - -def test_create_qwen_chat_template(): - env = ImmutableSandboxedEnvironment() - messages = [ - { - "role": "system", - "content": "Alright?" - }, - { - "role": "user", - "content": "Yeah." - }, - ] - template = env.from_string(PromptFactory.create_chat_template("qwen")) - prompt = template.render( - messages=messages, - bos_token="", - eos_token="<|end|>", - add_generation_prompt=True, - ) - assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n" From 1a5c2334b4fe110cb92146b261d9ff4865e0e63b Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 7/9] Revert "feat: handle the deprecated default chat template" This reverts commit 6c05ee1c5c1e5a3c1daf4c70b6ba1238ed2777df. --- app/api/utils.py | 2 +- app/utils.py | 9 +-------- tests/app/test_utils.py | 21 ++------------------- 3 files changed, 4 insertions(+), 28 deletions(-) diff --git a/app/api/utils.py b/app/api/utils.py index 05726ba..ba99eed 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -352,7 +352,7 @@ async def generate_text( chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:", tokenize=True, ) - prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) # type: ignore + prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) async def _stream() -> AsyncGenerator[bytes, None]: start = 0 diff --git a/app/utils.py b/app/utils.py index ee97d51..23968b2 100644 --- a/app/utils.py +++ b/app/utils.py @@ -756,14 +756,6 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom tokenize=False, add_generation_prompt=True, ) - elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template: - # This largely depends on how older versions of HF tokenizers behave and may not work universally - tokenizer.chat_template = tokenizer.default_chat_template - prompt = tokenizer.apply_chat_template( - [dump_pydantic_object_to_dict(message) for message in messages], - tokenize=False, - add_generation_prompt=True, - ) else: system_content = "" prompt_parts: List[str] = [] @@ -843,3 +835,4 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom "25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.', "55540447": "linkage concept" } + diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 2f00e10..470d0d8 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -412,27 +412,10 @@ def test_get_prompt_with_chat_template(): assert prompt == "Mock chat template applied" -def test_get_prompt_with_default_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: - mock_tokenizer = tok.return_value - mock_tokenizer.chat_template = None - mock_tokenizer.default_chat_template = "Mock default chat template" - mock_tokenizer.apply_chat_template.return_value = "Mock default chat template applied" - messages = [ - PromptMessage(content="Alright?", role=PromptRole.USER.value), - PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), - ] - - prompt = get_prompt_from_messages(mock_tokenizer, messages) - - assert prompt == "Mock default chat template applied" - - def test_get_prompt_without_chat_template(): with patch('transformers.PreTrainedTokenizer') as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = None - mock_tokenizer.default_chat_template = None messages = [ PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value), PromptMessage(content="Alright?", role=PromptRole.USER.value), @@ -449,9 +432,9 @@ def test_get_prompt_with_no_messages(): with patch('transformers.PreTrainedTokenizer') as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = None - mock_tokenizer.default_chat_template = None messages = [] prompt = get_prompt_from_messages(mock_tokenizer, messages) - assert prompt == "\n<|assistant|>\n" + expected_prompt = "\n<|assistant|>\n" + assert prompt == expected_prompt \ No newline at end of file From a4ff387eaf4bb3dac9152852c47834d7f3dbe45b Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 8/9] Revert "feat: add metrics for usages of prompt and completion tokens" This reverts commit b338b4f827d107aaab27ac9bed060664171f8dbc. --- app/api/routers/generative.py | 67 +++++-------------- app/domain.py | 10 +-- app/management/prometheus_metrics.py | 21 ------ app/model_services/huggingface_llm_model.py | 25 +------ app/utils.py | 3 +- .../test_huggingface_llm_model.py | 12 ---- 6 files changed, 24 insertions(+), 114 deletions(-) diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index 763c717..5f9ff9f 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -6,7 +6,6 @@ from typing import Union, Iterable, AsyncGenerator from typing_extensions import Annotated -from functools import partial from fastapi import APIRouter, Depends, Request, Body, Query from fastapi.encoders import jsonable_encoder from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse @@ -16,7 +15,6 @@ from app.utils import get_settings, get_prompt_from_messages from app.api.utils import get_rate_limiter from app.api.dependencies import validate_tracking_id -from app.management.prometheus_metrics import cms_prompt_tokens, cms_completion_tokens, cms_total_tokens PATH_GENERATE = "/generate" PATH_GENERATE_ASYNC = "/stream/generate" @@ -46,7 +44,7 @@ def generate_text( model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> PlainTextResponse: """ - Generates text based on the prompt provided. + Generate text based on the prompt provided. Args: request (Request): The request object. @@ -63,12 +61,7 @@ def generate_text( tracking_id = tracking_id or str(uuid.uuid4()) if prompt: return PlainTextResponse( - model_service.generate( - prompt, - max_tokens=max_tokens, - temperature=temperature, - report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE), - ), + model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature), headers={"x-cms-tracking-id": tracking_id}, status_code=HTTP_200_OK, ) @@ -96,7 +89,7 @@ async def generate_text_stream( model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> StreamingResponse: """ - Generates a stream of texts in near real-time. + Generate a stream of texts in near real-time. Args: request (Request): The request object. @@ -113,12 +106,7 @@ async def generate_text_stream( tracking_id = tracking_id or str(uuid.uuid4()) if prompt: return StreamingResponse( - model_service.generate_async( - prompt, - max_tokens=max_tokens, - temperature=temperature, - report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC), - ), + model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature), media_type="text/event-stream", headers={"x-cms-tracking-id": tracking_id}, status_code=HTTP_200_OK, @@ -133,7 +121,7 @@ async def generate_text_stream( @router.post( - PATH_OPENAI_COMPLETIONS, + "/v1/chat/completions", tags=[Tags.Generative.name], response_model=None, dependencies=[Depends(cms_globals.props.current_active_user)], @@ -148,7 +136,7 @@ def generate_chat_completions( model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> Union[StreamingResponse, JSONResponse]: """ - Generates chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint. + Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint. Args: request (Request): The request object. @@ -165,7 +153,6 @@ def generate_chat_completions( stream = request_data.stream max_tokens = request_data.max_tokens temperature = request_data.temperature - tracking_id = tracking_id or str(uuid.uuid4()) if not messages: error_response = { @@ -176,25 +163,16 @@ def generate_chat_completions( "code": "missing_field", } } - return JSONResponse( - content=error_response, - status_code=HTTP_400_BAD_REQUEST, - headers={"x-cms-tracking-id": tracking_id}, - ) + return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST) - async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator: + async def _stream(p: str, mt: int, t: float) -> AsyncGenerator: data = { - "id": tracking_id, + "id": tracking_id or str(uuid.uuid4()), "object": "chat.completion.chunk", "choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}], } yield f"data: {json.dumps(data)}\n\n" - async for chunk in model_service.generate_async( - prompt, - max_tokens=max_tokens, - temperature=temperature, - report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS) - ): + async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t): data = { "choices": [ { @@ -210,18 +188,12 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene if stream: return StreamingResponse( _stream(prompt, max_tokens, temperature), - media_type="text/event-stream", - headers={"x-cms-tracking-id": tracking_id}, + media_type="text/event-stream" ) else: - generated_text = model_service.generate( - prompt, - max_tokens=max_tokens, - temperature=temperature, - send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS), - ) + generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature) completion = OpenAIChatResponse( - id=tracking_id, + id=str(uuid.uuid4()), object="chat.completion", created=int(time.time()), model=model_service.model_name, @@ -234,19 +206,10 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene ), "finish_reason": "stop", } - ], + ] ) - return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id}) + return JSONResponse(content=jsonable_encoder(completion)) def _empty_prompt_error() -> Iterable[str]: yield "ERROR: No prompt text provided\n" - - -def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None: - cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num) - logger.debug(f"Sent prompt tokens usage: {prompt_token_num}") - cms_completion_tokens.labels(handler=handler).observe(completion_token_num) - logger.debug(f"Sent completion tokens usage: {completion_token_num}") - cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num) - logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}") diff --git a/app/domain.py b/app/domain.py index 2ddd6e3..29a779a 100644 --- a/app/domain.py +++ b/app/domain.py @@ -189,8 +189,8 @@ class OpenAIChatRequest(BaseModel): class OpenAIChatResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the chat completion request") - object: str = Field(..., description="The type of the response") - created: int = Field(..., description="The timestamp when the completion was generated") - model: str = Field(..., description="The name of the model used for generating the completion") - choices: List = Field(..., description="The generated messages and their metadata") + id: str + object: str + created: int + model: str + choices: List diff --git a/app/management/prometheus_metrics.py b/app/management/prometheus_metrics.py index 78c5698..3f48858 100644 --- a/app/management/prometheus_metrics.py +++ b/app/management/prometheus_metrics.py @@ -34,24 +34,3 @@ "Number of bulk-processed documents", ["handler"], ) - -# The histogram metric to track the number of tokens in the messages of the input prompt -cms_prompt_tokens = Histogram( - "cms_prompt_tokens", - "Number of tokens in the messages of the input prompt", - ["handler"], -) - -# The histogram metric to track the number of tokens in the generated assistant reply -cms_completion_tokens = Histogram( - "cms_completion_tokens", - "Number of tokens in the generated assistant reply", - ["handler"], -) - -# The histogram metric to track the total number of tokens used in the prompt and the completion -cms_total_tokens = Histogram( - "cms_total_tokens", - "Number of tokens used in the prompt and the completion", - ["handler"], -) diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 7340f67..848b02c 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -2,7 +2,7 @@ import logging import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable +from typing import Dict, List, Optional, Tuple, Any, AsyncIterable from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -198,7 +198,6 @@ def generate( prompt: str, max_tokens: int = 512, temperature: float = 0.7, - report_tokens: Optional[Callable[[str], None]] = None, **kwargs: Any ) -> str: """ @@ -208,7 +207,6 @@ def generate( prompt (str): The prompt for the text generation max_tokens (int): The maximum number of tokens to generate. Defaults to 512. temperature (float): The temperature for the text generation. Defaults to 0.7. - report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. **kwargs (Any): Additional keyword arguments to be passed to this method. Returns: @@ -232,13 +230,9 @@ def generate( outputs = self.model.generate(**generation_kwargs) generated_text = self.tokenizer.decode(outputs[0], skip_prompt=True, skip_special_tokens=True) - logger.debug("Response generation completed") - if report_tokens: - report_tokens( - prompt_token_num=inputs.input_ids.shape[-1], # type: ignore - completion_token_num=outputs[0].shape[-1], # type: ignore - ) + + logger.debug("Response generation completed") return generated_text @@ -247,7 +241,6 @@ async def generate_async( prompt: str, max_tokens: int = 512, temperature: float = 0.7, - report_tokens: Optional[Callable[[str], None]] = None, **kwargs: Any ) -> AsyncIterable: """ @@ -257,7 +250,6 @@ async def generate_async( prompt (str): The prompt for the text generation. max_tokens (int): The maximum number of tokens to generate. Defaults to 512. temperature (float): The temperature for the text generation. Defaults to 0.7. - report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. **kwargs (Any): Additional keyword arguments to be passed to the model loader. Returns: @@ -287,20 +279,9 @@ async def generate_async( try: _ = self._text_generator.submit(self.model.generate, **generation_kwargs) - output = "" for content in streamer: yield content - output += content await asyncio.sleep(0.01) - if report_tokens: - report_tokens( - prompt_token_num=inputs.input_ids.shape[-1], # type: ignore - completion_token_num=self.tokenizer( # type: ignore - output, - add_special_tokens=False, - return_tensors="pt" - ).input_ids.shape[-1], - ) except Exception as e: logger.error("An error occurred while generating the response") logger.exception(e) diff --git a/app/utils.py b/app/utils.py index 23968b2..705539b 100644 --- a/app/utils.py +++ b/app/utils.py @@ -694,7 +694,7 @@ def dump_pydantic_object_to_dict(model: BaseModel) -> Dict: """ if hasattr(model, "model_dump"): - return model.model_dump(mode="json") # type: ignore + return model.model_dump() # type: ignore elif hasattr(model, "dict"): return model.dict() # type: ignore else: @@ -835,4 +835,3 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom "25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.', "55540447": "linkage concept" } - diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index 7cd4941..0262c41 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -46,7 +46,6 @@ def test_generate(huggingface_llm_model): huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() huggingface_llm_model.tokenizer = MagicMock() - mock_send_metrics = MagicMock() inputs = MagicMock() inputs.input_ids = MagicMock(shape=[1, 2]) inputs.attention_mask = MagicMock() @@ -59,7 +58,6 @@ def test_generate(huggingface_llm_model): prompt="Alright?", max_tokens=128, temperature=0.5, - report_tokens=mock_send_metrics ) huggingface_llm_model.tokenizer.assert_called_once_with( @@ -80,10 +78,6 @@ def test_generate(huggingface_llm_model): skip_prompt=True, skip_special_tokens=True, ) - mock_send_metrics.assert_called_once_with( - prompt_token_num=2, - completion_token_num=2, - ) assert result == "Yeah." @@ -91,7 +85,6 @@ async def test_generate_async(huggingface_llm_model): huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() huggingface_llm_model.tokenizer = MagicMock() - mock_send_metrics = MagicMock() inputs = MagicMock() inputs.input_ids = MagicMock(shape=[1, 2]) inputs.attention_mask = MagicMock() @@ -104,7 +97,6 @@ async def test_generate_async(huggingface_llm_model): prompt="Alright?", max_tokens=128, temperature=0.5, - report_tokens=mock_send_metrics ) huggingface_llm_model.tokenizer.assert_called_once_with( @@ -125,8 +117,4 @@ async def test_generate_async(huggingface_llm_model): skip_prompt=True, skip_special_tokens=True, ) - mock_send_metrics.assert_called_once_with( - prompt_token_num=2, - completion_token_num=2, - ) assert result == "Yeah." \ No newline at end of file From 3b043eea03316f6f951a00d15cecb54b18ac7b32 Mon Sep 17 00:00:00 2001 From: Xi Bai <82581439+baixiac@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:52:05 +0100 Subject: [PATCH 9/9] Revert "feat: add the endpoint compatible with OpenAI client protocols" This reverts commit bcea8fe447795ea0fd4cd8d61904b877e518d42e. --- app/api/routers/generative.py | 150 +----------------- app/api/utils.py | 39 ++--- app/domain.py | 27 ---- app/model_services/huggingface_llm_model.py | 26 +-- app/utils.py | 57 +------ pyproject.toml | 4 +- tests/app/api/test_serving_hf_llm.py | 64 +------- .../test_huggingface_llm_model.py | 77 +-------- tests/app/test_utils.py | 49 +----- uv.lock | 133 +++------------- 10 files changed, 60 insertions(+), 566 deletions(-) diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index 5f9ff9f..007eb0b 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -1,24 +1,16 @@ -import json import logging -import time -import uuid import app.api.globals as cms_globals -from typing import Union, Iterable, AsyncGenerator from typing_extensions import Annotated from fastapi import APIRouter, Depends, Request, Body, Query -from fastapi.encoders import jsonable_encoder -from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST -from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole +from fastapi.responses import PlainTextResponse, StreamingResponse +from app.domain import Tags from app.model_services.base import AbstractModelService -from app.utils import get_settings, get_prompt_from_messages +from app.utils import get_settings from app.api.utils import get_rate_limiter -from app.api.dependencies import validate_tracking_id PATH_GENERATE = "/generate" PATH_GENERATE_ASYNC = "/stream/generate" -PATH_OPENAI_COMPLETIONS = "/v1/chat/completions" router = APIRouter() config = get_settings() @@ -39,8 +31,6 @@ def generate_text( request: Request, prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512, - temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7, - tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> PlainTextResponse: """ @@ -50,27 +40,13 @@ def generate_text( request (Request): The request object. prompt (str): The prompt to be sent to the model. max_tokens (int): The maximum number of tokens to generate. - temperature (float): The temperature of the generated text. - tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. Returns: PlainTextResponse: A response containing the generated text. """ - tracking_id = tracking_id or str(uuid.uuid4()) - if prompt: - return PlainTextResponse( - model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature), - headers={"x-cms-tracking-id": tracking_id}, - status_code=HTTP_200_OK, - ) - else: - return PlainTextResponse( - _empty_prompt_error(), - headers={"x-cms-tracking-id": tracking_id}, - status_code=HTTP_400_BAD_REQUEST, - ) + return PlainTextResponse(model_service.generate(prompt, max_tokens=max_tokens)) @router.post( @@ -84,8 +60,6 @@ async def generate_text_stream( request: Request, prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512, - temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7, - tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> StreamingResponse: """ @@ -95,121 +69,13 @@ async def generate_text_stream( request (Request): The request object. prompt (str): The prompt to be sent to the model. max_tokens (int): The maximum number of tokens to generate. - temperature (float): The temperature of the generated text. - tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. Returns: StreamingResponse: A streaming response containing the text generated in near real-time. """ - tracking_id = tracking_id or str(uuid.uuid4()) - if prompt: - return StreamingResponse( - model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature), - media_type="text/event-stream", - headers={"x-cms-tracking-id": tracking_id}, - status_code=HTTP_200_OK, - ) - else: - return StreamingResponse( - _empty_prompt_error(), - media_type="text/event-stream", - headers={"x-cms-tracking-id": tracking_id}, - status_code=HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/v1/chat/completions", - tags=[Tags.Generative.name], - response_model=None, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions", -) -def generate_chat_completions( - request: Request, - request_data: Annotated[OpenAIChatRequest, Body( - description="OpenAI-like completion request", media_type="application/json" - )], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep) -) -> Union[StreamingResponse, JSONResponse]: - """ - Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint. - - Args: - request (Request): The request object. - request_data (OpenAIChatRequest): The request data containing model, messages, and stream. - tracking_id (Union[str, None]): An optional tracking ID of the requested task. - model_service (AbstractModelService): The model service dependency. - - Returns: - StreamingResponse: A OpenAI-like response containing the text generated in near real-time. - JSONResponse: A response containing an error message if the prompt messages are empty. - """ - - messages = request_data.messages - stream = request_data.stream - max_tokens = request_data.max_tokens - temperature = request_data.temperature - - if not messages: - error_response = { - "error": { - "message": "No prompt messages provided", - "type": "invalid_request_error", - "param": "messages", - "code": "missing_field", - } - } - return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST) - - async def _stream(p: str, mt: int, t: float) -> AsyncGenerator: - data = { - "id": tracking_id or str(uuid.uuid4()), - "object": "chat.completion.chunk", - "choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}], - } - yield f"data: {json.dumps(data)}\n\n" - async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t): - data = { - "choices": [ - { - "delta": {"content": chunk} - } - ], - "object": "chat.completion.chunk", - } - yield f"data: {json.dumps(data)}\n\n" - yield "data: [DONE]\n\n" - - prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore - if stream: - return StreamingResponse( - _stream(prompt, max_tokens, temperature), - media_type="text/event-stream" - ) - else: - generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature) - completion = OpenAIChatResponse( - id=str(uuid.uuid4()), - object="chat.completion", - created=int(time.time()), - model=model_service.model_name, - choices=[ - { - "index": 0, - "message": PromptMessage( - role=PromptRole.ASSISTANT, - content=generated_text, - ), - "finish_reason": "stop", - } - ] - ) - return JSONResponse(content=jsonable_encoder(completion)) - - -def _empty_prompt_error() -> Iterable[str]: - yield "ERROR: No prompt text provided\n" + return StreamingResponse( + model_service.generate_async(prompt, max_tokens=max_tokens), + media_type="text/event-stream" + ) diff --git a/app/api/utils.py b/app/api/utils.py index ba99eed..a14714c 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -286,7 +286,6 @@ async def init_vllm_engine(app: FastAPI, """ try: - # Import necessary vLLM components from vllm.utils import FlexibleArgumentParser from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -299,19 +298,16 @@ async def init_vllm_engine(app: FastAPI, ) from vllm import SamplingParams, TokensPrompt except ImportError: - # Raise a custom exception if vLLM is not installed - raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.") + logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") parser = FlexibleArgumentParser() parser = make_arg_parser(parser) args = parser.parse_args([]) validate_parsed_serve_args(args) - args.model = model_dir_path args.dtype = "float16" args.served_model_name = [model_name] - args.max_model_len = 2048 # The default batched length (2048) needs to be higher than max_model_len. - # args.tokenizer = model_dir_path # Uncomment if your tokenizer is in a different path or needs explicit setting. + # args.tokenizer = model_dir_path args.log_level = log_level exit_stack = contextlib.AsyncExitStack() @@ -321,11 +317,9 @@ async def init_vllm_engine(app: FastAPI, disable_frontend_multiprocessing=True, ) ) - tokenizer = await engine.get_tokenizer() vllm_config = await engine.get_vllm_config() model_config = await engine.get_model_config() - await init_app_state(engine, vllm_config, app.state, args) async def generate_text( @@ -333,32 +327,27 @@ async def generate_text( prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512 ) -> StreamingResponse: - """ - Custom endpoint for streaming text generation. - This endpoint takes a raw text prompt and streams back the generated text. - It applies a chat template to the prompt internally for model compatibility. - """ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] params = SamplingParams(max_tokens=max_tokens) - conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore - prompt_tokens = apply_hf_chat_template( # type: ignore - tokenizer, - conversation=conversation, - tools=None, - add_generation_prompt=True, - continue_final_message=False, - chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:", - tokenize=True, + prompt = TokensPrompt( + prompt_token_ids=apply_hf_chat_template( # type: ignore + tokenizer, + conversation=conversation, + tools=None, + add_generation_prompt=True, + continue_final_message=False, + chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:", + tokenize=True, + ) ) - prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) async def _stream() -> AsyncGenerator[bytes, None]: start = 0 - async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt_obj, sampling_params=params): + async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params): text = output.outputs[0].text - yield text[start:].encode("utf-8") + yield text[start:] # type: ignore start = len(text) return StreamingResponse(_stream(), media_type="text/event-stream") diff --git a/app/domain.py b/app/domain.py index 29a779a..c9d38cf 100644 --- a/app/domain.py +++ b/app/domain.py @@ -167,30 +167,3 @@ class Doc(BaseModel): text: str = Field(description="The text from which the entities are extracted") ents: List[Entity] = Field(description="The list of extracted entities") title: Optional[str] = Field(default=None, description="The headline of the text") - - -class PromptRole(Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - - -class PromptMessage(BaseModel): - role: PromptRole = Field(description="The role who generates the message") - content: str = Field(description="The actual text of the message") - - -class OpenAIChatRequest(BaseModel): - messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model") - stream: bool = Field(..., description="Whether to stream the response") - max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0) - temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0) - - -class OpenAIChatResponse(BaseModel): - id: str - object: str - created: int - model: str - choices: List diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 848b02c..566bbf9 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -168,8 +168,6 @@ def init_model(self) -> None: logger.warning("Model service is already initialised and can be initialised only once") else: self._model, self._tokenizer = self.load_model(self._model_pack_path) - if non_default_device_is_available(get_settings().DEVICE): - self._model.to(get_settings().DEVICE) if self._enable_trainer: logger.error("Trainers are not yet implemented for HuggingFace Generative models") @@ -193,20 +191,13 @@ def annotate(self, text: str) -> List[Annotation]: def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace Generative models") - def generate( - self, - prompt: str, - max_tokens: int = 512, - temperature: float = 0.7, - **kwargs: Any - ) -> str: + def generate(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> str: """ Generates text based on the prompt. Args: prompt (str): The prompt for the text generation max_tokens (int): The maximum number of tokens to generate. Defaults to 512. - temperature (float): The temperature for the text generation. Defaults to 0.7. **kwargs (Any): Additional keyword arguments to be passed to this method. Returns: @@ -223,8 +214,8 @@ def generate( inputs=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, - do_sample=False, - temperature=temperature, + do_sample=True, + temperature=0.7, top_p=0.9, ) @@ -236,20 +227,13 @@ def generate( return generated_text - async def generate_async( - self, - prompt: str, - max_tokens: int = 512, - temperature: float = 0.7, - **kwargs: Any - ) -> AsyncIterable: + async def generate_async(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> AsyncIterable: """ Asynchronously generates text stream based on the prompt. Args: prompt (str): The prompt for the text generation. max_tokens (int): The maximum number of tokens to generate. Defaults to 512. - temperature (float): The temperature for the text generation. Defaults to 0.7. **kwargs (Any): Additional keyword arguments to be passed to the model loader. Returns: @@ -273,7 +257,7 @@ async def generate_async( streamer=streamer, max_new_tokens=max_tokens, do_sample=True, - temperature=temperature, + temperature=0.7, top_p=0.9, ) diff --git a/app/utils.py b/app/utils.py index 705539b..245da97 100644 --- a/app/utils.py +++ b/app/utils.py @@ -20,12 +20,12 @@ from spacy.lang.en import English from spacy.util import filter_spans from safetensors.torch import load_file -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import PreTrainedModel from urllib.parse import ParseResult from functools import lru_cache from typing import List, Optional, Dict, Callable, Any, Union, Type, TypeVar from app.config import Settings -from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole +from app.domain import Annotation, Entity, CodeType, ModelType, Device from app.exception import ManagedModelException @@ -682,24 +682,6 @@ def load_pydantic_object_from_dict(model: Type[T], obj: Dict) -> T: raise TypeError("Model must have a known method for parsing objects.") -def dump_pydantic_object_to_dict(model: BaseModel) -> Dict: - """ - Dumps the pydantic model object to a dictionary. - - Args: - model (BaseModel): The pydantic model to dump. - - Returns: - Dict: The dictionary object. - """ - - if hasattr(model, "model_dump"): - return model.model_dump() # type: ignore - elif hasattr(model, "dict"): - return model.dict() # type: ignore - else: - raise TypeError("Model must have a known method for dumping objects.") - def download_model_package( model_package_url: str, destination_path: str, @@ -739,41 +721,6 @@ def download_model_package( retry_delay *= 2 -def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[PromptMessage]) -> str: - """ - Generates a prompt from a list of prompt messages. - - Args: - tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template. - messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt. - - Returns: - str: The generated prompt. - """ - if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: - prompt = tokenizer.apply_chat_template( - [dump_pydantic_object_to_dict(message) for message in messages], - tokenize=False, - add_generation_prompt=True, - ) - else: - system_content = "" - prompt_parts: List[str] = [] - for message in messages: - content = message.content.strip() - if message.role == PromptRole.SYSTEM: - system_content = content - elif message.role == PromptRole.USER: - prompt_parts.append(f"<|user|>\n{content}") - elif message.role == PromptRole.ASSISTANT: - prompt_parts.append(f"<|assistant|>\n{content}") - if system_content: - prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts) - else: - prompt = "\n".join(prompt_parts) - prompt += "\n<|assistant|>\n" - return prompt - TYPE_ID_TO_NAME_PATCH = { "32816260": "physical object", "2680757": "observable entity", diff --git a/pyproject.toml b/pyproject.toml index eaba606..01b8d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "pynvml~=11.5.3", "toml~=0.10.2", "peft<0.14.0", - "huggingface-hub~=0.33.0", + "huggingface-hub~=0.32.0", ] readme = "README.md" keywords = ["natural-language-processing", "electronic-health-records", "clinical-data"] @@ -71,7 +71,6 @@ dev = [ "locust<2.32.0", "typer-cli~=0.15.1", "types-toml==0.10.8.20240310", - "openai>=1.84.0", ] docs = [ "sphinx~=7.1.2", @@ -100,7 +99,6 @@ dev = [ "locust<2.32.0", "typer-cli~=0.15.1", "types-toml==0.10.8.20240310", - "openai>=1.84.0", ] docs = [ "sphinx~=7.1.2", diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py index 4bd4e9d..39e82bf 100644 --- a/tests/app/api/test_serving_hf_llm.py +++ b/tests/app/api/test_serving_hf_llm.py @@ -1,10 +1,8 @@ import httpx -import json import pytest import app.api.globals as cms_globals from unittest.mock import create_autospec -from fastapi.testclient import TestClient from app.api.api import get_generative_server from app.model_services.huggingface_llm_model import HuggingFaceLlmModel from app.utils import get_settings @@ -29,72 +27,14 @@ def llm_app(llm_model_service): yield app app.dependency_overrides.clear() -@pytest.fixture(scope="function") -def client(llm_model_service): - llm_model_service.generate.return_value = "Yeah." - app = get_generative_server(config, msd_overwritten=lambda: llm_model_service) - app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None - client = TestClient(app) - yield client - client.app.dependency_overrides.clear() - - -def test_generate(client): - response = client.post( - "/generate?max_tokens=128&temperature=0.7", - data="Alright?", - headers={"Content-Type": "text/plain"}, - ) - - assert response.status_code == 200 - assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing" - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == "Yeah." - @pytest.mark.asyncio async def test_stream_generate(llm_model_service, llm_app): - llm_model_service.generate_async.return_value = "Fine." async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: response = await ac.post( - "/stream/generate?max_tokens=32&temperature=0.7", + "/stream/generate?max_tokens=32", data="How are you doing?", headers={"Content-Type": "text/plain"}, ) - assert response.status_code == 200 - assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing" - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - assert response.text == "Fine." - - -@pytest.mark.asyncio -async def test_generate_chat_completions(llm_model_service, llm_app): - llm_model_service.generate.return_value = "I'm a chat bot." - request_data = { - "messages": [ - { - "role": "system", - "content": "You are a chat bot." - }, - { - "role": "user", - "content": "Who are you?" - } - ], - "stream": True, - "max_tokens": 128, - "temperature": 0.7 - } - async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: - response = await ac.post( - "/v1/chat/completions?max_tokens=128&temperature=0.7", - data=json.dumps(request_data), - headers={"Content-Type": "application/json"}, - ) - - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - assert response.text.startswith("data:") - assert "id" in response.text - assert "chat.completion.chunk" in response.text + assert response.status_code == 200 \ No newline at end of file diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index 0262c41..1f33572 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -1,5 +1,4 @@ import os -from unittest.mock import MagicMock from tests.app.conftest import MODEL_PARENT_DIR from transformers import PreTrainedModel, PreTrainedTokenizerBase from app import __version__ @@ -44,77 +43,5 @@ def test_info(huggingface_llm_model): def test_generate(huggingface_llm_model): huggingface_llm_model.init_model() - huggingface_llm_model.model = MagicMock() - huggingface_llm_model.tokenizer = MagicMock() - inputs = MagicMock() - inputs.input_ids = MagicMock(shape=[1, 2]) - inputs.attention_mask = MagicMock() - huggingface_llm_model.tokenizer.return_value = inputs - outputs = [MagicMock(shape=[2])] - huggingface_llm_model.model.generate.return_value = outputs - huggingface_llm_model.tokenizer.decode.return_value = "Yeah." - - result = huggingface_llm_model.generate( - prompt="Alright?", - max_tokens=128, - temperature=0.5, - ) - - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright?", - add_special_tokens=False, - return_tensors="pt", - ) - huggingface_llm_model.model.generate.assert_called_once_with( - inputs=inputs.input_ids, - attention_mask=inputs.attention_mask, - max_new_tokens=128, - do_sample=False, - temperature=0.5, - top_p=0.9, - ) - huggingface_llm_model.tokenizer.decode.assert_called_once_with( - outputs[0], - skip_prompt=True, - skip_special_tokens=True, - ) - assert result == "Yeah." - - -async def test_generate_async(huggingface_llm_model): - huggingface_llm_model.init_model() - huggingface_llm_model.model = MagicMock() - huggingface_llm_model.tokenizer = MagicMock() - inputs = MagicMock() - inputs.input_ids = MagicMock(shape=[1, 2]) - inputs.attention_mask = MagicMock() - huggingface_llm_model.tokenizer.return_value = inputs - outputs = [MagicMock(shape=[2])] - huggingface_llm_model.model.generate.return_value = outputs - huggingface_llm_model.tokenizer.decode.return_value = "Yeah." - - result = await huggingface_llm_model.generate_async( - prompt="Alright?", - max_tokens=128, - temperature=0.5, - ) - - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright?", - add_special_tokens=False, - return_tensors="pt", - ) - huggingface_llm_model.model.generate_async.assert_called_once_with( - inputs=inputs.input_ids, - attention_mask=inputs.attention_mask, - max_new_tokens=128, - do_sample=False, - temperature=0.5, - top_p=0.9, - ) - huggingface_llm_model.tokenizer.decode.assert_called_once_with( - outputs[0], - skip_prompt=True, - skip_special_tokens=True, - ) - assert result == "Yeah." \ No newline at end of file + output = huggingface_llm_model.generate("How are you doing?") + assert isinstance(output, str) diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 470d0d8..4519350 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -6,7 +6,7 @@ import zipfile import tarfile import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from safetensors.torch import save_file from transformers import PreTrainedModel from urllib.parse import urlparse @@ -33,10 +33,8 @@ pyproject_dependencies_to_pip_requirements, get_model_data_package_base_name, load_pydantic_object_from_dict, - dump_pydantic_object_to_dict, - get_prompt_from_messages, ) -from app.domain import Annotation, Entity, PromptMessage, PromptRole +from app.domain import Annotation, Entity def test_get_code_base_uri(): @@ -395,46 +393,3 @@ def __init__(self): def forward(self, x): return self.linear(x) - - -def test_get_prompt_with_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: - mock_tokenizer = tok.return_value - mock_tokenizer.chat_template = "Mock chat template" - mock_tokenizer.apply_chat_template.return_value = "Mock chat template applied" - messages = [ - PromptMessage(content="Alright?", role=PromptRole.USER.value), - PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), - ] - - prompt = get_prompt_from_messages(mock_tokenizer, messages) - - assert prompt == "Mock chat template applied" - - -def test_get_prompt_without_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: - mock_tokenizer = tok.return_value - mock_tokenizer.chat_template = None - messages = [ - PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value), - PromptMessage(content="Alright?", role=PromptRole.USER.value), - PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), - ] - - prompt = get_prompt_from_messages(mock_tokenizer, messages) - - expected_prompt = "<|system|>\nYou are a helpful assistant.\n<|user|>\nAlright?\n<|assistant|>\nYeah.\n<|assistant|>\n" - assert prompt == expected_prompt - - -def test_get_prompt_with_no_messages(): - with patch('transformers.PreTrainedTokenizer') as tok: - mock_tokenizer = tok.return_value - mock_tokenizer.chat_template = None - messages = [] - - prompt = get_prompt_from_messages(mock_tokenizer, messages) - - expected_prompt = "\n<|assistant|>\n" - assert prompt == expected_prompt \ No newline at end of file diff --git a/uv.lock b/uv.lock index cc97b01..0b032d7 100644 --- a/uv.lock +++ b/uv.lock @@ -1215,7 +1215,6 @@ dev = [ { name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "mypy" }, - { name = "openai" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-bdd" }, @@ -1243,7 +1242,6 @@ dev = [ { name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "mypy" }, - { name = "openai" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-bdd" }, @@ -1279,14 +1277,13 @@ requires-dist = [ { name = "fastapi-users-db-sqlalchemy", specifier = "~=5.0.0" }, { name = "graypy", specifier = "~=2.1.0" }, { name = "httpx", marker = "extra == 'dev'", specifier = "~=0.24.1" }, - { name = "huggingface-hub", specifier = "~=0.33.0" }, + { name = "huggingface-hub", specifier = "~=0.32.0" }, { name = "ijson", specifier = "~=3.1.4" }, { name = "locust", marker = "extra == 'dev'", specifier = "<2.32.0" }, { name = "medcat", marker = "python_full_version < '3.9'", specifier = "~=1.13.1" }, { name = "medcat", marker = "python_full_version >= '3.9'", specifier = "~=1.16.0" }, { name = "mlflow", specifier = "~=2.16.2" }, { name = "mypy", marker = "extra == 'dev'", specifier = "~=1.14.0" }, - { name = "openai", marker = "extra == 'dev'", specifier = ">=1.84.0" }, { name = "peft", specifier = "<0.14.0" }, { name = "prometheus-fastapi-instrumentator", specifier = "~=7.0.0" }, { name = "psycopg2-binary", specifier = "~=2.9.4" }, @@ -1323,7 +1320,6 @@ dev = [ { name = "httpx", specifier = "~=0.24.1" }, { name = "locust", specifier = "<2.32.0" }, { name = "mypy", specifier = "~=1.14.0" }, - { name = "openai", specifier = ">=1.84.0" }, { name = "pytest", specifier = "~=7.1.2" }, { name = "pytest-asyncio", specifier = "~=0.23.7" }, { name = "pytest-bdd", specifier = "~=7.2.0" }, @@ -3267,7 +3263,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.33.5" +version = "0.32.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, @@ -3280,9 +3276,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/16/5716d03e2b48bcc8e32d9b18ed7e55d2ae52e3d5df146cced9fe0581b5ff/huggingface_hub-0.33.5.tar.gz", hash = "sha256:814097e475646d170c44be4c38f7d381ccc4539156a5ac62a54f53aaf1602ed8", size = 427075 } +sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494 } wheels = [ - { url = "https://files.pythonhosted.org/packages/33/d5/d9e9b75d8dc9cf125fff16fb0cd51d864a29e8b46b6880d8808940989405/huggingface_hub-0.33.5-py3-none-any.whl", hash = "sha256:29b4e64982c2064006021af297e1b17d44c85a8aaf90a0d7efeff7e7d2426296", size = 515705 }, + { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101 }, ] [package.optional-dependencies] @@ -3624,88 +3620,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 }, ] -[[package]] -name = "jiter" -version = "0.9.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9' and sys_platform != 'win32'", - "python_full_version < '3.9' and sys_platform == 'win32'", -] -sdist = { url = "https://files.pythonhosted.org/packages/84/72/c28662416d9807bb5a38625eadedb82d4bd14fd2700c308ece7acdb8e89f/jiter-0.9.1.tar.gz", hash = "sha256:7852990068b6e06102ecdc44c1619855a2af63347bfb5e7e009928dcacf04fdd", size = 162540 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2b/5f/7f6aaca7943c644b4fd220650771f39dbfb74f9690efc6fb8c0d4092a399/jiter-0.9.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c0163baa7ee85860fdc14cc39263014500df901eeffdf94c1eab9a2d713b2a9d", size = 312882 }, - { url = "https://files.pythonhosted.org/packages/86/0d/aac9eafc5d46bdf5c4f127ac1ce85e434d003bb5e3ae886f5e726a988cf6/jiter-0.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:514d4dd845e0af4da15112502e6fcb952f0721f27f17e530454e379472b90c14", size = 311743 }, - { url = "https://files.pythonhosted.org/packages/b8/54/fab1f4d8634af7bb1ad6dc49bee50ea9f649de0e5309c80192ace739f968/jiter-0.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b879faee1cc1a67fde3f3f370041239fd260ac452bd53e861aa4a94a51e3fd02", size = 1085889 }, - { url = "https://files.pythonhosted.org/packages/bd/86/bf4ed251d8035d5d72a46c8f9969bd5054fad052371cbea0cb161060e660/jiter-0.9.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20a5ce641f93bfb8d8e336f8c4a045e491652f41eaacc707b15b245ece611e72", size = 1117896 }, - { url = "https://files.pythonhosted.org/packages/62/40/b04c40deccd5edd5f2a3853f4a80dc0ddbe157d1d523a573fb3d224315fc/jiter-0.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8575b1d2b49df04ca82d658882f4a432b7ed315a69126a379df4d10aeb416021", size = 1211956 }, - { url = "https://files.pythonhosted.org/packages/85/f0/114e9893e4ef5b423718efe9b3da01117539c333f06ef19543c68c8b7ed1/jiter-0.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc61831699904e0c58e82943f529713833db87acd13f95a3c0feb791f862d47b", size = 1219691 }, - { url = "https://files.pythonhosted.org/packages/02/9a/1aeac4541ce1c59c65dc76dbab642232da3d8db0581df3e61b8943033bd7/jiter-0.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb733faf4d0e730d6663873249c1acb572fc8bd9dae3836ceda69751f27c5be", size = 352604 }, - { url = "https://files.pythonhosted.org/packages/6b/27/446ec6ca0a25d9d2f45ad546633a2b4a1b6a7f28fb6819c7056b163c5aee/jiter-0.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d903b3bb917c0df24f2ef62f587c8f32f6003cb2f97264109ca56c023262557f", size = 1147136 }, - { url = "https://files.pythonhosted.org/packages/09/9d/c8540bc097b07e106d060c21395c6fa6561223e7366c948a04ef0aa39979/jiter-0.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:eac3eb5206845b170142c016ae467eca523a25459dc9c53fcd8e154ea263406c", size = 1255843 }, - { url = "https://files.pythonhosted.org/packages/d3/61/9b377ecf4e09e325e90f77a7a4859ec933162f58ff5c6b7730aff6352033/jiter-0.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7ea0c20cfc61acc5335bb8ee36d639e6a4ded03f34f878b2b3038bb9f3bb553c", size = 1257536 }, - { url = "https://files.pythonhosted.org/packages/ed/f6/b6754e11ac9d02f05a2d713c0846ce813a69c1f6f7de7f1ae216c4e35ace/jiter-0.9.1-cp310-cp310-win32.whl", hash = "sha256:0f8f812dd6d2b4112db9ab4c1079c4fe73e553a500e936657fdda394fa2517e1", size = 214064 }, - { url = "https://files.pythonhosted.org/packages/1d/cb/7b9c5d6f73499d1fb5e97e36e8078f3bea00d7541a973117eccf9db1e079/jiter-0.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:f7f0198889170e7af6210509803e6527b402efc6c26f42e2896883597a10426f", size = 209952 }, - { url = "https://files.pythonhosted.org/packages/ee/3b/9f9deaef471e346354c832b6627e0d1b9ba3d9611d0e0fd394c2acf2a615/jiter-0.9.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b8564e3198c4c8d835fc95cc54d6bcbd2fd8dc33a047fecc12c208491196995", size = 312737 }, - { url = "https://files.pythonhosted.org/packages/36/00/76fa6d519f8289aad32ec1caf3716eb700ba48e3212d1dda71e74c385a5c/jiter-0.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:90b92044588d14efe89b394eca735adc4ac096eba82dc75d93c3083b1eebce8d", size = 313357 }, - { url = "https://files.pythonhosted.org/packages/b3/e9/f864ebe9ddf07761d5bdd3148b45a5d433c6cbce7c7e8be29baf806fa612/jiter-0.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3505f7f419b355c7788fcaae0dfc4c6ccbc50c0dc3633a2da797e841c5a423dc", size = 1085946 }, - { url = "https://files.pythonhosted.org/packages/82/a1/ed02d4c86d620989dcd392366daa67198961eedaf2e66f7a68f0d3846dba/jiter-0.9.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93af8c3f4a3bf145c690e857a945eb5c655534bf95c67e1447d85c02e5af64d7", size = 1118090 }, - { url = "https://files.pythonhosted.org/packages/d3/01/d107531d215a57cda3cbc4adfcf3119166dd32adc1c332c1f3f36efd3484/jiter-0.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43b81dd21e260a249780764921b1f9a6379cb31e24e7b61e6bf0799f38ec4b91", size = 1212231 }, - { url = "https://files.pythonhosted.org/packages/45/1e/6801a81a2ef1f917fe9a7d2139e576dd4f53497c309dab9461136922709c/jiter-0.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:db639fad5631b3d1692609f6dd77b64e8578321b7aeec07a026acd2c867c04a5", size = 1219263 }, - { url = "https://files.pythonhosted.org/packages/a5/d4/40082e8666cfdb24461855e9bb29fe77f063cc65a6c903291f2e5225f780/jiter-0.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15356b943e70ca7ab3b587ffaffadc0158467f6c4e0b491e52a0743c4bdf5ba1", size = 350364 }, - { url = "https://files.pythonhosted.org/packages/c4/09/09bc72dd143f76acd55e04c3a45b9f9ee3ed28e00b49924e3702ad041812/jiter-0.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53a7033a46141ff815518a6972d657c75d8f5946b9315e1c25b07e9677c1ff6c", size = 1146802 }, - { url = "https://files.pythonhosted.org/packages/5b/34/9d15a9c04d5760537b432134447bde94b936ec73dc922b4d14a48def2e1f/jiter-0.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:68cf519a6f00b8127f9be64a37e97e978094438abced5adebe088a98c64bdcff", size = 1256019 }, - { url = "https://files.pythonhosted.org/packages/8f/01/1fcd165fb28968a54bb46a209d5919f7649b96608eef7dc4622ea378b95a/jiter-0.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9098abdd34cd9ddeb04768cc4f5fc725ebd9a52978c488da74e58a837ce93506", size = 1257610 }, - { url = "https://files.pythonhosted.org/packages/9f/87/93ac6a57331dd90e4c896ac852bf8ce6b28b40dace4b9698a207dbb99af2/jiter-0.9.1-cp311-cp311-win32.whl", hash = "sha256:7179ce96aecd096af890dd57b84133e47a59fbde32a77734f09bafa6a4da619e", size = 214515 }, - { url = "https://files.pythonhosted.org/packages/bb/ee/3678b8a3bd5f6471d0a492540e7ff9c63db278d844214458ec5cfb22adb2/jiter-0.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:e6517f5b7b6f60fd77fc1099572f445be19553c6f61b907ab5b413fb7179663f", size = 212258 }, - { url = "https://files.pythonhosted.org/packages/26/ca/1c7438d66969a13938266492de65daf752754ec59f2a3f3716027c7d708f/jiter-0.9.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:95065923a49ae387bab62b1bf5f798beb12e6fb4469a079fdd0ecad64b40b272", size = 313516 }, - { url = "https://files.pythonhosted.org/packages/e8/d9/3a6300309e312f8ed529ae57d565f69abdb520e4f12460cefa7996d0716c/jiter-0.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a179fbc5c7922844a673be35099a3036a7276dc63753c6c81a77c3cb525f2f8d", size = 308161 }, - { url = "https://files.pythonhosted.org/packages/b3/91/2aca15be38514daf8f1a1460fd9c4b652ed09148fe109520298858be7928/jiter-0.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abd30dc5c0183d31faf30ce8279d723809c54b3fe6d95d922d4a4b31bc462799", size = 1086100 }, - { url = "https://files.pythonhosted.org/packages/9f/6f/f7ba3dfe7be08bf58939324e0bb4f4aa605eff7f2c2ac140a41221cf50a4/jiter-0.9.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9765512bdeae269843e6615377f48123432da247e18048d05e9c5685377c241c", size = 1118922 }, - { url = "https://files.pythonhosted.org/packages/b5/4e/b1f4d9bdba81de293e1b8672598300a9195cf3d77b0acc5f331a75695b58/jiter-0.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f15cdbdc1e1e89e0d9ea581de63e03975043a4b40ab87d5554fdc440357b771", size = 1212327 }, - { url = "https://files.pythonhosted.org/packages/3e/ab/e417aaf5a62067bd91c5f7ed4e5ab83bd46f349449adde1159ad8e2d3a21/jiter-0.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1a639b2cfe56b5b687c678ed45d68f46dfb922c2f338fdfb227eb500053929d", size = 1220860 }, - { url = "https://files.pythonhosted.org/packages/1e/50/c5ba756c641ca8ebc1e4ff07c03ce5c8ef5052b0238f514436f8de3c9fc4/jiter-0.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41955c9d83c8470de9cc64c97b04a3ffd2f32815bb2c4307f44d8e21542b74df", size = 344077 }, - { url = "https://files.pythonhosted.org/packages/c6/b3/bd7d8d4bad65aa1f4a20562233080054149785c0d7f7b9027e761335d882/jiter-0.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f26f6d42c330e26a6ba3471b390364faad96f3ca965a6c579957810b0c078efa", size = 1148785 }, - { url = "https://files.pythonhosted.org/packages/c0/12/bfd9a167709f96171312d1e0ae2c1be70a167abcc3bff6f3441967e3626a/jiter-0.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a23e01bd7e918f27f02d3df8721b8a395211070a8a65aeb353209b8c72720cf", size = 1255962 }, - { url = "https://files.pythonhosted.org/packages/5f/3c/3a79020862d2511b854b350bc9229cf228fd38b836e94f274ca940e22e95/jiter-0.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8a96ad217989dd9df661711c3fa2e6fb2601c4bbb482e89718110bdafbc16c9e", size = 1257561 }, - { url = "https://files.pythonhosted.org/packages/93/d3/7f6f8e57613d4947a872980befa6af19de9252e310ea4a512eed0fe1e064/jiter-0.9.1-cp38-cp38-win32.whl", hash = "sha256:4b180e7baa4747b3834c5a9202b1ba30dc64797f45236d9142cdb2a8807763cf", size = 215019 }, - { url = "https://files.pythonhosted.org/packages/9b/5d/b6f0cd60c8f702936f253644a92dee19e2c82010290e4607af462033351f/jiter-0.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:baf881de1fbc7b3343cce24f75a2ab6350e03fc13d16d00f452929788a6cdc3f", size = 199563 }, - { url = "https://files.pythonhosted.org/packages/4f/3a/a8a4768af26578c87894bb130bcd6fb6c97f4cb36ed7a20a664412d41935/jiter-0.9.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ec95aa1b433c50b2b129456b4680b239ec93206ea3f86cfd41b6a70be5beb2f3", size = 313942 }, - { url = "https://files.pythonhosted.org/packages/63/74/05977891db48000d985a5f573493c43adf0f190eada670e51b92c9ed9139/jiter-0.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5d92cb50d135dbdd33b638fa2e0c6af25e1d635d38da13aa9ab05d021fb0c869", size = 308160 }, - { url = "https://files.pythonhosted.org/packages/21/54/75f529e90442c8ad41acd8cf08323a4f3dcaa105710b2c8a1fda56e3a462/jiter-0.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b146dc2464f1d96007271d08bdf79288a5f1aa4aae5329eb79dcffb1181c703e", size = 1086503 }, - { url = "https://files.pythonhosted.org/packages/bf/fa/02532a7ce7b712c576125d4f2614e77bc897c95b2b15e21ee25f42b3ff34/jiter-0.9.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcf20ba858658ecd54b4710172d92009afa66d41d967c86d11607592a3c220fa", size = 1120444 }, - { url = "https://files.pythonhosted.org/packages/91/c2/ab8cebaea6f2691eddcc5b6c67deb1399adbd85f12ad836f7cd77be78bf8/jiter-0.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:147fccc44bebdb672d4c601e9312730488b840d415e201e89c8ea0929a63dacf", size = 1212370 }, - { url = "https://files.pythonhosted.org/packages/13/e3/90dddb7877b67cc0e1ddb864c2ca74314def26ff6542431a6e3061e0f805/jiter-0.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a428061aae26efaa6fb690ef9e7d6224aefe4eef7524165d073beb3cdad75f6f", size = 1221210 }, - { url = "https://files.pythonhosted.org/packages/81/76/90ee847519a94a4a1a8bad7addce7019f424aea03c55eacf068469226760/jiter-0.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7164d92bb901784bd3c098ac0b0beae4306ea6c741dbd3a375449a8affc5366", size = 353774 }, - { url = "https://files.pythonhosted.org/packages/59/a6/614a5d672d4b9c6bc9ad34579f0522577a0a78cc265069fca96543a832ca/jiter-0.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:93049a562233808914a2b938b0c745d7049db1667b3f42f0f5cf48e617393ba5", size = 1148581 }, - { url = "https://files.pythonhosted.org/packages/2d/94/c100147c310361fa83e25c4c6ce17723532147580252962b89e6085795c2/jiter-0.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f6dcf2cb16cc15d82a018e20eeaf169e6f6cd8c426f4c312ebe11710c623bed2", size = 1256636 }, - { url = "https://files.pythonhosted.org/packages/51/9a/dc82e218ba839052899df555e34f16b8ad1d7da9c01be208f65a5bf0083c/jiter-0.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2da9d485a7c526817cde9ff8b3394fa50ff5b782b86b6896378a3ba8844550f2", size = 1258099 }, - { url = "https://files.pythonhosted.org/packages/58/d5/d853e069624038950265ac0e877985b249049b624e925dab6cd11035140c/jiter-0.9.1-cp39-cp39-win32.whl", hash = "sha256:ea58c155d827d24e5ba8d7958ec4738b26be0894c0881a91d88b39ff48bb06c9", size = 214611 }, - { url = "https://files.pythonhosted.org/packages/cb/8d/7b6b1ee6e3d9d1a06237bbdfe4c6bb21baf323d3f70a0cc8f203de40c6b2/jiter-0.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:be2e911ecdb438951290c2079fe4190e7cc5be9e849df4caeb085b83ed620ff6", size = 211171 }, -] - [[package]] name = "jiter" version = "0.10.0" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version > '3.11' and sys_platform == 'darwin'", - "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.11' and sys_platform == 'darwin'", - "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.10.*' and sys_platform == 'darwin'", - "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", - "python_full_version > '3.11' and sys_platform == 'win32'", - "python_full_version == '3.11' and sys_platform == 'win32'", - "python_full_version == '3.10.*' and sys_platform == 'win32'", - "python_full_version == '3.9.*' and sys_platform == 'win32'", -] sdist = { url = "https://files.pythonhosted.org/packages/ee/9d/ae7ddb4b8ab3fb1b51faf4deb36cb48a4fbbd7cb36bad6a5fca4741306f7/jiter-0.10.0.tar.gz", hash = "sha256:07a7142c38aacc85194391108dc91b5b57093c978a9932bd86a36862759d9500", size = 162759 } wheels = [ { url = "https://files.pythonhosted.org/packages/be/7e/4011b5c77bec97cb2b572f566220364e3e21b51c48c5bd9c4a9c26b41b67/jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303", size = 317215 }, @@ -5586,7 +5504,7 @@ version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12", version = "12.1.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9' and sys_platform != 'win32'" }, - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -5614,7 +5532,7 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -5672,9 +5590,9 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -5705,7 +5623,7 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -5799,17 +5717,14 @@ name = "openai" version = "1.84.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "anyio", version = "4.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "anyio", version = "4.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "distro" }, - { name = "httpx" }, - { name = "jiter", version = "0.9.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "jiter", version = "0.10.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "pydantic", version = "1.10.22", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "distro", marker = "python_full_version >= '3.9'" }, + { name = "httpx", marker = "python_full_version >= '3.9'" }, + { name = "jiter", marker = "python_full_version >= '3.9'" }, { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "sniffio" }, - { name = "tqdm" }, - { name = "typing-extensions" }, + { name = "sniffio", marker = "python_full_version >= '3.9'" }, + { name = "tqdm", marker = "python_full_version >= '3.9'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.9'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/91/a3/128caf24e116f48fad3e4d5122cdf84db06c5127911849d51663c66158c8/openai-1.84.0.tar.gz", hash = "sha256:4caa43bdab262cc75680ce1a2322cfc01626204074f7e8d9939ab372acf61698", size = 467066 } wheels = [ @@ -9890,8 +9805,8 @@ name = "xformers" version = "0.0.29.post2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/27/ed/04ec7ef97a7e1c836add41ef5a2aef8cbdd45c0190ca42cc08f3c21e2b7b/xformers-0.0.29.post2.tar.gz", hash = "sha256:6ca3d1a6db6f2abff25c1154adee96987f77f4dfd5141771805afa5fc13e9395", size = 8468494 } wheels = [ @@ -9905,12 +9820,12 @@ name = "xgrammar" version = "0.1.18" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ninja", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, - { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, - { name = "sentencepiece", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, - { name = "tiktoken", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, - { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "ninja", marker = "python_full_version >= '3.9'" }, + { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "sentencepiece", marker = "python_full_version >= '3.9'" }, + { name = "tiktoken", marker = "python_full_version >= '3.9'" }, + { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "triton", version = "3.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/8f/c3/22c9eeab6ee1dd6d0513d227e9d307fd20a0491db58f1f04bc5d566d13dc/xgrammar-0.1.18.tar.gz", hash = "sha256:a0438a0f9262fff1d0e4f184268eb759f094243edce92b67eb7aa5f245c47471", size = 1697230 }