diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index b331dddd5..cf6ac280e 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -43,8 +43,8 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = READ_ONLY_GROUP_NAME - - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) @@ -92,7 +92,8 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 ro_dir.name = "readonly domain controllers" - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py index 2526190e4..48ece1bc4 100644 --- a/app/alembic/versions/71e642808369_add_directory_is_system.py +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -56,8 +56,13 @@ async def _indicate_system_directories( if not base_dn_list: return - for base_dn in base_dn_list: - base_dn.is_system = True + await session.execute( + update(Directory) + .where( + qa(Directory.parent_id).is_(None), + ) + .values(is_system=True), + ) await session.flush() diff --git a/app/dtos.py b/app/dtos.py new file mode 100644 index 000000000..41d090c9e --- /dev/null +++ b/app/dtos.py @@ -0,0 +1,81 @@ +"""Module for dtos.""" + +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import ClassVar + +from adaptix.conversion import get_converter + +from entities import Directory, DistinguishedNamePrefix + + +@dataclass +class DirectoryDTO: + id: int + name: str + is_system: bool + object_sid: str + object_guid: uuid.UUID + parent_id: int | None + entity_type_id: int | None + object_class: str + rdname: str + created_at: datetime | None + updated_at: datetime | None + depth: int + password_policy_id: int | None + path: list[str] + + search_fields: ClassVar[dict[str, str]] = { + "name": "name", + "objectguid": "objectGUID", + "objectsid": "objectSid", + } + ro_fields: ClassVar[set[str]] = { + "uid", + "whencreated", + "lastlogon", + "authtimestamp", + "objectguid", + "objectsid", + "entitytypename", + } + + def get_dn_prefix(self) -> DistinguishedNamePrefix: + return { + "organizationalUnit": "ou", + "domain": "dc", + "container": "cn", + }.get( + self.object_class, + "cn", + ) # type: ignore + + def get_dn(self, dn: str = "cn") -> str: + return f"{dn}={self.name}" + + @property + def is_domain(self) -> bool: + return not self.parent_id and self.object_class == "domain" + + @property + def host_principal(self) -> str: + return f"host/{self.name}" + + @property + def path_dn(self) -> str: + return ",".join(reversed(self.path)) + + @property + def relative_id(self) -> str: + """Get RID from objectSid. + + Relative Identifier (RID) is the last sub-authority value of a SID. + """ + if "-" in self.object_sid: + return self.object_sid.split("-")[-1] + return "" + + +_directory_sqla_obj_to_dto = get_converter(Directory, DirectoryDTO) diff --git a/app/entities.py b/app/entities.py index 53f5c95e9..acc32675a 100644 --- a/app/entities.py +++ b/app/entities.py @@ -270,10 +270,10 @@ def path_dn(self) -> str: def create_path( self, - parent: Directory | None = None, + parent_path: list | None = None, dn: str = "cn", ) -> None: - pre = parent.path if parent else [] + pre = parent_path or [] self.path = pre + [self.get_dn(dn)] self.depth = len(self.path) self.rdname = dn diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index b5bfe580a..73be8a2e0 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -11,11 +11,13 @@ from sqlalchemy import exists, select from sqlalchemy.ext.asyncio import AsyncSession +from dtos import DirectoryDTO from entities import Attribute, Directory, Group, NetworkPolicy, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils @@ -113,6 +115,7 @@ async def setup_enviroment( domain=domain, parent=domain, ) + base_directories_cache.clear() except Exception: import traceback @@ -124,21 +127,22 @@ async def create_dir( self, data: dict, is_system: bool, - domain: Directory, - parent: Directory | None = None, + domain: Directory | DirectoryDTO, + parent: Directory | DirectoryDTO | None = None, ) -> None: """Create data recursively.""" dir_ = Directory( is_system=is_system, object_class=data["object_class"], name=data["name"], - parent=parent, ) dir_.groups = [] - dir_.create_path(parent, dir_.get_dn_prefix()) + path = parent.path if parent else [] + dir_.create_path(path, dir_.get_dn_prefix()) self._session.add(dir_) await self._session.flush() + dir_.parent_id = parent.id if parent else None await self._session.refresh(dir_, ["id"]) self._session.add( diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 6f29fe9af..b000b3798 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -211,7 +211,7 @@ async def handle( # noqa: C901 parent=parent, ) - new_dir.create_path(parent, new_dn) + new_dir.create_path(parent.path, new_dn) ctx.session.add(new_dir) await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 7c315eadd..0ac906ce8 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -199,7 +199,7 @@ async def handle( return directory.parent = parent_dir - directory.create_path(directory.parent, dn=new_dn) + directory.create_path(parent_dir.path, dn=new_dn) try: await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 1f9579dc2..ebd5730ac 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -23,6 +23,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select +from dtos import DirectoryDTO from entities import ( Attribute, AttributeType, @@ -367,7 +368,7 @@ def _mutate_query_with_attributes_to_load( def _build_query( self, - base_directories: list[Directory], + base_directories: list[DirectoryDTO], user: UserSchema, access_manager: AccessManager, ) -> Select[tuple[Directory]]: diff --git a/app/ldap_protocol/roles/role_use_case.py b/app/ldap_protocol/roles/role_use_case.py index 1e978a3f1..08951cca7 100644 --- a/app/ldap_protocol/roles/role_use_case.py +++ b/app/ldap_protocol/roles/role_use_case.py @@ -6,6 +6,7 @@ from sqlalchemy import and_, insert, literal, or_, select +from dtos import DirectoryDTO from entities import AccessControlEntry, AceType, Directory, Role from enums import AuthorizationRules, RoleConstants, RoleScope from ldap_protocol.utils.queries import get_base_directories @@ -40,7 +41,7 @@ def __init__( async def inherit_parent_aces( self, - parent_directory: Directory, + parent_directory: Directory | DirectoryDTO, directory: Directory, ) -> None: """Inherit access control entries from the parent directory. diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py new file mode 100644 index 000000000..446998bbc --- /dev/null +++ b/app/ldap_protocol/utils/async_cache.py @@ -0,0 +1,42 @@ +"""Async cache implementation.""" +import time +from functools import wraps +from typing import Callable, Generic, TypeVar + +from dtos import DirectoryDTO + +T = TypeVar("T") +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + + +class AsyncTTLCache(Generic[T]): + def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: + self._ttl = ttl + self._value: T | None = None + self._expires_at: float | None = None + + def clear(self) -> None: + self._value = None + self._expires_at = None + + def __call__(self, func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> T: + if self._value is not None: + if not self._expires_at or self._expires_at > time.monotonic(): + return self._value + self.clear() + + result = await func(*args, **kwargs) + + self._value = result + self._expires_at = ( + time.monotonic() + self._ttl if self._ttl else None + ) + + return result + + return wrapper + + +base_directories_cache = AsyncTTLCache[list[DirectoryDTO]]() diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..d02b38e49 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -130,6 +130,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import asyncio import functools import hashlib import random @@ -138,9 +139,10 @@ import time from calendar import timegm from datetime import datetime +from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Callable +from typing import Any, Callable, Generic, TypeVar from zoneinfo import ZoneInfo from loguru import logger @@ -149,6 +151,7 @@ from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable +from dtos import DirectoryDTO from entities import Directory @@ -192,12 +195,18 @@ def validate_attribute(attribute: str) -> bool: ) -def is_dn_in_base_directory(base_directory: Directory, entry: str) -> bool: +def is_dn_in_base_directory( + base_directory: DirectoryDTO, + entry: str, +) -> bool: """Check if an entry in a base dn.""" return entry.lower().endswith(base_directory.path_dn.lower()) -def dn_is_base_directory(base_directory: Directory, entry: str) -> bool: +def dn_is_base_directory( + base_directory: DirectoryDTO, + entry: str, +) -> bool: """Check if an entry is a base dn.""" return base_directory.path_dn.lower() == entry.lower() @@ -302,7 +311,7 @@ def string_to_sid(sid_string: str) -> bytes: def create_object_sid( - domain: Directory, + domain: Directory | DirectoryDTO, rid: int, reserved: bool = False, ) -> str: @@ -402,3 +411,28 @@ async def explain_query( for row in await session.execute(explain(query, analyze=True)) ), ) + + +# def async_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: +# """Cache for get_base_directories""" +# cache: list[tuple[list[DirectoryDTO], float | None]] = [] + +# def decorator(func: Callable) -> Callable: +# @wraps(func) +# async def wrapper(*args: tuple, **kwargs: dict) -> list[DirectoryDTO]: +# if cache: +# value, expires_at = cache[0] +# if not expires_at or expires_at > time.monotonic(): +# return value +# else: +# cache.clear() + +# result = await func(*args, **kwargs) +# expires_at = time.monotonic() + ttl if ttl else None +# cache.append((result, expires_at)) + +# return result + +# return wrapper + +# return decorator diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..64853a708 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import InstrumentedAttribute, joinedload, selectinload from sqlalchemy.sql.expression import ColumnElement +from dtos import DirectoryDTO, _directory_sqla_obj_to_dto from entities import Attribute, Directory, Group, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, @@ -25,6 +26,7 @@ queryable_attr as qa, ) +from .async_cache import base_directories_cache from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( create_integer_hash, @@ -35,13 +37,16 @@ ) -async def get_base_directories(session: AsyncSession) -> list[Directory]: +@base_directories_cache +async def get_base_directories(session: AsyncSession) -> list[DirectoryDTO]: """Get base domain directories.""" result = await session.execute( select(Directory) .filter(qa(Directory.parent_id).is_(None)), ) # fmt: skip - return list(result.scalars().all()) + return [ + _directory_sqla_obj_to_dto(dir_) for dir_ in result.scalars().all() + ] async def get_user(session: AsyncSession, name: str) -> User | None: @@ -362,14 +367,14 @@ async def create_group( dir_ = Directory( object_class="", name=name, - parent=parent, + parent_id=parent.id, ) session.add(dir_) await session.flush() await session.refresh(dir_, ["id"]) group = Group(directory_id=dir_.id) - dir_.create_path(parent) + dir_.create_path(parent.path) session.add(group) dir_.object_sid = create_object_sid( diff --git a/pyproject.toml b/pyproject.toml index f7adf0e26..c25e9d974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -222,6 +222,7 @@ known-first-party = [ "extra", "enums", "errors", + "dtos", ] known-third-party = [ "alembic", # https://github.com/astral-sh/ruff/issues/10519 diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index 26c0e4523..c13c0a5a6 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -13,7 +13,6 @@ from fastapi import status from httpx import AsyncClient from jose import jwt -from password_utils import PasswordUtils from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -26,6 +25,7 @@ from ldap_protocol.ldap_requests.modify import Operation from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.queries import get_search_path +from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds