From 7df2f857e54942665674f961daf803483aefa45b Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 8 Apr 2026 02:38:40 +0530 Subject: [PATCH 1/2] [REFACTOR] Replace CostCalculationHelper with litellm.cost_per_token Move cost calculation from platform-service to sdk1's Audit class, using litellm's built-in cost_per_token() instead of a custom helper that fetched pricing data from an external URL. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../platform_service/controller/platform.py | 15 +-- .../src/unstract/platform_service/env.py | 5 - .../helper/cost_calculation.py | 124 ------------------ .../src/unstract/platform_service/utils.py | 20 --- unstract/sdk1/src/unstract/sdk1/audit.py | 37 +++++- 5 files changed, 33 insertions(+), 168 deletions(-) delete mode 100644 platform-service/src/unstract/platform_service/helper/cost_calculation.py diff --git a/platform-service/src/unstract/platform_service/controller/platform.py b/platform-service/src/unstract/platform_service/controller/platform.py index cb8dd29e06..85f759e209 100644 --- a/platform-service/src/unstract/platform_service/controller/platform.py +++ b/platform-service/src/unstract/platform_service/controller/platform.py @@ -15,7 +15,6 @@ from unstract.platform_service.helper.adapter_instance import ( AdapterInstanceRequestHelper, ) -from unstract.platform_service.helper.cost_calculation import CostCalculationHelper from unstract.platform_service.helper.prompt_studio import PromptStudioRequestHelper platform_bp = Blueprint("platform", __name__) @@ -213,23 +212,11 @@ def usage() -> Any: usage_type = payload.get("usage_type", "") llm_usage_reason = payload.get("llm_usage_reason", "") model_name = payload.get("model_name", "") - provider = payload.get("provider", "") embedding_tokens = payload.get("embedding_tokens", 0) prompt_tokens = payload.get("prompt_tokens", 0) completion_tokens = payload.get("completion_tokens", 0) total_tokens = payload.get("total_tokens", 0) - input_tokens = prompt_tokens - if usage_type == "embedding": - input_tokens = embedding_tokens - cost_in_dollars = 0.0 - if provider: - cost_calculation_helper = CostCalculationHelper() - cost_in_dollars = cost_calculation_helper.calculate_cost( - model_name=model_name, - provider=provider, - input_tokens=input_tokens, - output_tokens=completion_tokens, - ) + cost_in_dollars = payload.get("cost_in_dollars", 0.0) usage_id = uuid.uuid4() current_time = datetime.now() query = f""" diff --git a/platform-service/src/unstract/platform_service/env.py b/platform-service/src/unstract/platform_service/env.py index 2bcf2da382..a0c2f6fcb9 100644 --- a/platform-service/src/unstract/platform_service/env.py +++ b/platform-service/src/unstract/platform_service/env.py @@ -18,11 +18,6 @@ class Env: PG_BE_PASSWORD = os.environ.get("PG_BE_PASSWORD") PG_BE_DATABASE = os.environ.get("PG_BE_DATABASE") ENCRYPTION_KEY = EnvManager.get_required_setting("ENCRYPTION_KEY") - MODEL_PRICES_URL = EnvManager.get_required_setting("MODEL_PRICES_URL") - MODEL_PRICES_TTL_IN_DAYS = int( - EnvManager.get_required_setting("MODEL_PRICES_TTL_IN_DAYS") - ) - MODEL_PRICES_FILE_PATH = EnvManager.get_required_setting("MODEL_PRICES_FILE_PATH") APPLICATION_NAME = EnvManager.get_required_setting( "APPLICATION_NAME", "unstract-platform-service" ) diff --git a/platform-service/src/unstract/platform_service/helper/cost_calculation.py b/platform-service/src/unstract/platform_service/helper/cost_calculation.py deleted file mode 100644 index ba2e572330..0000000000 --- a/platform-service/src/unstract/platform_service/helper/cost_calculation.py +++ /dev/null @@ -1,124 +0,0 @@ -import json -from datetime import UTC, datetime, timedelta -from typing import Any - -import requests -from flask import current_app as app - -from unstract.platform_service.env import Env -from unstract.platform_service.utils import format_float_positional -from unstract.sdk1.exceptions import FileStorageError -from unstract.sdk1.file_storage import EnvHelper, StorageType - - -class CostCalculationHelper: - def __init__( - self, - url: str = Env.MODEL_PRICES_URL, - ttl_days: int = Env.MODEL_PRICES_TTL_IN_DAYS, - file_path: str = Env.MODEL_PRICES_FILE_PATH, - ): - self.ttl_days = ttl_days - self.url = url - self.file_path = file_path - - try: - self.file_storage = EnvHelper.get_storage( - StorageType.PERMANENT, "FILE_STORAGE_CREDENTIALS" - ) - except KeyError as e: - app.logger.error(f"Required credentials is missing in the env: {str(e)}") - raise e - except FileStorageError as e: - app.logger.error( - "Error while initialising storage: %s", - e, - stack_info=True, - exc_info=True, - ) - raise e - - self.model_token_data = self._get_model_token_data() - - def calculate_cost( - self, model_name: str, provider: str, input_tokens: int, output_tokens: int - ) -> str: - cost = 0.0 - item = None - - if not self.model_token_data: - return json.loads(format_float_positional(cost)) - # Filter the model objects by model name - filtered_models = { - k: v for k, v in self.model_token_data.items() if k.endswith(model_name) - } - # Check if the lite llm provider starts with the given provider - for _, model_info in filtered_models.items(): - if provider in model_info.get("litellm_provider", ""): - item = model_info - break - if item: - input_cost_per_token = item.get("input_cost_per_token", 0) - output_cost_per_token = item.get("output_cost_per_token", 0) - cost += input_cost_per_token * input_tokens - cost += output_cost_per_token * output_tokens - return format_float_positional(cost) - - def _get_model_token_data(self) -> dict[str, Any] | None: - try: - # File does not exist, fetch JSON data from API - if not self.file_storage.exists(self.file_path): - return self._fetch_and_save_json() - - file_mtime = self.file_storage.modification_time(self.file_path) - file_expiry_date = file_mtime + timedelta(days=self.ttl_days) - file_expiry_date_utc = file_expiry_date.replace(tzinfo=UTC) - now_utc = datetime.now().replace(tzinfo=UTC) - - if now_utc < file_expiry_date_utc: - app.logger.info(f"Reading model token data from {self.file_path}") - # File exists and TTL has not expired, read and return content - file_contents = self.file_storage.read( - self.file_path, mode="r", encoding="utf-8" - ) - return json.loads(file_contents) - else: - # TTL expired, fetch updated JSON data from API - return self._fetch_and_save_json() - except Exception as e: - app.logger.warning( - "Error in calculate_cost: %s", e, stack_info=True, exc_info=True - ) - return None - - def _fetch_and_save_json(self) -> dict[str, Any] | None: - """Fetch model's price and token data from the URL. - - Caches it in a file with the mentioned TTL - - Returns: - Optional[dict[str, Any]]: JSON of model and price / token data - """ - try: - # Fetch updated JSON data from API - response = requests.get(self.url, timeout=10) - response.raise_for_status() - json_data = response.json() - # Save JSON data to file - self.file_storage.json_dump( - path=self.file_path, - data=json_data, - ensure_ascii=False, - indent=4, - ) - app.logger.info( - "File '%s' updated successfully with TTL set to %d days.", - self.file_path, - self.ttl_days, - ) - return json_data - except Exception as e: - app.logger.error( - "Error fetching data from API: %s", e, stack_info=True, exc_info=True - ) - return None diff --git a/platform-service/src/unstract/platform_service/utils.py b/platform-service/src/unstract/platform_service/utils.py index a677083780..b6b58380db 100644 --- a/platform-service/src/unstract/platform_service/utils.py +++ b/platform-service/src/unstract/platform_service/utils.py @@ -38,23 +38,3 @@ def raise_for_missing_envs(cls) -> None: cls.missing_settings ) raise ValueError(ERROR_MESSAGE) - - -def format_float_positional(value: float, precision: int = 10) -> str: - """Format floats to a string. - - Formats a floating-point number to a string with the specified precision, - removing trailing zeros and the decimal point if not needed. - - Args: - value (float): The floating-point number to format. - precision (int, optional): The number of decimal places to - include in the formatted output. Defaults to 10. - - Returns: - str: The formatted string representation of the float, - with unnecessary trailing zeros and the decimal point - removed if the float is an integer. - """ - formatted: str = f"{value:.{precision}f}" - return formatted.rstrip("0").rstrip(".") if "." in formatted else formatted diff --git a/unstract/sdk1/src/unstract/sdk1/audit.py b/unstract/sdk1/src/unstract/sdk1/audit.py index 0396ab33a8..c545a029fa 100644 --- a/unstract/sdk1/src/unstract/sdk1/audit.py +++ b/unstract/sdk1/src/unstract/sdk1/audit.py @@ -1,12 +1,16 @@ +import logging from typing import Any import requests +from litellm import cost_per_token from llama_index.core.callbacks import CBEventType, TokenCountingHandler from unstract.sdk1.constants import LogLevel, ToolEnv from unstract.sdk1.platform import PlatformHelper from unstract.sdk1.tool.stream import StreamMixin from unstract.sdk1.utils.common import TokenCounterCompat +logger = logging.getLogger(__name__) + class Audit(StreamMixin): """The 'Audit' class is responsible for pushing usage data to the platform service. @@ -79,8 +83,30 @@ def push_usage_data( if event_type == "llm": llm_usage_reason = kwargs.get("llm_usage_reason", "") - if model_name is not None and model_name != "": - model_name = model_name.split("/", 1)[-1] + prompt_tokens = token_counter.prompt_llm_token_count + completion_tokens = token_counter.completion_llm_token_count + input_tokens = prompt_tokens + if event_type == "embedding": + input_tokens = token_counter.total_embedding_token_count + + # Compute cost using the full model name (e.g. "azure/gpt-4o") + # before stripping the provider prefix for DB storage. + cost_in_dollars = 0.0 + if model_name: + try: + prompt_cost, completion_cost = cost_per_token( + model=model_name, + prompt_tokens=input_tokens, + completion_tokens=completion_tokens, + ) + cost_in_dollars = prompt_cost + completion_cost + except Exception: + logger.debug( + "Cost lookup failed for model %s, defaulting to 0", model_name + ) + + # Strip provider prefix for DB storage (e.g. "azure/gpt-4o" -> "gpt-4o") + display_model_name = model_name.split("/", 1)[-1] if model_name else "" data = { "workflow_id": workflow_id, @@ -89,12 +115,13 @@ def push_usage_data( "run_id": run_id, "usage_type": event_type, "llm_usage_reason": llm_usage_reason, - "model_name": model_name, + "model_name": display_model_name, "provider": provider, "embedding_tokens": token_counter.total_embedding_token_count, - "prompt_tokens": token_counter.prompt_llm_token_count, - "completion_tokens": token_counter.completion_llm_token_count, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, "total_tokens": token_counter.total_llm_token_count, + "cost_in_dollars": cost_in_dollars, } url = f"{base_url}/usage" From 1a57c387436534fa14938aeb1a027142483aaa56 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 8 Apr 2026 11:12:57 +0530 Subject: [PATCH 2/2] [REFACTOR] Zero out completion_tokens for embedding cost calculation Explicitly set completion_tokens to 0 for embedding events before calling cost_per_token, making the assumption that embeddings have no completion tokens explicit rather than relying on the counter always being zero. Co-Authored-By: Claude Opus 4.6 (1M context) --- unstract/sdk1/src/unstract/sdk1/audit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unstract/sdk1/src/unstract/sdk1/audit.py b/unstract/sdk1/src/unstract/sdk1/audit.py index c545a029fa..abd764763d 100644 --- a/unstract/sdk1/src/unstract/sdk1/audit.py +++ b/unstract/sdk1/src/unstract/sdk1/audit.py @@ -88,6 +88,7 @@ def push_usage_data( input_tokens = prompt_tokens if event_type == "embedding": input_tokens = token_counter.total_embedding_token_count + completion_tokens = 0 # Compute cost using the full model name (e.g. "azure/gpt-4o") # before stripping the provider prefix for DB storage.