diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 46acf043..854d71c0 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -728,11 +728,21 @@ def get_user(self) -> UserResponse: ### get\_user\_data\_credentials ```python -get_user_data_credentials() -> UserDataCredentials +get_user_data_credentials( + duration: int = DEFAULT_FS_CREDENTIAL_DURATION, +) -> UserDataCredentials ``` Retrieves user data credentials for secondary storage access. +**Parameters:** + +* **`duration`** + (`int`, default: + `DEFAULT_FS_CREDENTIAL_DURATION` + ) + –Credential lifetime in seconds (default: 4 hours) + **Returns:** * `UserDataCredentials` @@ -740,14 +750,19 @@ Retrieves user data credentials for secondary storage access. ```python -def get_user_data_credentials(self) -> UserDataCredentials: +def get_user_data_credentials( + self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION +) -> UserDataCredentials: """ Retrieves user data credentials for secondary storage access. + Args: + duration: Credential lifetime in seconds (default: 4 hours) + Returns: The user data credentials object. """ - response = self.request("GET", "/user-data/credentials") + response = self._request("GET", "/user-data/credentials", params={"duration": duration}) return UserDataCredentials(**response.json()) ``` diff --git a/docs/sdk/artifact.mdx b/docs/sdk/artifact.mdx index 6e902b22..7c6495f1 100644 --- a/docs/sdk/artifact.mdx +++ b/docs/sdk/artifact.mdx @@ -244,7 +244,10 @@ ArtifactStorage --------------- ```python -ArtifactStorage(file_system: AbstractFileSystem) +ArtifactStorage( + file_system: AbstractFileSystem, + credential_refresher: Callable[[], bool] | None = None, +) ``` Storage for artifacts with efficient handling of large files and directories. @@ -260,17 +263,28 @@ Initialize artifact storage with a file system and prefix path. * **`file_system`** (`AbstractFileSystem`) –FSSpec-compatible file system +* **`credential_refresher`** + (`Callable[[], bool] | None`, default: + `None` + ) + –Optional function to refresh credentials when it's about to expire ```python -def __init__(self, file_system: fsspec.AbstractFileSystem): +def __init__( + self, + file_system: fsspec.AbstractFileSystem, + credential_refresher: t.Callable[[], bool] | None = None, +): """ Initialize artifact storage with a file system and prefix path. Args: file_system: FSSpec-compatible file system + credential_refresher: Optional function to refresh credentials when it's about to expire """ self._file_system = file_system + self._credential_refresher = credential_refresher ``` @@ -464,6 +478,7 @@ Store a file in the storage system, using multipart upload for large files. ```python +@with_credential_refresh def store_file(self, file_path: Path, target_key: str) -> str: """ Store a file in the storage system, using multipart upload for large files. diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 6b7c479c..9e42fb5e 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -65,6 +65,9 @@ def __init__( self._fs_prefix: str = ".dreadnode/storage/" self._initialized = False + + self._credentials: UserDataCredentials | None = None + self._credentials_expiry: datetime | None = None ``` @@ -380,6 +383,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: tracer=self._get_tracer(), file_system=self._fs, prefix_path=self._fs_prefix, + credential_refresher=self._refresh_storage_credentials if self._credentials else None, ) ``` @@ -524,18 +528,21 @@ def initialize(self) -> None: # ) # ) - credentials = self._api.get_user_data_credentials() - resolved_endpoint = resolve_endpoint(credentials.endpoint) + self._credentials = self._api.get_user_data_credentials( + duration=DEFAULT_FS_CREDENTIAL_DURATION + ) + self._credentials_expiry = self._credentials.expiration + resolved_endpoint = self._resolve_endpoint(self._credentials.endpoint) self._fs = S3FileSystem( - key=credentials.access_key_id, - secret=credentials.secret_access_key, - token=credentials.session_token, + key=self._credentials.access_key_id, + secret=self._credentials.secret_access_key, + token=self._credentials.session_token, client_kwargs={ "endpoint_url": resolved_endpoint, - "region_name": credentials.region, + "region_name": self._credentials.region, }, ) - self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/" + self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/" self._logfire = logfire.configure( local=not self.is_default, @@ -1723,6 +1730,7 @@ def run( file_system=self._fs, prefix_path=self._fs_prefix, autolog=autolog, + credential_refresher=self._refresh_storage_credentials if self._credentials else None, ) ``` diff --git a/docs/sdk/metric.mdx b/docs/sdk/metric.mdx index 57a9257e..0a537149 100644 --- a/docs/sdk/metric.mdx +++ b/docs/sdk/metric.mdx @@ -31,8 +31,8 @@ Metric Metric( value: float, step: int = 0, - timestamp: datetime = lambda: datetime.now( - timezone.utc + timestamp: datetime = ( + lambda: datetime.now(timezone.utc) )(), attributes: JsonDict = dict(), ) diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 5b311482..f2c0f78f 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -36,7 +36,11 @@ process_run, process_task, ) -from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL +from dreadnode.constants import ( + DEFAULT_FS_CREDENTIAL_DURATION, + DEFAULT_MAX_POLL_TIME, + DEFAULT_POLL_INTERVAL, +) from dreadnode.util import logger from dreadnode.version import VERSION @@ -306,7 +310,9 @@ def get_run(self, run: str | ULID) -> Run: TraceFormat = t.Literal["tree", "flat"] @t.overload - def get_run_tasks(self, run: str | ULID, *, format: t.Literal["tree"]) -> list[TaskTree]: ... + def get_run_tasks( + self, run: str | ULID, *, format: t.Literal["tree"] + ) -> list[TaskTree]: ... @t.overload def get_run_tasks( @@ -334,7 +340,9 @@ def get_run_tasks( return tasks if format == "flat" else convert_flat_tasks_to_tree(tasks) @t.overload - def get_run_trace(self, run: str | ULID, *, format: t.Literal["tree"]) -> list[TraceTree]: ... + def get_run_trace( + self, run: str | ULID, *, format: t.Literal["tree"] + ) -> list[TraceTree]: ... @t.overload def get_run_trace( @@ -517,12 +525,19 @@ def export_timeseries( # User data access - def get_user_data_credentials(self) -> UserDataCredentials: + def get_user_data_credentials( + self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION + ) -> UserDataCredentials: """ Retrieves user data credentials for secondary storage access. + Args: + duration: Credential lifetime in seconds (default: 4 hours) + Returns: The user data credentials object. """ - response = self.request("GET", "/user-data/credentials") + response = self._request( + "GET", "/user-data/credentials", params={"duration": duration} + ) return UserDataCredentials(**response.json()) diff --git a/dreadnode/artifact/storage.py b/dreadnode/artifact/storage.py index 3b482589..f6fa180a 100644 --- a/dreadnode/artifact/storage.py +++ b/dreadnode/artifact/storage.py @@ -4,10 +4,12 @@ """ import hashlib +import typing as t from pathlib import Path import fsspec # type: ignore[import-untyped] +from dreadnode.storage_utils import with_credential_refresh from dreadnode.util import logger CHUNK_SIZE = 8 * 1024 * 1024 # 8MB @@ -22,15 +24,27 @@ class ArtifactStorage: - Batch uploads for directories handled by fsspec """ - def __init__(self, file_system: fsspec.AbstractFileSystem): + def __init__( + self, + file_system: fsspec.AbstractFileSystem, + credential_refresher: t.Callable[[], bool] | None = None, + ): """ Initialize artifact storage with a file system and prefix path. Args: file_system: FSSpec-compatible file system + credential_refresher: Optional function to refresh credentials when it's about to expire """ self._file_system = file_system + self._credential_refresher = credential_refresher + def _refresh_credentials_if_needed(self) -> None: + """Refresh credentials if refresher is available.""" + if self._credential_refresher: + self._credential_refresher() + + @with_credential_refresh def store_file(self, file_path: Path, target_key: str) -> str: """ Store a file in the storage system, using multipart upload for large files. diff --git a/dreadnode/constants.py b/dreadnode/constants.py index d2570d51..a2b241a1 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -54,5 +54,10 @@ # path to the user configuration file USER_CONFIG_PATH = pathlib.Path( # allow overriding the user config file via env variable - os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" + os.getenv("DREADNODE_USER_CONFIG_FILE") + or pathlib.Path.home() / ".dreadnode" / "config" ) + +# Default values for the file system credential management +DEFAULT_FS_CREDENTIAL_DURATION = 14400 # 4 hours in seconds +FS_CREDENTIAL_REFRESH_BUFFER = 300 # 5 minutes in seconds diff --git a/dreadnode/main.py b/dreadnode/main.py index 68d48e47..e6ec7734 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -26,6 +26,7 @@ from dreadnode.api.client import ApiClient from dreadnode.config import UserConfig from dreadnode.constants import ( + DEFAULT_FS_CREDENTIAL_DURATION, DEFAULT_SERVER_URL, ENV_API_KEY, ENV_API_TOKEN, @@ -35,6 +36,7 @@ ENV_PROJECT, ENV_SERVER, ENV_SERVER_URL, + FS_CREDENTIAL_REFRESH_BUFFER, ) from dreadnode.metric import ( Metric, @@ -64,7 +66,7 @@ Inherited, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint +from dreadnode.util import clean_str, handle_internal_errors from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -73,6 +75,8 @@ from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.trace import Tracer + from dreadnode.api.models import UserDataCredentials + ToObject = t.Literal["task-or-run", "run"] @@ -138,6 +142,9 @@ def __init__( self._initialized = False + self._credentials: UserDataCredentials | None = None + self._credentials_expiry: datetime | None = None + def _get_profile_server(self, profile: str | None = None) -> str | None: with contextlib.suppress(Exception): user_config = UserConfig.read() @@ -264,7 +271,9 @@ def configure( # Log config information for clarity if self.server or self.token or self.local_dir: destination = self.server or DEFAULT_SERVER_URL or "local storage" - rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + rich.print( + f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})" + ) # Warn the user if the profile didn't resolve elif active_profile and not (self.server or self.token): @@ -348,18 +357,21 @@ def initialize(self) -> None: # ) # ) - credentials = self._api.get_user_data_credentials() - resolved_endpoint = resolve_endpoint(credentials.endpoint) + self._credentials = self._api.get_user_data_credentials( + duration=DEFAULT_FS_CREDENTIAL_DURATION + ) + self._credentials_expiry = self._credentials.expiration + resolved_endpoint = self._resolve_endpoint(self._credentials.endpoint) self._fs = S3FileSystem( - key=credentials.access_key_id, - secret=credentials.secret_access_key, - token=credentials.session_token, + key=self._credentials.access_key_id, + secret=self._credentials.secret_access_key, + token=self._credentials.session_token, client_kwargs={ "endpoint_url": resolved_endpoint, - "region_name": credentials.region, + "region_name": self._credentials.region, }, ) - self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/" + self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/" self._logfire = logfire.configure( local=not self.is_default, @@ -406,6 +418,46 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie return self._api + def _refresh_storage_credentials(self) -> bool: + """Refresh storage credentials if they are about to expire.""" + if not self._api or not self._credentials: + return False + + now = datetime.now(timezone.utc) + + if ( + self._credentials_expiry is None + or (self._credentials_expiry - now).total_seconds() + < FS_CREDENTIAL_REFRESH_BUFFER + ): + try: + logger.info("Refreshing storage credentials") + self._credentials = self._api.get_user_data_credentials( + duration=DEFAULT_FS_CREDENTIAL_DURATION + ) + self._credentials_expiry = self._credentials.expiration + + resolved_endpoint = self._resolve_endpoint(self._credentials.endpoint) + self._fs = S3FileSystem( + key=self._credentials.access_key_id, + secret=self._credentials.secret_access_key, + token=self._credentials.session_token, + client_kwargs={ + "endpoint_url": resolved_endpoint, + "region_name": self._credentials.region, + }, + ) + logger.info( + f"Storage credentials refreshed, valid until {self._credentials_expiry}" + ) + return True # noqa: TRY300 + + except Exception as e: # noqa: BLE001 + logger.error(f"Failed to refresh storage credentials: {e}") + return False + + return True + def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer": return self._logfire._tracer_provider.get_tracer( # noqa: SLF001 self.otel_scope, @@ -624,7 +676,9 @@ def make_task( attributes=_attributes, func=t.cast("t.Callable[P, R]", func), scorers=[ - scorer if isinstance(scorer, Scorer) else Scorer.from_callable(scorer) + scorer + if isinstance(scorer, Scorer) + else Scorer.from_callable(scorer) for scorer in scorers or [] ], tags=list(tags or []), @@ -778,6 +832,9 @@ def run( file_system=self._fs, prefix_path=self._fs_prefix, autolog=autolog, + credential_refresher=self._refresh_storage_credentials + if self._credentials + else None, ) def get_run_context(self) -> RunContext: @@ -824,6 +881,9 @@ def continue_run(self, run_context: RunContext) -> RunSpan: tracer=self._get_tracer(), file_system=self._fs, prefix_path=self._fs_prefix, + credential_refresher=self._refresh_storage_credentials + if self._credentials + else None, ) def tag(self, *tag: str, to: ToObject = "task-or-run") -> None: diff --git a/dreadnode/storage_utils.py b/dreadnode/storage_utils.py new file mode 100644 index 00000000..9599a238 --- /dev/null +++ b/dreadnode/storage_utils.py @@ -0,0 +1,37 @@ +import functools +import typing as t + +from dreadnode.util import logger + + +def with_credential_refresh(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + """Decorator that automatically handles credential refresh on storage errors.""" + + @functools.wraps(func) + def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: + # Try to refresh credentials before operation + if hasattr(self, "_refresh_credentials_if_needed"): + self._refresh_credentials_if_needed() + + try: + return func(self, *args, **kwargs) + except Exception as e: + error_str = str(e) + if any( + error in error_str + for error in [ + "ExpiredToken", + "TokenRefreshRequired", + "InvalidAccessKeyId", + "The Access Key Id you provided does not exist", + ] + ): + logger.info("Storage credential error, forcing refresh and retrying") + + if hasattr(self, "_refresh_credentials_if_needed"): + self._refresh_credentials_if_needed() + + return func(self, *args, **kwargs) + raise + + return wrapper diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 3ed93910..2c2cbcd7 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -36,6 +36,7 @@ from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize +from dreadnode.storage_utils import with_credential_refresh from dreadnode.tracing.constants import ( EVENT_ATTRIBUTE_LINK_HASH, EVENT_ATTRIBUTE_OBJECT_HASH, @@ -235,7 +236,9 @@ def set_attribute( self._added_attributes = True if schema and raw is False: self._schema[key] = create_json_schema(value, set()) - otel_value = self._pre_attributes[key] = value if raw else prepare_otlp_attribute(value) + otel_value = self._pre_attributes[key] = ( + value if raw else prepare_otlp_attribute(value) + ) if self._span is not None: self._span.set_attribute(key, otel_value) self._pre_attributes[key] = otel_value @@ -323,7 +326,11 @@ def __init__( **({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}), **({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}), **({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}), - **({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}), + **( + {SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} + if object_schemas + else {} + ), } # Mark objects and schemas as large attributes if present @@ -335,7 +342,9 @@ def __init__( large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS) attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs - super().__init__(f"run.{run_id}.update", tracer, type="run_update", attributes=attributes) + super().__init__( + f"run.{run_id}.update", tracer, type="run_update", attributes=attributes + ) def __repr__(self) -> str: status = "active" if self.is_recording else "inactive" @@ -365,6 +374,7 @@ def __init__( update_frequency: int = 5, run_id: str | ULID | None = None, type: SpanType = "run", + credential_refresher: t.Callable[[], bool] | None = None, ) -> None: self.autolog = autolog self.project = project @@ -375,7 +385,9 @@ def __init__( self._object_schemas: dict[str, JsonDict] = {} self._inputs: list[ObjectRef] = [] self._outputs: list[ObjectRef] = [] - self._artifact_storage = ArtifactStorage(file_system=file_system) + self._artifact_storage = ArtifactStorage( + file_system=file_system, credential_refresher=credential_refresher + ) self._artifacts: list[DirectoryNode] = [] self._artifact_merger = ArtifactMerger() self._artifact_tree_builder = ArtifactTreeBuilder( @@ -406,6 +418,7 @@ def __init__( SPAN_ATTRIBUTE_PROJECT: project, **(attributes or {}), } + self._credential_refresher = credential_refresher super().__init__(name, tracer, attributes=attributes, type=type, tags=tags) @classmethod @@ -415,6 +428,7 @@ def from_context( tracer: Tracer, file_system: AbstractFileSystem, prefix_path: str, + credential_refresher: t.Callable[[], bool] | None = None, ) -> "RunSpan": self = RunSpan( name=f"run.{context['run_id']}.fragment", @@ -425,6 +439,7 @@ def from_context( prefix_path=prefix_path, type="run_fragment", run_id=context["run_id"], + credential_refresher=credential_refresher, ) self._remote_context = context["trace_context"] @@ -500,12 +515,18 @@ def __exit__( if self._context_token is not None: current_run_span.reset(self._context_token) + def _refresh_credentials_if_needed(self) -> None: + if self._credential_refresher: + self._credential_refresher() + def push_update(self, *, force: bool = False) -> None: if self._span is None: return current_time = time.time() - force_update = force or (current_time - self._last_update_time >= self._update_frequency) + force_update = force or ( + current_time - self._last_update_time >= self._update_frequency + ) should_update = force_update and ( self._pending_params or self._pending_inputs @@ -527,7 +548,9 @@ def push_update(self, *, force: bool = False) -> None: inputs=self._pending_inputs if self._pending_inputs else None, outputs=self._pending_outputs if self._pending_outputs else None, objects=self._pending_objects if self._pending_objects else None, - object_schemas=self._pending_object_schemas if self._pending_object_schemas else None, + object_schemas=self._pending_object_schemas + if self._pending_object_schemas + else None, ): pass @@ -604,6 +627,7 @@ def log_object( return composite_hash + @with_credential_refresh def _store_file_by_hash(self, data: bytes, full_path: str) -> str: """ Writes data to the given full_path in the object store if it doesn't already exist. @@ -622,7 +646,9 @@ def _store_file_by_hash(self, data: bytes, full_path: str) -> str: return str(self._file_system.unstrip_protocol(full_path)) - def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Object: + def _create_object_by_hash( + self, serialized: Serialized, object_hash: str + ) -> Object: """Create an ObjectVal or ObjectUri depending on size with a specific hash.""" data = serialized.data data_bytes = serialized.data_bytes @@ -630,7 +656,11 @@ def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Ob data_hash = serialized.data_hash schema_hash = serialized.schema_hash - if data is None or data_bytes is None or data_len <= DEFAULT_MAX_INLINE_OBJECT_BYTES: + if ( + data is None + or data_bytes is None + or data_len <= DEFAULT_MAX_INLINE_OBJECT_BYTES + ): return ObjectVal( hash=object_hash, value=data, @@ -863,7 +893,9 @@ def __init__( self._output: R | Unset = UNSET # For the python output - self._context_token: Token[TaskSpan[t.Any] | None] | None = None # contextvars context + self._context_token: Token[TaskSpan[t.Any] | None] | None = ( + None # contextvars context + ) self._tasks: list[TaskSpan[t.Any]] = [] self._parent_task: TaskSpan[t.Any] | None = None @@ -875,7 +907,9 @@ def __init__( SPAN_ATTRIBUTE_OUTPUTS: self._outputs, **(attributes or {}), } - super().__init__(name, tracer, type="task", attributes=attributes, label=label, tags=tags) + super().__init__( + name, tracer, type="task", attributes=attributes, label=label, tags=tags + ) def __enter__(self) -> te.Self: self._run = current_run_span.get() @@ -975,7 +1009,9 @@ def log_output( label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + self._outputs.append( + ObjectRef(name, label=label, hash=hash_, attributes=attributes) + ) return hash_ @property @@ -1004,7 +1040,9 @@ def log_input( label=label, event_name=EVENT_NAME_OBJECT_INPUT, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + self._inputs.append( + ObjectRef(name, label=label, hash=hash_, attributes=attributes) + ) return hash_ @property @@ -1066,7 +1104,9 @@ def log_metric( # this task-metric was logged here. if (run := current_run_span.get()) is not None: - metric = run.log_metric(key, metric, prefix=self._label, origin=origin, mode=mode) + metric = run.log_metric( + key, metric, prefix=self._label, origin=origin, mode=mode + ) self._metrics.setdefault(key, []).append(metric) diff --git a/poetry.lock b/poetry.lock index 003108b3..82db87b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5123,4 +5123,4 @@ training = ["transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "0f9e538475309634ca67a66835b23db97718351ceb13f5a835b47ad8b740908b" +content-hash = "0f9e538475309634ca67a66835b23db97718351ceb13f5a835b47ad8b740908b" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 27f6dc6a..a2dcc93d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,4 +130,4 @@ skip-magic-trailing-comma = false "INP001", # namespace not required for pytest "S101", # asserts allowed in tests... "SLF001", # allow access to private members -] +] \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/scripts/make_docs.py b/scripts/make_docs.py deleted file mode 100644 index e47ffebe..00000000 --- a/scripts/make_docs.py +++ /dev/null @@ -1,676 +0,0 @@ -# type: ignore # noqa: PGH003 - - -import argparse -import builtins -import inspect -import io -import json -import logging -import pkgutil -import pydoc -import re -import sys -import types -import typing as t -from importlib import import_module -from pathlib import Path - -from docstring_parser import Docstring, DocstringStyle, ParseError -from docstring_parser import parse as parse_docstring - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - -logger = logging.getLogger(__name__) - - -# --- Helper Functions --- -def get_raw_doc(obj: t.Any) -> str | None: - """Gets the raw, uncleaned docstring.""" - return inspect.getdoc(obj) - - -def format_signature(obj: t.Any, class_name: str | None = None) -> str: - """Formats the signature of a callable object for code display.""" - try: - actual_obj = obj - if isinstance(obj, property): - actual_obj = obj.fget if obj.fget else obj - elif isinstance(obj, (staticmethod, classmethod)): - actual_obj = getattr(obj, "__func__", obj) - if not callable(actual_obj): - return "" - sig = inspect.signature(actual_obj) - sig_str = str(sig) - if class_name: - sig_str = re.sub(rf"\b{re.escape(class_name)}\.([\w]+)\b", r"\1", sig_str) - except (ValueError, TypeError): - return "(...)" - except Exception as e: # noqa: BLE001 - logger.warning( - "Warning: Could not get signature for %s: %s", getattr(obj, "__name__", "unknown"), e - ) - return "(...)" - - return sig_str - - -def format_type_annotation(annotation_str: str) -> str: - """Convert pipe-style union types to MDX-compatible format.""" - if "|" in annotation_str: - # Handle simple Optional types - if " | None" in annotation_str: - base_type = annotation_str.replace(" | None", "").strip() - return f"Optional[{base_type}]" - - # General case for union types - parts = [part.strip() for part in annotation_str.split("|")] - return f"Union[{', '.join(parts)}]" - - return annotation_str - - -# --- Core MDX Generator Class --- - - -class MDXDoc(pydoc.HTMLDoc): - """Formatter class for creating clean, readable MDX documentation.""" - - def __init__(self, auth_group: str | None = None): - super().__init__() - self.auth_group = auth_group - - # --- Docstring Formatting Logic --- - def _format_docstring(self, obj: t.Any) -> str: - """Parses and formats the docstring using a clean, traditional style with concise sections.""" - raw_doc = get_raw_doc(obj) - if not raw_doc: - return "" - - try: - parsed = parse_docstring(raw_doc, style=DocstringStyle.GOOGLE) - output = io.StringIO() - - # Process each section of the docstring - self._write_description(parsed, output) - self._write_parameters(parsed, output) - self._write_returns(parsed, output) - self._write_raises(parsed, output) - - return output.getvalue() - - except ParseError as e: - logger.warning( - "Warning: Could not parse docstring for %s: %s", - getattr(obj, "__name__", "object"), - e, - ) - return raw_doc.replace("<", r"\<") + "\n\n" - except Exception as e: # noqa: BLE001 - logger.info( - "Error formatting docstring for %s: %s", getattr(obj, "__name__", "object"), e - ) - return raw_doc.replace("<", r"\<") + "\n\n" - - def _write_description(self, parsed: Docstring, output: io.StringIO) -> None: - """Writes the description section of the docstring.""" - description = "" - if parsed.short_description: - description += parsed.short_description - if parsed.long_description: - if description: - description += "\n\n" - description += parsed.long_description - if description: - output.write(description.replace("<", r"\<") + "\n\n") - - def _write_parameters(self, parsed: Docstring, output: io.StringIO) -> None: - """Writes the parameters section of the docstring.""" - if parsed.params: - output.write("**Parameters:**\n\n") - for param in parsed.params: - param_header = f"**`{param.arg_name}`**" - if param.type_name: - safe_type = param.type_name.replace("`", r"\`").replace("<", r"\<") - safe_type = format_type_annotation(safe_type) - param_header += f" (`{safe_type}`)" - if param.is_optional: - param_header += " *(optional)*" - output.write(f"- {param_header}") - if param.description: - updated_description = param.description.replace("<", r"\<") - output.write(f": {updated_description}") - if param.default: - safe_default = param.default.replace("`", r"\`") - output.write(f" Default: `{safe_default}`") - output.write("\n") - output.write("\n") - - def _write_returns(self, parsed: Docstring, output: io.StringIO) -> None: - """Writes the returns section of the docstring.""" - if parsed.returns: - return_line = "**Returns:** " - if parsed.returns.type_name: - safe_type = parsed.returns.type_name.replace("`", r"\`").replace("<", r"\<") - safe_type = format_type_annotation(safe_type) - return_line += f"`{safe_type}`" - if parsed.returns.description: - if parsed.returns.type_name: - return_line += " — " - return_line += parsed.returns.description.replace("<", r"\<") - output.write(return_line + "\n\n") - - def _write_raises(self, parsed: Docstring, output: io.StringIO) -> None: - """Writes the raises section of the docstring.""" - if parsed.raises: - output.write("**Raises:**\n\n") - for exc in parsed.raises: - exc_line = "- " - if exc.type_name: - safe_type = exc.type_name.replace("`", r"\`").replace("<", r"\<") - safe_type = format_type_annotation(safe_type) - exc_line += f"`{safe_type}`" - if exc.description: - if exc.type_name: - exc_line += " — " - exc_line += exc.description.replace("<", r"\<") - output.write(exc_line + "\n") - output.write("\n") - - # --- Overridden pydoc methods --- - - def page(self, title: str, contents: str) -> str: - safe_title = title.replace("'", "''") - return f"---\ntitle: '{safe_title}'\n---\n\n{contents}" - - def heading(self, title: str, level: str = 1) -> str: - return f"{'#' * level} {title}\n" - - def section(self, title: str, contents: str, level: int = 2) -> str: - return f"\n{'#' * level} {title}\n\n{contents}\n" - - def docmodule( - self, - object: types.ModuleType, - ) -> str: - full_name = object.__name__ - short_name = full_name.split(".")[-1] - safe_short_name = short_name.replace("'", "''") - output = io.StringIO() - - # Write frontmatter and module header - self._write_frontmatter(output, safe_short_name, short_name, full_name) - - # Write source file information - self._write_source_file_info(output, object) - - # Write module docstring - module_doc_formatted = self._format_docstring(object) - output.write(module_doc_formatted) - - # Collect and document members - classes, functions = self._collect_members(object, full_name) - self._write_classes(output, classes, full_name) - self._write_functions(output, functions) - - return output.getvalue() - - def _write_frontmatter(self, output, safe_short_name, short_name, full_name): - """Writes the frontmatter and module header.""" - output.write(f"---\ntitle: '{safe_short_name}'\nsidebarTitle: '{safe_short_name}'\n") - if self.auth_group: - output.write(f'groups: ["{self.auth_group}"]\n') - output.write("---\n\n") - output.write(f"# Module `{short_name}`\n\n") - output.write(f"*(Full name: `{full_name}`)*\n\n") - - def _write_source_file_info(self, output, object): - """Writes the source file information.""" - try: - source_file = inspect.getsourcefile(object) - if source_file: - output.write(f"**Source file:** `{Path(source_file).name}`\n\n") - except (TypeError, OSError): - pass - except Exception as e: # noqa: BLE001 - logger.warning("Warning: Could not get source file: %s", e) - - def _collect_members(self, object, full_name): - """Collects classes and functions defined in the module.""" - classes, functions = [], [] - try: - for member_name, member_obj in inspect.getmembers(object): - if member_name.startswith("_") and not member_name.startswith("__"): - continue - if self._is_defined_here(member_obj, full_name): - if inspect.isclass(member_obj): - classes.append((member_name, member_obj)) - elif inspect.isfunction(member_obj): - functions.append((member_name, member_obj)) - except Exception as e: # noqa: BLE001 - logger.warning("Error inspecting members: %s", e) - return classes, functions - - def _is_defined_here(self, member_obj, full_name): - """Checks if a member is defined in the current module.""" - try: - member_module = inspect.getmodule(member_obj) - except Exception: # noqa: BLE001 - return False - return member_module is not None and member_module.__name__ == full_name - - def _write_classes(self, output, classes, full_name): - """Writes the documentation for classes.""" - if classes: - output.write("## Classes\n\n") - for class_name, class_obj in sorted(classes, key=lambda item: item[0]): - output.write(self.docclass(class_obj, class_name, module_name=full_name)) - output.write("\n") - - def _write_functions(self, output, functions): - """Writes the documentation for functions.""" - if functions: - output.write("## Functions\n\n") - for func_name, func_obj in sorted(functions, key=lambda item: item[0]): - output.write(self.docroutine(func_obj, func_name, class_name=None)) - output.write("\n") - - def docclass( - self, object: type, name: str | None = None, module_name: str | None = None - ) -> str: - real_name = name or object.__name__ - output = io.StringIO() - output.write(f"\n### Class `{real_name}`\n\n") - - self._write_inheritance_info(object, output) - self._write_class_docstring(object, output) - - methods, properties = self._collect_class_members(object, module_name) - self._write_properties(properties, output, real_name) - self._write_methods(methods, output, real_name) - - return output.getvalue() - - def _write_inheritance_info(self, object: type, output: io.StringIO): - """Writes inheritance information for a class.""" - if object.__bases__: - bases = [] - for b in object.__bases__: - if b is object or (b is builtins.object and len(object.__bases__) > 1): - continue - base_module = getattr(b, "__module__", "") - base_name_str = getattr(b, "__name__", str(b)) - if base_module and base_module != "builtins": - bases.append(f"`{base_module}.{base_name_str}`") - else: - bases.append(f"`{base_name_str}`") - if bases: - output.write(f"**Inherits from:** {', '.join(bases)}\n\n") - - def _write_class_docstring(self, object: type, output: io.StringIO): - """Writes the formatted docstring for a class.""" - class_doc_formatted = self._format_docstring(object) - output.write(class_doc_formatted) - - def _collect_class_members(self, object: type, module_name: str | None) -> tuple[list, list]: - """Collects methods and properties of a class.""" - methods, properties = [], [] - try: - for member_name, member_obj in inspect.getmembers(object): - if member_name.startswith("_") and not member_name.startswith("__"): - continue - if self._is_relevant_member(member_name, member_obj, object, module_name): - if isinstance(member_obj, property): - properties.append((member_name, member_obj)) - elif self._is_method(member_obj): - methods.append((member_name, member_obj)) - except Exception as e: # noqa: BLE001 - logger.warning("Error inspecting members of %s: %s", object.__name__, e) - return methods, properties - - def _is_relevant_member( - self, member_name: str, member_obj: t.Any, object: type, module_name: str | None - ) -> bool: - """Determines if a member is relevant for documentation.""" - is_directly_defined = member_name in object.__dict__ - try: - target_obj = member_obj.fget if isinstance(member_obj, property) else member_obj - member_origin_module = inspect.getmodule(target_obj) - except Exception: # noqa: BLE001 - return False - - return is_directly_defined or ( - member_origin_module and member_origin_module.__name__ == module_name - ) - - def _is_method(self, member_obj: t.Any) -> bool: - """Checks if a member is a method.""" - return inspect.isfunction(member_obj) or isinstance(member_obj, (classmethod, staticmethod)) - - def _write_properties(self, properties: list, output: io.StringIO, class_name: str): - """Writes properties of a class.""" - if properties: - output.write("#### Properties\n\n") - for prop_name, prop_obj in sorted(properties, key=lambda item: item[0]): - output.write(self._docproperty(prop_obj, prop_name, class_name=class_name)) - - def _write_methods(self, methods: list, output: io.StringIO, class_name: str): - """Writes methods of a class.""" - if methods: - output.write("#### Methods\n\n") - for method_name, method_obj in sorted(methods, key=lambda item: item[0]): - output.write(self.docroutine(method_obj, method_name, class_name=class_name)) - - def docroutine( - self, - object: t.Any, - name: str | None = None, - class_name: str | None = None, - ) -> str: - real_name = name or getattr(object, "__name__", "unknown_routine") - output = io.StringIO() - - # Heading level based on context (class method vs standalone function) - heading_level = 5 if class_name else 3 - output.write(f"{'#' * heading_level} `{real_name}`\n\n") - - # Function/method signature - signature = format_signature(object, class_name=class_name) - if signature and signature != "(...)": - output.write(f"```python\n{real_name}{signature}\n```\n\n") - elif real_name: - output.write(f"`{real_name}(...)`\n\n") - - # Docstring content - doc_formatted = self._format_docstring(object) - output.write(doc_formatted) - - # Add a separator only if we're not at the end of a section - if class_name: - output.write("---\n\n") - - return output.getvalue() - - def _docproperty(self, prop: property, name: str, class_name: str | None = None) -> str: - output = io.StringIO() - output.write(f"##### `{name}`\n\n") - - # Get property type annotation - type_hint_str = "" - target_for_type = prop.fget if prop.fget else prop - if hasattr(target_for_type, "__annotations__"): - try: - return_annotation = t.get_type_hints(target_for_type).get("return") - if return_annotation: - annotation = str(return_annotation) - annotation = re.sub(r"\btyping\.", "", annotation) - if class_name: - annotation = re.sub(rf"\b{re.escape(class_name)}\.", "", annotation) - annotation = format_type_annotation(annotation) - safe_annotation = annotation.replace("`", r"\`").replace("<", r"\<") - type_hint_str = f"`{safe_annotation}`" - except Exception as e: # noqa: BLE001 - logger.warning( - "Warning: Error getting type hint for property %s: %s", - getattr(prop, "__name__", "unknown_property"), - e, - ) - - # Show type compactly - if type_hint_str: - output.write(f"**Type:** {type_hint_str} *(property)*\n\n") - else: - output.write("*(property)*\n\n") - - # Documentation - doc_obj = prop - raw_doc = get_raw_doc(prop) - if not raw_doc and prop.fget: - doc_obj = prop.fget - - doc_formatted = self._format_docstring(doc_obj) - output.write(doc_formatted) - - # Property details (compact) - details = [] - if prop.fget: - details.append("getter") - if prop.fset: - details.append("setter") - if prop.fdel: - details.append("deleter") - if details: - output.write(f"*Has: {', '.join(details)}*\n\n") - - # Add separator - output.write("---\n\n") - - return output.getvalue() - - def link(self, text: str, url: str) -> str: - return f"[{text}]({url})" - - def strong(self, text: str) -> str: - return f"**{text}**" - - def emphasis(self, text: str) -> str: - return f"*{text}*" - - def escape(self, text: str) -> str: - return text.replace("<", r"\<") - - def preformat(self, text: str) -> str: - return f"```\n{text}\n```" - - def multicolumn(self, list_items: list[t.Any], fmt: t.Callable) -> str: - return "\n".join(f"- {fmt(item)}" for item in list_items) - - def grey(self, text: str) -> str: - return text - - def write(self, *args, **kwargs): - pass - - -# --- Main execution logic --- -def generate_mdx_docs( - module_paths: list[str], - output_dir: str, - auth_group: str | None = None, - project_root: str | None = None, -): - """Generates clean, traditional MDX documentation for Python modules.""" - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - logger.info("Output directory: %s", output_path.resolve()) - - mdx_formatter = MDXDoc(auth_group=auth_group) - _project_root_path = Path(project_root).resolve() if project_root else Path.cwd() - _setup_sys_path(_project_root_path) - - processed_modules = set() - generated_files = [] - - for path_str in module_paths: - path = Path(path_str).resolve() - logger.info("Processing path: %s", path) - if path.is_file() and path.suffix == ".py" and path.name != "__init__.py": - _process_file(path, mdx_formatter, processed_modules, generated_files, output_path) - elif path.is_dir(): - _process_directory(path, mdx_formatter, processed_modules, generated_files, output_path) - else: - logger.warning("Warning: Path is not Python file/directory: %s", path) - - _write_docs_json(generated_files, output_path) - - -def _setup_sys_path(project_root: Path): - """Sets up the system path for module imports.""" - if str(project_root) not in sys.path: - sys.path.insert(0, str(project_root)) - if Path.cwd() not in sys.path: - sys.path.insert(0, Path.cwd()) - - -def _process_file(path, mdx_formatter, processed_modules, generated_files, output_path): - """Processes a single Python file.""" - module_name = _determine_module_name(path) - if not module_name: - return - - logger.info(" Attempting to import module: %s", module_name) - try: - module = import_module(module_name) - if module.__name__ in processed_modules: - return - logger.info(" Generating MDX for module: %s", module.__name__) - _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path) - except ImportError: - logger.exception("Error importing module '%s'", module_name) - except Exception: - logger.exception("Error processing module %s", module_name) - - -def _process_directory(path, mdx_formatter, processed_modules, generated_files, output_path): - """Processes a directory as a package.""" - logger.info(" Processing directory as package: %s", path.name) - package_name = path.name - if str(path.parent) not in sys.path: - sys.path.insert(0, str(path.parent)) - for _, modname, _ in pkgutil.walk_packages([str(path)], prefix=f"{package_name}."): - if modname in processed_modules: - continue - logger.info(" Attempting to import package module: %s", modname) - try: - module = import_module(modname) - logger.info(" Generating MDX for module: %s", module.__name__) - _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path) - except ImportError: - logger.exception("Error importing package module %s", modname) - except Exception: - logger.exception("Error processing package module %s", modname) - - -def _determine_module_name(path): - """Determines the module name for a given file path.""" - try: - best_match_len = -1 - module_name = None - for p_str in sys.path: - p = Path(p_str).resolve() - try: - rel_path = path.relative_to(p) - if ".." not in rel_path.parts: - current_len = len(p.parts) - if current_len > best_match_len: - best_match_len = current_len - module_name_parts = [*list(rel_path.parts[:-1]), path.stem] - module_name = ".".join(part for part in module_name_parts if part) - except ValueError: - continue - if not module_name: - module_name = path.stem - if str(path.parent) not in sys.path: - sys.path.insert(0, str(path.parent)) - else: - return module_name - except Exception as e: # noqa: BLE001 - logger.warning("Warning: Error determining module name for %s: %s", path, e) - return None - - -def _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path): - """Generates MDX documentation for a module.""" - mdx_content = mdx_formatter.docmodule(module) - output_filename = f"{module.__name__.replace('.', '/')}.mdx" - output_file = output_path / output_filename - output_file.parent.mkdir(parents=True, exist_ok=True) - with Path.open(output_file, "w", encoding="utf-8") as f: - f.write(mdx_content) - logger.info(" -> Wrote %s", output_file) - processed_modules.add(module.__name__) - relative_path = str(output_file.relative_to(output_path.parent)).replace(".mdx", "") - generated_files.append(relative_path) - - -def _write_docs_json(generated_files, output_path): - """Writes the docs.json file.""" - docs_json_path = output_path / "docs.json" - nav_file_paths = [ - str(file_path).replace(str(output_path) + "/", "") for file_path in generated_files - ] - for i, file_path in enumerate(nav_file_paths): - if "dreadnode/" in file_path: - parts = file_path.split("/") - if len(parts) > 1 and parts[0] == "dreadnode": - parts.insert(1, "library") - nav_file_paths[i] = "/".join(parts) - nested_pages = _build_nested_structure(nav_file_paths) - docs_structure = {"group": "API Reference", "pages": nested_pages} - with Path.open(docs_json_path, "w", encoding="utf-8") as f: - json.dump(docs_structure, f, indent=2) - logger.info("Generated navigation structure written to %s", docs_json_path) - - -def _build_nested_structure(file_paths, base_prefix="dreadnode/library"): - """Builds a nested structure for the navigation based on file paths.""" - nested_structure = {} - for file_path in file_paths: - if file_path.startswith(base_prefix): - relative_path = file_path[len(base_prefix) + 1 :] - parts = relative_path.split("/") - current_level = nested_structure - for part in parts[:-1]: - if part not in current_level: - current_level[part] = {} - elif isinstance(current_level[part], str): - current_level[part] = {"index": current_level[part]} - current_level = current_level[part] - if parts[-1] in current_level and isinstance(current_level[parts[-1]], dict): - if "dreadnode/" in file_path: - modified_path = file_path.replace("dreadnode/", "strikes/") - current_level[parts[-1]]["index"] = modified_path - else: - current_level[parts[-1]]["index"] = file_path - elif "dreadnode/" in file_path: - modified_path = file_path.replace("dreadnode/", "strikes/") - current_level[parts[-1]] = modified_path - else: - current_level[parts[-1]] = file_path - - def convert_to_list(structure): - result = [] - for key, value in sorted(structure.items()): - if isinstance(value, dict): - result.append({"group": key, "pages": convert_to_list(value)}) - else: - result.append(value) - return result - - return convert_to_list(nested_structure) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate simple, clean MDX documentation for Python modules." - ) - parser.add_argument("modules", nargs="+", help="Paths to Python files or package directories.") - parser.add_argument( - "-o", "--output-dir", default="docs", help="Directory to write MDX files (default: ./docs)." - ) - parser.add_argument( - "-p", - "--project-root", - default=None, - help="Optional path to the project root directory (assists with import resolution). Defaults to CWD.", - ) - parser.add_argument( - "-g", - "--auth-group", - choices=["crucible", "strikes", "spyglass"], - help="Optional authentication group to add to frontmatter.", - ) - - args = parser.parse_args() - generate_mdx_docs(args.modules, args.output_dir, args.auth_group, args.project_root) - logger.info("MDX generation complete.")