From c70a8a23eced302943e223002af9d813c23e33c7 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Fri, 16 Jan 2026 12:01:30 +0300 Subject: [PATCH 1/4] added cache func and cached get_base_directories --- app/ldap_protocol/utils/helpers.py | 76 +++++++++++++++++++++++++++++- app/ldap_protocol/utils/queries.py | 2 + 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..73b061050 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -130,6 +130,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import asyncio import functools import hashlib import random @@ -138,19 +139,23 @@ import time from calendar import timegm from datetime import datetime +from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Callable +from typing import Any, Callable, Iterable from zoneinfo import ZoneInfo from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm.attributes import instance_state from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable from entities import Directory +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + def validate_entry(entry: str) -> bool: """Validate entry str. @@ -402,3 +407,72 @@ async def explain_query( for row in await session.execute(explain(query, analyze=True)) ), ) + + +def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool: + def _check(value: Any) -> bool: + try: + state = instance_state(value) + return bool(state.expired_attributes) + except AttributeError: + return False + + def _walk(value: Any, depth: int = 0) -> bool: + if depth > max_depth: + return False + + if _check(value): + return True + + if isinstance(value, str | bytes | bytearray): + return False + + if isinstance(value, dict): + return any(_walk(v, depth + 1) for v in value.values()) + + if isinstance(value, Iterable): + return any(_walk(v, depth + 1) for v in value) + + return False + + return _walk(obj) + + +def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: + cache: dict = {} + locks: dict = {} + + def _is_value_expired( + value: Any, + now: float, + expires_at: float | None, + ) -> bool: + return bool( + expires_at and expires_at < now or has_expired_sqla_objs(value), + ) + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> Any: + key = (args, tuple(sorted(kwargs.items()))) + now = time.monotonic() + if key not in locks: + locks[key] = asyncio.Lock() + + async with locks[key]: + if key in cache: + value, expires_at = cache[key] + if not _is_value_expired(value, now, expires_at): + return value + else: + del cache[key] + + result = await func(*args, **kwargs) + expires_at = now + ttl if ttl else None + cache[key] = (result, expires_at) + del locks[key] + return result + + return wrapper + + return decorator diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..4e40514aa 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -27,6 +27,7 @@ from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( + async_lru_cache, create_integer_hash, create_object_sid, dn_is_base_directory, @@ -35,6 +36,7 @@ ) +@async_lru_cache() async def get_base_directories(session: AsyncSession) -> list[Directory]: """Get base domain directories.""" result = await session.execute( From 6e414332b29c9d98acc3c50a368f6db9524071ab Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:22:19 +0300 Subject: [PATCH 2/4] refactor: change get_base_directories output to dto --- .../versions/16a9fa2c1f1e_rename_readonly_group.py | 7 ++++--- .../71e642808369_add_directory_is_system.py | 9 +++++++-- app/entities.py | 4 ++-- app/ldap_protocol/auth/setup_gateway.py | 10 ++++++---- app/ldap_protocol/ldap_requests/add.py | 2 +- app/ldap_protocol/ldap_requests/modify_dn.py | 2 +- app/ldap_protocol/ldap_requests/search.py | 3 ++- app/ldap_protocol/roles/role_use_case.py | 3 ++- app/ldap_protocol/utils/helpers.py | 13 ++++++++++--- app/ldap_protocol/utils/queries.py | 13 ++++++++----- pyproject.toml | 1 + tests/test_api/test_auth/test_router.py | 2 +- 12 files changed, 45 insertions(+), 24 deletions(-) diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index b331dddd5..cf6ac280e 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -43,8 +43,8 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = READ_ONLY_GROUP_NAME - - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) @@ -92,7 +92,8 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 ro_dir.name = "readonly domain controllers" - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py index 2526190e4..48ece1bc4 100644 --- a/app/alembic/versions/71e642808369_add_directory_is_system.py +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -56,8 +56,13 @@ async def _indicate_system_directories( if not base_dn_list: return - for base_dn in base_dn_list: - base_dn.is_system = True + await session.execute( + update(Directory) + .where( + qa(Directory.parent_id).is_(None), + ) + .values(is_system=True), + ) await session.flush() diff --git a/app/entities.py b/app/entities.py index 53f5c95e9..acc32675a 100644 --- a/app/entities.py +++ b/app/entities.py @@ -270,10 +270,10 @@ def path_dn(self) -> str: def create_path( self, - parent: Directory | None = None, + parent_path: list | None = None, dn: str = "cn", ) -> None: - pre = parent.path if parent else [] + pre = parent_path or [] self.path = pre + [self.get_dn(dn)] self.depth = len(self.path) self.rdname = dn diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index b5bfe580a..a9aefd094 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -11,6 +11,7 @@ from sqlalchemy import exists, select from sqlalchemy.ext.asyncio import AsyncSession +from dtos import DirectoryDTO from entities import Attribute, Directory, Group, NetworkPolicy, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, @@ -124,21 +125,22 @@ async def create_dir( self, data: dict, is_system: bool, - domain: Directory, - parent: Directory | None = None, + domain: Directory | DirectoryDTO, + parent: Directory | DirectoryDTO | None = None, ) -> None: """Create data recursively.""" dir_ = Directory( is_system=is_system, object_class=data["object_class"], name=data["name"], - parent=parent, ) dir_.groups = [] - dir_.create_path(parent, dir_.get_dn_prefix()) + path = parent.path if parent else [] + dir_.create_path(path, dir_.get_dn_prefix()) self._session.add(dir_) await self._session.flush() + dir_.parent_id = parent.id if parent else None await self._session.refresh(dir_, ["id"]) self._session.add( diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 6f29fe9af..b000b3798 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -211,7 +211,7 @@ async def handle( # noqa: C901 parent=parent, ) - new_dir.create_path(parent, new_dn) + new_dir.create_path(parent.path, new_dn) ctx.session.add(new_dir) await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 7c315eadd..0ac906ce8 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -199,7 +199,7 @@ async def handle( return directory.parent = parent_dir - directory.create_path(directory.parent, dn=new_dn) + directory.create_path(parent_dir.path, dn=new_dn) try: await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 1f9579dc2..ebd5730ac 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -23,6 +23,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select +from dtos import DirectoryDTO from entities import ( Attribute, AttributeType, @@ -367,7 +368,7 @@ def _mutate_query_with_attributes_to_load( def _build_query( self, - base_directories: list[Directory], + base_directories: list[DirectoryDTO], user: UserSchema, access_manager: AccessManager, ) -> Select[tuple[Directory]]: diff --git a/app/ldap_protocol/roles/role_use_case.py b/app/ldap_protocol/roles/role_use_case.py index 1e978a3f1..08951cca7 100644 --- a/app/ldap_protocol/roles/role_use_case.py +++ b/app/ldap_protocol/roles/role_use_case.py @@ -6,6 +6,7 @@ from sqlalchemy import and_, insert, literal, or_, select +from dtos import DirectoryDTO from entities import AccessControlEntry, AceType, Directory, Role from enums import AuthorizationRules, RoleConstants, RoleScope from ldap_protocol.utils.queries import get_base_directories @@ -40,7 +41,7 @@ def __init__( async def inherit_parent_aces( self, - parent_directory: Directory, + parent_directory: Directory | DirectoryDTO, directory: Directory, ) -> None: """Inherit access control entries from the parent directory. diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 73b061050..4fb72bf80 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -152,6 +152,7 @@ from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable +from dtos import DirectoryDTO from entities import Directory DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes @@ -197,12 +198,18 @@ def validate_attribute(attribute: str) -> bool: ) -def is_dn_in_base_directory(base_directory: Directory, entry: str) -> bool: +def is_dn_in_base_directory( + base_directory: Directory | DirectoryDTO, + entry: str, +) -> bool: """Check if an entry in a base dn.""" return entry.lower().endswith(base_directory.path_dn.lower()) -def dn_is_base_directory(base_directory: Directory, entry: str) -> bool: +def dn_is_base_directory( + base_directory: Directory | DirectoryDTO, + entry: str, +) -> bool: """Check if an entry is a base dn.""" return base_directory.path_dn.lower() == entry.lower() @@ -307,7 +314,7 @@ def string_to_sid(sid_string: str) -> bytes: def create_object_sid( - domain: Directory, + domain: Directory | DirectoryDTO, rid: int, reserved: bool = False, ) -> str: diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 4e40514aa..694b52b69 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import InstrumentedAttribute, joinedload, selectinload from sqlalchemy.sql.expression import ColumnElement +from dtos import DirectoryDTO, _directory_sqla_obj_to_dto from entities import Attribute, Directory, Group, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, @@ -37,13 +38,15 @@ @async_lru_cache() -async def get_base_directories(session: AsyncSession) -> list[Directory]: +async def get_base_directories(session: AsyncSession) -> list[DirectoryDTO]: """Get base domain directories.""" result = await session.execute( select(Directory) .filter(qa(Directory.parent_id).is_(None)), ) # fmt: skip - return list(result.scalars().all()) + return [ + _directory_sqla_obj_to_dto(dir_) for dir_ in result.scalars().all() + ] async def get_user(session: AsyncSession, name: str) -> User | None: @@ -364,14 +367,14 @@ async def create_group( dir_ = Directory( object_class="", name=name, - parent=parent, + parent_id=parent.id, ) session.add(dir_) await session.flush() - await session.refresh(dir_, ["id"]) + await session.refresh(dir_, ["id", "parent_id", "parent"]) group = Group(directory_id=dir_.id) - dir_.create_path(parent) + dir_.create_path(parent.path) session.add(group) dir_.object_sid = create_object_sid( diff --git a/pyproject.toml b/pyproject.toml index f7adf0e26..c25e9d974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -222,6 +222,7 @@ known-first-party = [ "extra", "enums", "errors", + "dtos", ] known-third-party = [ "alembic", # https://github.com/astral-sh/ruff/issues/10519 diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index 26c0e4523..c13c0a5a6 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -13,7 +13,6 @@ from fastapi import status from httpx import AsyncClient from jose import jwt -from password_utils import PasswordUtils from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -26,6 +25,7 @@ from ldap_protocol.ldap_requests.modify import Operation from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.queries import get_search_path +from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds From 97a98245bff312cfdd99cf8c7c115f1f4ffa20fa Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:28:44 +0300 Subject: [PATCH 3/4] refactor: deleted has_expired_sqla_objs --- app/ldap_protocol/utils/helpers.py | 43 ++---------------------------- app/ldap_protocol/utils/queries.py | 2 +- 2 files changed, 3 insertions(+), 42 deletions(-) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 4fb72bf80..bcc9845a9 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -142,13 +142,12 @@ from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Any, Callable, Iterable +from typing import Any, Callable from zoneinfo import ZoneInfo from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm.attributes import instance_state from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable @@ -416,48 +415,10 @@ async def explain_query( ) -def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool: - def _check(value: Any) -> bool: - try: - state = instance_state(value) - return bool(state.expired_attributes) - except AttributeError: - return False - - def _walk(value: Any, depth: int = 0) -> bool: - if depth > max_depth: - return False - - if _check(value): - return True - - if isinstance(value, str | bytes | bytearray): - return False - - if isinstance(value, dict): - return any(_walk(v, depth + 1) for v in value.values()) - - if isinstance(value, Iterable): - return any(_walk(v, depth + 1) for v in value) - - return False - - return _walk(obj) - - def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: cache: dict = {} locks: dict = {} - def _is_value_expired( - value: Any, - now: float, - expires_at: float | None, - ) -> bool: - return bool( - expires_at and expires_at < now or has_expired_sqla_objs(value), - ) - def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args: tuple, **kwargs: dict) -> Any: @@ -469,7 +430,7 @@ async def wrapper(*args: tuple, **kwargs: dict) -> Any: async with locks[key]: if key in cache: value, expires_at = cache[key] - if not _is_value_expired(value, now, expires_at): + if not expires_at or expires_at > now: return value else: del cache[key] diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 694b52b69..d2a8e7420 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -371,7 +371,7 @@ async def create_group( ) session.add(dir_) await session.flush() - await session.refresh(dir_, ["id", "parent_id", "parent"]) + await session.refresh(dir_, ["id"]) group = Group(directory_id=dir_.id) dir_.create_path(parent.path) From ded6d9d065f50e7417ee07e17100af4ccb4ea257 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:31:09 +0300 Subject: [PATCH 4/4] add: dtos file --- app/dtos.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 app/dtos.py diff --git a/app/dtos.py b/app/dtos.py new file mode 100644 index 000000000..7efee30ca --- /dev/null +++ b/app/dtos.py @@ -0,0 +1,91 @@ +"""Module for dtos.""" + +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import ClassVar + +from adaptix.conversion import get_converter + +from entities import Directory, DistinguishedNamePrefix + + +@dataclass +class DirectoryDTO: + id: int + name: str + is_system: bool + object_sid: str + object_guid: uuid.UUID + parent_id: int | None + entity_type_id: int | None + object_class: str + rdname: str + created_at: datetime | None + updated_at: datetime | None + depth: int + password_policy_id: int | None + path: list[str] + + search_fields: ClassVar[dict[str, str]] = { + "name": "name", + "objectguid": "objectGUID", + "objectsid": "objectSid", + } + ro_fields: ClassVar[set[str]] = { + "uid", + "whencreated", + "lastlogon", + "authtimestamp", + "objectguid", + "objectsid", + "entitytypename", + } + + def get_dn_prefix(self) -> DistinguishedNamePrefix: + return { + "organizationalUnit": "ou", + "domain": "dc", + "container": "cn", + }.get( + self.object_class, + "cn", + ) # type: ignore + + def get_dn(self, dn: str = "cn") -> str: + return f"{dn}={self.name}" + + @property + def is_domain(self) -> bool: + return not self.parent_id and self.object_class == "domain" + + @property + def host_principal(self) -> str: + return f"host/{self.name}" + + @property + def path_dn(self) -> str: + return ",".join(reversed(self.path)) + + def create_path( + self, + parent: Directory | None = None, + dn: str = "cn", + ) -> None: + pre = parent.path if parent else [] + self.path = pre + [self.get_dn(dn)] + self.depth = len(self.path) + self.rdname = dn + + @property + def relative_id(self) -> str: + """Get RID from objectSid. + + Relative Identifier (RID) is the last sub-authority value of a SID. + """ + if "-" in self.object_sid: + return self.object_sid.split("-")[-1] + return "" + + +_directory_sqla_obj_to_dto = get_converter(Directory, DirectoryDTO)