Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
python-version: [
'3.9',
'3.10',
'3.11',
'3.12',
]
max-parallel: 4

steps:
- uses: actions/checkout@v4
- name: Install uv and set Python to ${{ matrix.python-version }}
uses: astral-sh/setup-uv@v6
with:
version: "0.7.20"
version: "0.8.10"
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv sync --group dev --group docs
uv sync --extra dev --extra docs --extra vllm
uv run python -m ensurepip
- name: Check types
run: |
uv run mypy app
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.6.10"
version: "0.8.10"
python-version: "3.10"
- name: Install dependencies
run: |
uv sync --group dev --group docs --group vllm
uv sync --extra dev --extra docs --extra vllm
- name: Run unit tests
run: |
uv run pytest -v tests/app --cov --cov-report=html:coverage_reports #--random-order
Expand Down
6 changes: 3 additions & 3 deletions app/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def _get_app(
tags = TagsStreamable
else:
tags = Tags
tags_metadata = [{ # type: ignore
"name": tag.name, # type: ignore
"description": tag.value # type: ignore
tags_metadata = [{
"name": tag.name,
"description": tag.value
} for tag in tags]
app = FastAPI(
title="CogStack ModelServe",
Expand Down
1 change: 0 additions & 1 deletion app/api/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,4 @@ async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> As
SQLAlchemyUserDatabase: A database instance initialised with the given session and the User model.
"""

# TODO: fix this type checking error
yield SQLAlchemyUserDatabase(session, User)
3 changes: 2 additions & 1 deletion app/api/routers/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore
assert hasattr(model_service, "tokenizer"), "Model service doesn't have a tokenizer"
prompt = get_prompt_from_messages(model_service.tokenizer, messages)
if stream:
return StreamingResponse(
_stream(prompt, max_tokens, temperature),
Expand Down
6 changes: 3 additions & 3 deletions app/api/routers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
self,
content: Any,
status_code: int = 200,
max_chunk_size: Optional[int] = 1024,
max_chunk_size: int = 1024,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
Expand All @@ -161,8 +161,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
})
response_started = True
line_bytes = line.encode("utf-8")
for i in range(0, len(line_bytes), self.max_chunk_size): # type: ignore
chunk = line_bytes[i:i + self.max_chunk_size] # type: ignore
for i in range(0, len(line_bytes), self.max_chunk_size):
chunk = line_bytes[i:i + self.max_chunk_size]
await send({
"type": "http.response.body",
"body": chunk,
Expand Down
2 changes: 1 addition & 1 deletion app/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ async def generate_text(
params = SamplingParams(max_tokens=max_tokens)

conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore
prompt_tokens = apply_hf_chat_template( # type: ignore
prompt_tokens = apply_hf_chat_template( # type: ignore
tokenizer,
conversation=conversation,
tools=None,
Expand Down
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Settings(BaseSettings): # type: ignore
TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to
DEBUG: str = "false" # if "true", the debug mode is switched on

class Config:
Expand Down
3 changes: 3 additions & 0 deletions app/envs/.env
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,8 @@ TRAINING_SAFE_MODEL_SERIALISATION=false
# The strategy used for aggregating the predictions of the Hugging Face NER model
HF_PIPELINE_AGGREGATION_STRATEGY=simple

# The comma-separated names of ontologies for MedCAT2 to map to
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10

# If "true", the debug mode is switched on
DEBUG=false
4 changes: 3 additions & 1 deletion app/management/tracker_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import os
import socket
import mlflow
Expand Down Expand Up @@ -114,7 +115,7 @@ def send_model_stats(stats: Dict, step: int) -> None:
step (int): The current step in the training or evaluation process.
"""

metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items()}
metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items() if isinstance(val, (int, float))}
mlflow.log_metrics(metrics, step)

@staticmethod
Expand Down Expand Up @@ -563,6 +564,7 @@ def get_metrics_by_job_id(self, job_id: str) -> List[Dict[str, Any]]:
metrics_history = {}
for metric in run.data.metrics.keys():
metrics_history[metric] = [m.value for m in self.mlflow_client.get_metric_history(run_id=run.info.run_id, key=metric)]
metrics_history["concepts"] = ast.literal_eval(run.data.tags.get("training.entity.classes", "[]"))
metrics.append(metrics_history)
return metrics
except MlflowException as e:
Expand Down
59 changes: 32 additions & 27 deletions app/model_services/medcat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pandas as pd

from multiprocessing import cpu_count
from typing import Dict, List, Optional, TextIO, Tuple, Any
from typing import Dict, List, Optional, TextIO, Tuple, Any, Set, Union
from medcat.cat import CAT
from medcat.data.entities import Entities, OnlyCUIEntities
from app import __version__ as app_version
from app.model_services.base import AbstractModelService
from app.trainers.medcat_trainer import MedcatSupervisedTrainer, MedcatUnsupervisedTrainer
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
base_model_file (Optional[str]): The model package file name. Defaults to None.
"""
super().__init__(config)
self._model: CAT = None
self._model: Optional[CAT] = None
self._config = config
self._model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model"))
self._model_pack_path = os.path.join(self._model_parent_dir, base_model_file or config.BASE_MODEL_FILE)
Expand All @@ -55,7 +56,7 @@ def __init__(
self.model_name = model_name or "MedCAT model"

@property
def model(self) -> CAT:
def model(self) -> Optional[CAT]:
"""Getter for the MedCAT model."""

return self._model
Expand Down Expand Up @@ -113,7 +114,7 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->

model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path))
if unpack_model_data_package(model_file_path, model_path):
cat = CAT.load_model_pack(model_file_path.replace(".tar.gz", ".zip"), *args, **kwargs)
cat = CAT.load_model_pack(model_file_path.replace(".tar.gz", ".zip"), **kwargs)
logger.info("Model package loaded from %s", os.path.normpath(model_file_path))
return cat
else:
Expand All @@ -131,18 +132,20 @@ def init_model(self, *args: Any, **kwargs: Any) -> None:
logger.warning("Model service is already initialised and can be initialised only once")
else:
if non_default_device_is_available(get_settings().DEVICE):
self._model = self.load_model(
self._model_pack_path,
meta_cat_config_dict={"general": {"device": get_settings().DEVICE}},
)
self._model.config.general["device"] = get_settings().DEVICE
self._model = self.load_model(self._model_pack_path)
for addon in self._model.get_addons():
addon.config.general.device = get_settings().DEVICE # type: ignore
self._model.config.general.device = get_settings().DEVICE # type: ignore
else:
self._model = self.load_model(self._model_pack_path)
self._set_tuis_filtering()
if self._enable_trainer:
self._supervised_trainer = MedcatSupervisedTrainer(self)
self._unsupervised_trainer = MedcatUnsupervisedTrainer(self)
self._metacat_trainer = MetacatTrainer(self)
self._model.config.general.map_to_other_ontologies = [ # type: ignore
tui.strip() for tui in self._config.MEDCAT2_MAPPED_ONTOLOGIES.split(",")
]

def info(self) -> ModelCard:
"""
Expand All @@ -168,10 +171,8 @@ def annotate(self, text: str) -> List[Annotation]:
List[Annotation]: A list of annotations containing the extracted named entities.
"""

doc = self.model.get_entities(
text,
addl_info=["cui2icd10", "cui2opcs4", "cui2ontologies", "cui2snomed", "cui2athena_ids"],
)
assert self.model is not None, "Model is not initialised"
doc = self.model.get_entities(text)
return [load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc)]

def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
Expand All @@ -187,12 +188,12 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:

batch_size_chars = 500000

docs = self.model.multiprocessing_batch_char_size(
self._data_iterator(texts),
assert self.model is not None, "Model is not initialised"
docs = {i: result for i, (_, result) in enumerate(self.model.get_entities_multi_texts(
texts,
batch_size_chars=batch_size_chars,
nproc=max(int(cpu_count() / 2), 1),
addl_info=["cui2icd10", "cui2opcs4", "cui2ontologies", "cui2snomed", "cui2athena_ids"],
)
n_process=max(int(cpu_count() / 2), 1),
))}
docs = dict(sorted(docs.items(), key=lambda x: x[0]))
annotations_list = []
for _, doc in docs.items():
Expand Down Expand Up @@ -342,12 +343,12 @@ def train_metacat(
**hyperparams,
)

def get_records_from_doc(self, doc: Dict) -> List[Dict]:
def get_records_from_doc(self, doc: Union[Dict, Entities, OnlyCUIEntities]) -> List[Dict]:
"""
Extracts and formats entity records from a document dictionary.

Args:
doc (Dict): The document dictionary containing extracted named entities.
doc (Union[Dict, Entities, OnlyCUIEntities]): The document dictionary containing extracted named entities.

Returns:
List[Dict]: A list of formatted entity records.
Expand All @@ -362,9 +363,9 @@ def get_records_from_doc(self, doc: Dict) -> List[Dict]:
if "athena_ids" in row and row["athena_ids"]:
df.loc[idx, "athena_ids"] = [athena_id["code"] for athena_id in row["athena_ids"]]
if self._config.INCLUDE_SPAN_TEXT == "true":
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "source_value": "text", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "source_value": "text", "type_ids": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
else:
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "type_ids": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df = self._retrieve_meta_annotations(df)
records = df.to_dict("records")
return records
Expand All @@ -384,15 +385,19 @@ def _retrieve_meta_annotations(df: pd.DataFrame) -> pd.DataFrame:

def _set_tuis_filtering(self) -> None:
# this patching may not be needed after the base 1.4.x model is fixed in the future
assert self._model is not None, "Model is not initialised"
if self._model.cdb.addl_info.get("type_id2name", {}) == {}:
self._model.cdb.addl_info["type_id2name"] = TYPE_ID_TO_NAME_PATCH

tuis2cuis = self._model.cdb.addl_info.get("type_id2cuis")
model_tuis = set(tuis2cuis.keys())
type_id2info = self._model.cdb.type_id2info
model_tuis = set(type_id2info.keys())
if self._whitelisted_tuis == {""}:
return
assert self._whitelisted_tuis.issubset(model_tuis), f"Unrecognisable Type Unique Identifier(s): {self._whitelisted_tuis - model_tuis}"
whitelisted_cuis = set()
whitelisted_cuis: Set = set()
for tui in self._whitelisted_tuis:
whitelisted_cuis.update(tuis2cuis.get(tui, {}))
self._model.cdb.config.linking.filters = {"cuis": whitelisted_cuis}
type_info = type_id2info.get(tui)
if type_info is None:
continue
whitelisted_cuis.update(type_info.cuis)
self._model.config.components.linking.filters.cuis = whitelisted_cuis
Loading