Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions app/service/learning_svc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import glob
import re
import time
from base64 import b64decode
from importlib import import_module

Expand All @@ -17,6 +18,8 @@ def __init__(self):
self.model = set()
self.parsers = self.add_parsers('app/learning')
self.re_variable = re.compile(r'#{(.*?)}', flags=re.DOTALL)
self._model_dirty = True
self._model_built_at = 0.0
self.log.debug('Loaded %d parsers' % len(self.parsers))

@staticmethod
Expand All @@ -28,13 +31,31 @@ def add_parsers(directory):
return parsers

async def build_model(self):
# get_config() may return a string; coerce to int and clamp to >= 1 second.
raw_ttl = self.get_config('model_cache_ttl_seconds')
try:
cache_ttl = max(1, int(raw_ttl)) if raw_ttl is not None else 3600
except (ValueError, TypeError):
cache_ttl = 3600
if not self._model_dirty and (time.monotonic() - self._model_built_at) < cache_ttl:
self.log.debug('Skipping model rebuild - cache still valid (TTL=%ds)', cache_ttl)
return
# Build into a fresh local set to prevent stale entries from abilities
# that were removed since the last build.
new_model = set()
for ability in await self.get_service('data_svc').locate('abilities'):
for executor in ability.executors:
if executor.command:
variables = frozenset(re.findall(self.re_variable, executor.test))
if len(variables) > 1: # relationships require at least 2 variables
self.model.add(variables)
self.model = set(self.model)
new_model.add(variables)
self.model = new_model
self._model_dirty = False
self._model_built_at = time.monotonic()

def invalidate_model_cache(self):
"""Mark the model cache as dirty so it will be rebuilt on next call."""
self._model_dirty = True

async def learn(self, facts, link, blob, operation=None):
decoded_blob = b64decode(blob).decode('utf-8')
Expand Down
1 change: 1 addition & 0 deletions conf/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ app.contact.ftp.user: caldera_user
app.contact.tcp: 0.0.0.0:7010
app.contact.udp: 0.0.0.0:7011
app.contact.websocket: 0.0.0.0:7012
model_cache_ttl_seconds: 3600
objects.planners.default: atomic
crypt_salt: REPLACE_WITH_RANDOM_VALUE
encryption_key: ADMIN123
Expand Down
81 changes: 81 additions & 0 deletions tests/test_learning_svc_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
import time
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from app.service.learning_svc import LearningService


class TestLearningServiceCache:
def test_initial_state_dirty(self):
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
assert svc._model_dirty is True
assert svc._model_built_at == 0.0

def test_invalidate_cache(self):
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
svc._model_dirty = False
svc._model_built_at = time.monotonic()
svc.invalidate_model_cache()
assert svc._model_dirty is True

def test_skip_rebuild_when_cache_valid(self):
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
svc._model_dirty = False
svc._model_built_at = time.monotonic()
mock_data_svc = MagicMock()
mock_data_svc.locate = AsyncMock(return_value=[])
svc.get_service = MagicMock(return_value=mock_data_svc)
svc.get_config = MagicMock(return_value=3600)
asyncio.run(svc.build_model())
mock_data_svc.locate.assert_not_called()

def test_rebuild_when_dirty(self):
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
svc._model_dirty = True
mock_data_svc = MagicMock()
mock_data_svc.locate = AsyncMock(return_value=[])
svc.get_service = MagicMock(return_value=mock_data_svc)
svc.get_config = MagicMock(return_value=3600)
asyncio.run(svc.build_model())
mock_data_svc.locate.assert_called_once()
assert svc._model_dirty is False
assert svc._model_built_at > 0

def test_rebuild_when_ttl_expired(self):
"""Rebuild must occur when cache TTL has expired even if not dirty."""
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
ttl = 60
svc._model_dirty = False
# Set built-at to ttl+1 seconds in the past so it is expired.
svc._model_built_at = time.monotonic() - (ttl + 1)
mock_data_svc = MagicMock()
mock_data_svc.locate = AsyncMock(return_value=[])
svc.get_service = MagicMock(return_value=mock_data_svc)
svc.get_config = MagicMock(return_value=ttl)
asyncio.run(svc.build_model())
mock_data_svc.locate.assert_called_once()
assert svc._model_dirty is False

def test_cache_ttl_as_string_is_coerced(self):
"""get_config() may return a string; it must be coerced to int without error."""
with patch.object(LearningService, 'add_service', return_value=MagicMock()):
with patch.object(LearningService, 'add_parsers', return_value=[]):
svc = LearningService()
svc._model_dirty = True
mock_data_svc = MagicMock()
mock_data_svc.locate = AsyncMock(return_value=[])
svc.get_service = MagicMock(return_value=mock_data_svc)
# Return TTL as a string (common when loaded from config files).
svc.get_config = MagicMock(return_value='3600')
asyncio.run(svc.build_model())
mock_data_svc.locate.assert_called_once()
Loading