diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 1eb07adfde44..5a6db12edded 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -443,6 +443,17 @@ def has_request_pools(self, model_id: str, device_id: torch.device = None) -> bo return device_id in self._request_pool_map[model_id] return True + def has_running_pools(self, model_id: str) -> bool: + """ + Check if there are running pools for the given model_id. + """ + if model_id not in self._request_pool_map: + return False + for device_id, pool_group in self._request_pool_map[model_id].items(): + if pool_group.get_running_pool_count(): + return True + return False + def get_request_pools_group( self, model_id: str, device_id: torch.device ) -> Optional[PoolGroup]: diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index ebbb036a9dca..180cc00ff492 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,7 +18,6 @@ import threading import time -from typing import Dict import torch import torch.multiprocessing as mp @@ -69,9 +68,6 @@ class InferenceManager: def __init__(self): self._model_manager = ModelManager() self._backend = DeviceManager() - self._model_mem_usage_map: Dict[str, int] = ( - {} - ) # store model memory usage for each model self._result_queue = mp.Queue() self._result_wrapper_map = {} self._result_wrapper_lock = threading.RLock() @@ -207,14 +203,14 @@ def _run( ): raise NumericalRangeException( "output_length", + output_length, 1, AINodeDescriptor() .get_config() .get_ain_inference_max_output_length(), - output_length, ) - if self._pool_controller.has_request_pools(model_id=model_id): + if self._pool_controller.has_running_pools(model_id): infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id,