From babcc346b496e25ee970c4f3049a820943a3a6f0 Mon Sep 17 00:00:00 2001 From: deacon Date: Sun, 15 Mar 2026 13:19:01 -0400 Subject: [PATCH 1/2] fix: cache trained ML model in learning_svc across operation completions Add TTL-based caching to avoid unnecessary model rebuilds. Model is only rebuilt when marked dirty or after cache TTL expires. --- app/service/learning_svc.py | 13 +++++++++ conf/default.yml | 1 + tests/test_learning_svc_cache.py | 49 ++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 tests/test_learning_svc_cache.py diff --git a/app/service/learning_svc.py b/app/service/learning_svc.py index ed1cf315e..5c9a7049f 100644 --- a/app/service/learning_svc.py +++ b/app/service/learning_svc.py @@ -1,6 +1,7 @@ import itertools import glob import re +import time from base64 import b64decode from importlib import import_module @@ -17,6 +18,8 @@ def __init__(self): self.model = set() self.parsers = self.add_parsers('app/learning') self.re_variable = re.compile(r'#{(.*?)}', flags=re.DOTALL) + self._model_dirty = True + self._model_built_at = 0.0 self.log.debug('Loaded %d parsers' % len(self.parsers)) @staticmethod @@ -28,6 +31,10 @@ def add_parsers(directory): return parsers async def build_model(self): + cache_ttl = self.get_config('model_cache_ttl_seconds') or 3600 + if not self._model_dirty and (time.monotonic() - self._model_built_at) < cache_ttl: + self.log.debug('Skipping model rebuild - cache still valid (TTL=%ds)', cache_ttl) + return for ability in await self.get_service('data_svc').locate('abilities'): for executor in ability.executors: if executor.command: @@ -35,6 +42,12 @@ async def build_model(self): if len(variables) > 1: # relationships require at least 2 variables self.model.add(variables) self.model = set(self.model) + self._model_dirty = False + self._model_built_at = time.monotonic() + + def invalidate_model_cache(self): + """Mark the model cache as dirty so it will be rebuilt on next call.""" + self._model_dirty = True async def learn(self, facts, link, blob, operation=None): decoded_blob = b64decode(blob).decode('utf-8') diff --git a/conf/default.yml b/conf/default.yml index ba0653c94..6e3976a4c 100644 --- a/conf/default.yml +++ b/conf/default.yml @@ -22,6 +22,7 @@ app.contact.ftp.user: caldera_user app.contact.tcp: 0.0.0.0:7010 app.contact.udp: 0.0.0.0:7011 app.contact.websocket: 0.0.0.0:7012 +model_cache_ttl_seconds: 3600 objects.planners.default: atomic crypt_salt: REPLACE_WITH_RANDOM_VALUE encryption_key: ADMIN123 diff --git a/tests/test_learning_svc_cache.py b/tests/test_learning_svc_cache.py new file mode 100644 index 000000000..d510d1e98 --- /dev/null +++ b/tests/test_learning_svc_cache.py @@ -0,0 +1,49 @@ +import asyncio +import time +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.service.learning_svc import LearningService + + +class TestLearningServiceCache: + def test_initial_state_dirty(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + assert svc._model_dirty is True + assert svc._model_built_at == 0.0 + + def test_invalidate_cache(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = False + svc._model_built_at = time.monotonic() + svc.invalidate_model_cache() + assert svc._model_dirty is True + + def test_skip_rebuild_when_cache_valid(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = False + svc._model_built_at = time.monotonic() + mock_data_svc = AsyncMock() + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=3600) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_not_called() + + def test_rebuild_when_dirty(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = True + mock_data_svc = AsyncMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=3600) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once() + assert svc._model_dirty is False + assert svc._model_built_at > 0 From 270c3d38832c9ead3eb310ebbade94aa5fe275bb Mon Sep 17 00:00:00 2001 From: deacon Date: Mon, 16 Mar 2026 00:51:37 -0400 Subject: [PATCH 2/2] fix: address Copilot review feedback on learning-svc-model-cache - Coerce cache_ttl to int (clamped >= 1s) since get_config() may return string - Rebuild model into a fresh local set to remove stale entries from deleted abilities before assigning to self.model - Fix test: explicitly set mock_data_svc.locate = AsyncMock() before assert_not_called() to avoid false-positive child mock assertion - Add test for TTL-expiry rebuild (dirty=False but TTL elapsed) - Add test for string TTL coercion path --- app/service/learning_svc.py | 14 ++++++++++--- tests/test_learning_svc_cache.py | 36 ++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/app/service/learning_svc.py b/app/service/learning_svc.py index 5c9a7049f..f8216081f 100644 --- a/app/service/learning_svc.py +++ b/app/service/learning_svc.py @@ -31,17 +31,25 @@ def add_parsers(directory): return parsers async def build_model(self): - cache_ttl = self.get_config('model_cache_ttl_seconds') or 3600 + # get_config() may return a string; coerce to int and clamp to >= 1 second. + raw_ttl = self.get_config('model_cache_ttl_seconds') + try: + cache_ttl = max(1, int(raw_ttl)) if raw_ttl is not None else 3600 + except (ValueError, TypeError): + cache_ttl = 3600 if not self._model_dirty and (time.monotonic() - self._model_built_at) < cache_ttl: self.log.debug('Skipping model rebuild - cache still valid (TTL=%ds)', cache_ttl) return + # Build into a fresh local set to prevent stale entries from abilities + # that were removed since the last build. + new_model = set() for ability in await self.get_service('data_svc').locate('abilities'): for executor in ability.executors: if executor.command: variables = frozenset(re.findall(self.re_variable, executor.test)) if len(variables) > 1: # relationships require at least 2 variables - self.model.add(variables) - self.model = set(self.model) + new_model.add(variables) + self.model = new_model self._model_dirty = False self._model_built_at = time.monotonic() diff --git a/tests/test_learning_svc_cache.py b/tests/test_learning_svc_cache.py index d510d1e98..8b0cfcefa 100644 --- a/tests/test_learning_svc_cache.py +++ b/tests/test_learning_svc_cache.py @@ -28,7 +28,8 @@ def test_skip_rebuild_when_cache_valid(self): svc = LearningService() svc._model_dirty = False svc._model_built_at = time.monotonic() - mock_data_svc = AsyncMock() + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) svc.get_service = MagicMock(return_value=mock_data_svc) svc.get_config = MagicMock(return_value=3600) asyncio.run(svc.build_model()) @@ -39,7 +40,7 @@ def test_rebuild_when_dirty(self): with patch.object(LearningService, 'add_parsers', return_value=[]): svc = LearningService() svc._model_dirty = True - mock_data_svc = AsyncMock() + mock_data_svc = MagicMock() mock_data_svc.locate = AsyncMock(return_value=[]) svc.get_service = MagicMock(return_value=mock_data_svc) svc.get_config = MagicMock(return_value=3600) @@ -47,3 +48,34 @@ def test_rebuild_when_dirty(self): mock_data_svc.locate.assert_called_once() assert svc._model_dirty is False assert svc._model_built_at > 0 + + def test_rebuild_when_ttl_expired(self): + """Rebuild must occur when cache TTL has expired even if not dirty.""" + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + ttl = 60 + svc._model_dirty = False + # Set built-at to ttl+1 seconds in the past so it is expired. + svc._model_built_at = time.monotonic() - (ttl + 1) + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=ttl) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once() + assert svc._model_dirty is False + + def test_cache_ttl_as_string_is_coerced(self): + """get_config() may return a string; it must be coerced to int without error.""" + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = True + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + # Return TTL as a string (common when loaded from config files). + svc.get_config = MagicMock(return_value='3600') + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once()