diff --git a/app/service/learning_svc.py b/app/service/learning_svc.py index ed1cf315e..f8216081f 100644 --- a/app/service/learning_svc.py +++ b/app/service/learning_svc.py @@ -1,6 +1,7 @@ import itertools import glob import re +import time from base64 import b64decode from importlib import import_module @@ -17,6 +18,8 @@ def __init__(self): self.model = set() self.parsers = self.add_parsers('app/learning') self.re_variable = re.compile(r'#{(.*?)}', flags=re.DOTALL) + self._model_dirty = True + self._model_built_at = 0.0 self.log.debug('Loaded %d parsers' % len(self.parsers)) @staticmethod @@ -28,13 +31,31 @@ def add_parsers(directory): return parsers async def build_model(self): + # get_config() may return a string; coerce to int and clamp to >= 1 second. + raw_ttl = self.get_config('model_cache_ttl_seconds') + try: + cache_ttl = max(1, int(raw_ttl)) if raw_ttl is not None else 3600 + except (ValueError, TypeError): + cache_ttl = 3600 + if not self._model_dirty and (time.monotonic() - self._model_built_at) < cache_ttl: + self.log.debug('Skipping model rebuild - cache still valid (TTL=%ds)', cache_ttl) + return + # Build into a fresh local set to prevent stale entries from abilities + # that were removed since the last build. + new_model = set() for ability in await self.get_service('data_svc').locate('abilities'): for executor in ability.executors: if executor.command: variables = frozenset(re.findall(self.re_variable, executor.test)) if len(variables) > 1: # relationships require at least 2 variables - self.model.add(variables) - self.model = set(self.model) + new_model.add(variables) + self.model = new_model + self._model_dirty = False + self._model_built_at = time.monotonic() + + def invalidate_model_cache(self): + """Mark the model cache as dirty so it will be rebuilt on next call.""" + self._model_dirty = True async def learn(self, facts, link, blob, operation=None): decoded_blob = b64decode(blob).decode('utf-8') diff --git a/conf/default.yml b/conf/default.yml index ba0653c94..6e3976a4c 100644 --- a/conf/default.yml +++ b/conf/default.yml @@ -22,6 +22,7 @@ app.contact.ftp.user: caldera_user app.contact.tcp: 0.0.0.0:7010 app.contact.udp: 0.0.0.0:7011 app.contact.websocket: 0.0.0.0:7012 +model_cache_ttl_seconds: 3600 objects.planners.default: atomic crypt_salt: REPLACE_WITH_RANDOM_VALUE encryption_key: ADMIN123 diff --git a/tests/test_learning_svc_cache.py b/tests/test_learning_svc_cache.py new file mode 100644 index 000000000..8b0cfcefa --- /dev/null +++ b/tests/test_learning_svc_cache.py @@ -0,0 +1,81 @@ +import asyncio +import time +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.service.learning_svc import LearningService + + +class TestLearningServiceCache: + def test_initial_state_dirty(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + assert svc._model_dirty is True + assert svc._model_built_at == 0.0 + + def test_invalidate_cache(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = False + svc._model_built_at = time.monotonic() + svc.invalidate_model_cache() + assert svc._model_dirty is True + + def test_skip_rebuild_when_cache_valid(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = False + svc._model_built_at = time.monotonic() + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=3600) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_not_called() + + def test_rebuild_when_dirty(self): + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = True + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=3600) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once() + assert svc._model_dirty is False + assert svc._model_built_at > 0 + + def test_rebuild_when_ttl_expired(self): + """Rebuild must occur when cache TTL has expired even if not dirty.""" + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + ttl = 60 + svc._model_dirty = False + # Set built-at to ttl+1 seconds in the past so it is expired. + svc._model_built_at = time.monotonic() - (ttl + 1) + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + svc.get_config = MagicMock(return_value=ttl) + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once() + assert svc._model_dirty is False + + def test_cache_ttl_as_string_is_coerced(self): + """get_config() may return a string; it must be coerced to int without error.""" + with patch.object(LearningService, 'add_service', return_value=MagicMock()): + with patch.object(LearningService, 'add_parsers', return_value=[]): + svc = LearningService() + svc._model_dirty = True + mock_data_svc = MagicMock() + mock_data_svc.locate = AsyncMock(return_value=[]) + svc.get_service = MagicMock(return_value=mock_data_svc) + # Return TTL as a string (common when loaded from config files). + svc.get_config = MagicMock(return_value='3600') + asyncio.run(svc.build_model()) + mock_data_svc.locate.assert_called_once()