diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 00000000..f58663f3 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,31 @@ +name: Notify Documentation Update + +on: + push: + branches: [main] + paths: + - "docs/**" + - "scripts/make_docs.py" + workflow_dispatch: + +jobs: + notify-docs: + runs-on: ubuntu-latest + steps: + - uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 + id: app-token + with: + app-id: ${{ vars.UPDATE_DOCS_APP_ID }} + private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }} + owner: "${{ github.repository_owner }}" + repositories: | + sdk + prod-docs + + - name: Trigger docs repository workflow + uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0 + with: + token: ${{ steps.app-token.outputs.token }} + repository: dreadnode/prod-docs + event-type: code-update + client-payload: '{"repository": "${{ github.repository }}", "ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "product": "strikes", "docs_dir": "docs", "module_dir": "dreadnode"}' diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index da0f09e8..c8a439c9 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -37,7 +37,11 @@ class ApiClient: - """Client for the Dreadnode API.""" + """Client for the Dreadnode API. + + This class provides methods to interact with the Dreadnode API, including + retrieving projects, runs, tasks, and exporting data. + """ def __init__( self, @@ -46,6 +50,13 @@ def __init__( *, debug: bool = False, ): + """Initializes the API client. + + Args: + base_url (str): The base URL of the Dreadnode API. + api_key (str): The API key for authentication. + debug (bool, optional): Whether to enable debug logging. Defaults to False. + """ self._base_url = base_url.rstrip("/") if not self._base_url.endswith("/api"): self._base_url += "/api" @@ -65,7 +76,11 @@ def __init__( self._client.event_hooks["response"].append(self._log_response) def _log_request(self, request: httpx.Request) -> None: - """Log every request to the console if debug is enabled.""" + """Logs HTTP requests if debug mode is enabled. + + Args: + request (httpx.Request): The HTTP request object. + """ logger.debug("-------------------------------------------") logger.debug("%s %s", request.method, request.url) @@ -74,7 +89,11 @@ def _log_request(self, request: httpx.Request) -> None: logger.debug("-------------------------------------------") def _log_response(self, response: httpx.Response) -> None: - """Log every response to the console if debug is enabled.""" + """Logs HTTP responses if debug mode is enabled. + + Args: + response (httpx.Response): The HTTP response object. + """ logger.debug("-------------------------------------------") logger.debug("Response: %s", response.status_code) @@ -83,7 +102,14 @@ def _log_response(self, response: httpx.Response) -> None: logger.debug("--------------------------------------------") def _get_error_message(self, response: httpx.Response) -> str: - """Get the error message from the response.""" + """Extracts the error message from an HTTP response. + + Args: + response (httpx.Response): The HTTP response object. + + Returns: + str: The error message extracted from the response. + """ try: obj = response.json() @@ -98,7 +124,17 @@ def _request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Make a raw request to the API.""" + """Makes a raw HTTP request to the API. + + Args: + method (str): The HTTP method (e.g., "GET", "POST"). + path (str): The API endpoint path. + params (dict[str, Any] | None, optional): Query parameters for the request. Defaults to None. + json_data (dict[str, Any] | None, optional): JSON payload for the request. Defaults to None. + + Returns: + httpx.Response: The HTTP response object. + """ return self._client.request(method, path, json=json_data, params=params) @@ -109,7 +145,20 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Make a request to the API. Raise an exception for non-200 status codes.""" + """Makes an HTTP request to the API and raises exceptions for errors. + + Args: + method (str): The HTTP method (e.g., "GET", "POST"). + path (str): The API endpoint path. + params (dict[str, Any] | None, optional): Query parameters for the request. Defaults to None. + json_data (dict[str, Any] | None, optional): JSON payload for the request. Defaults to None. + + Returns: + httpx.Response: The HTTP response object. + + Raises: + RuntimeError: If the response status code indicates an error. + """ response = self._request(method, path, params, json_data) if response.status_code == 401: # noqa: PLR2004 @@ -123,10 +172,23 @@ def request( return response def list_projects(self) -> list[Project]: + """Retrieves a list of projects. + + Returns: + list[Project]: A list of Project objects. + """ response = self.request("GET", "/strikes/projects") return [Project(**project) for project in response.json()] def get_project(self, project: str) -> Project: + """Retrieves details of a specific project. + + Args: + project (str): The project identifier. + + Returns: + Project: The Project object. + """ response = self.request("GET", f"/strikes/projects/{project!s}") return Project(**response.json()) @@ -195,6 +257,17 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports run data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of runs to include. Defaults to "completed". + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported run data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export", @@ -217,6 +290,18 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports metric data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of metrics to include. Defaults to "completed". + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported metric data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/metrics", @@ -241,6 +326,19 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports parameter data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of parameters to include. Defaults to "completed". + parameters (list[str] | None, optional): A list of parameter names to include. Defaults to None. + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported parameter data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/parameters", @@ -266,6 +364,19 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: + """Exports timeseries data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of timeseries to include. Defaults to "completed". + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + time_axis (TimeAxisType, optional): The type of time axis to use. Defaults to "relative". + aggregations (list[TimeAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported timeseries data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/timeseries", @@ -283,5 +394,10 @@ def export_timeseries( # User data access def get_user_data_credentials(self) -> UserDataCredentials: + """Retrieves user data credentials. + + Returns: + UserDataCredentials: The user data credentials object. + """ response = self.request("GET", "/user-data/credentials") return UserDataCredentials(**response.json()) diff --git a/dreadnode/integrations/transformers.py b/dreadnode/integrations/transformers.py index 861b481a..697efcf9 100644 --- a/dreadnode/integrations/transformers.py +++ b/dreadnode/integrations/transformers.py @@ -1,3 +1,10 @@ +""" +This module provides an integration with the `transformers` library for logging +metrics and parameters to Dreadnode during training. It includes a custom +`TrainerCallback` implementation that tracks training progress and logs relevant +information to Dreadnode. +""" + import importlib.util if importlib.util.find_spec("transformers") is None: @@ -17,6 +24,15 @@ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]: + """ + Cleans the keys of a dictionary by replacing certain prefixes with slashes. + + Args: + data (dict[str, t.Any]): The dictionary to clean. + + Returns: + dict[str, t.Any]: A new dictionary with cleaned keys. + """ cleaned: dict[str, t.Any] = {} for key, val in data.items(): _key = key.replace("eval_", "eval/").replace("test_", "test/").replace("train_", "train/") @@ -30,6 +46,12 @@ class DreadnodeCallback(TrainerCallback): This callback is used to log metrics and parameters to Dreadnode during training inside the `transformers` library or derivations (`trl`, etc.). + + + Attributes: + project (str | None): The project name in Dreadnode. + run_name (str | None): The name of the training run. + tags (list[str]): A list of tags associated with the run. """ def __init__( @@ -38,6 +60,14 @@ def __init__( run_name: str | None = None, tags: list[str] | None = None, ): + """ + Initializes the DreadnodeCallback. + + Args: + project (str | None): The project name in Dreadnode. + run_name (str | None): The name of the training run. + tags (list[str] | None): A list of tags associated with the run. + """ self.project = project self.run_name = run_name self.tags = tags or [] @@ -48,6 +78,9 @@ def __init__( self._step_span: Span | None = None def _shutdown(self) -> None: + """ + Shuts down the callback by closing any active spans and the run. + """ if self._step_span is not None: self._step_span.__exit__(None, None, None) self._step_span = None @@ -61,6 +94,14 @@ def _shutdown(self) -> None: self._run = None def _setup(self, args: TrainingArguments, state: TrainerState, model: t.Any) -> None: + """ + Sets up the callback by initializing the Dreadnode run and logging parameters. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + model (t.Any): The model being trained. + """ if self._initialized: return @@ -101,6 +142,16 @@ def on_train_begin( model: t.Any | None = None, **kwargs: t.Any, ) -> None: + """ + Called at the beginning of training. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + model (t.Any | None): The model being trained. + **kwargs (t.Any): Additional keyword arguments. + """ if not self._initialized: self._setup(args, state, model) @@ -111,6 +162,15 @@ def on_train_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of training. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ self._shutdown() def on_epoch_begin( @@ -135,6 +195,15 @@ def on_epoch_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of an epoch. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._epoch_span is not None: self._epoch_span.__exit__(None, None, None) self._epoch_span = None @@ -146,6 +215,15 @@ def on_step_begin( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the beginning of a training step. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._run is None: return @@ -161,6 +239,15 @@ def on_step_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of a training step. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._step_span is not None: self._step_span.__exit__(None, None, None) self._step_span = None @@ -173,6 +260,16 @@ def on_log( logs: dict[str, t.Any] | None = None, **kwargs: t.Any, ) -> None: + """ + Called when logs are reported. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + logs (dict[str, t.Any] | None): The logs to process. + **kwargs (t.Any): Additional keyword arguments. + """ if self._run is None or logs is None: return diff --git a/pyproject.toml b/pyproject.toml index 1234fd46..d7ced9c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,9 @@ logfire = "^3.5.3" python-ulid = "^3.0.0" coolname = "^2.2.0" pandas = "^2.2.3" -fsspec = {version = ">=2023.1.0,<=2025.3.0", extras = ["s3"]} # Pinned for datasets compatibility +fsspec = { version = ">=2023.1.0,<=2025.3.0", extras = [ + "s3", +] } # Pinned for datasets compatibility transformers = { version = "^4.41.0", optional = true } @@ -59,6 +61,7 @@ asyncio_default_fixture_loop_scope = "function" [tool.mypy] strict = true +python_version = "3.10" [tool.bandit] exclude_dirs = ["tests"] diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/make_docs.py b/scripts/make_docs.py new file mode 100644 index 00000000..e47ffebe --- /dev/null +++ b/scripts/make_docs.py @@ -0,0 +1,676 @@ +# type: ignore # noqa: PGH003 + + +import argparse +import builtins +import inspect +import io +import json +import logging +import pkgutil +import pydoc +import re +import sys +import types +import typing as t +from importlib import import_module +from pathlib import Path + +from docstring_parser import Docstring, DocstringStyle, ParseError +from docstring_parser import parse as parse_docstring + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger(__name__) + + +# --- Helper Functions --- +def get_raw_doc(obj: t.Any) -> str | None: + """Gets the raw, uncleaned docstring.""" + return inspect.getdoc(obj) + + +def format_signature(obj: t.Any, class_name: str | None = None) -> str: + """Formats the signature of a callable object for code display.""" + try: + actual_obj = obj + if isinstance(obj, property): + actual_obj = obj.fget if obj.fget else obj + elif isinstance(obj, (staticmethod, classmethod)): + actual_obj = getattr(obj, "__func__", obj) + if not callable(actual_obj): + return "" + sig = inspect.signature(actual_obj) + sig_str = str(sig) + if class_name: + sig_str = re.sub(rf"\b{re.escape(class_name)}\.([\w]+)\b", r"\1", sig_str) + except (ValueError, TypeError): + return "(...)" + except Exception as e: # noqa: BLE001 + logger.warning( + "Warning: Could not get signature for %s: %s", getattr(obj, "__name__", "unknown"), e + ) + return "(...)" + + return sig_str + + +def format_type_annotation(annotation_str: str) -> str: + """Convert pipe-style union types to MDX-compatible format.""" + if "|" in annotation_str: + # Handle simple Optional types + if " | None" in annotation_str: + base_type = annotation_str.replace(" | None", "").strip() + return f"Optional[{base_type}]" + + # General case for union types + parts = [part.strip() for part in annotation_str.split("|")] + return f"Union[{', '.join(parts)}]" + + return annotation_str + + +# --- Core MDX Generator Class --- + + +class MDXDoc(pydoc.HTMLDoc): + """Formatter class for creating clean, readable MDX documentation.""" + + def __init__(self, auth_group: str | None = None): + super().__init__() + self.auth_group = auth_group + + # --- Docstring Formatting Logic --- + def _format_docstring(self, obj: t.Any) -> str: + """Parses and formats the docstring using a clean, traditional style with concise sections.""" + raw_doc = get_raw_doc(obj) + if not raw_doc: + return "" + + try: + parsed = parse_docstring(raw_doc, style=DocstringStyle.GOOGLE) + output = io.StringIO() + + # Process each section of the docstring + self._write_description(parsed, output) + self._write_parameters(parsed, output) + self._write_returns(parsed, output) + self._write_raises(parsed, output) + + return output.getvalue() + + except ParseError as e: + logger.warning( + "Warning: Could not parse docstring for %s: %s", + getattr(obj, "__name__", "object"), + e, + ) + return raw_doc.replace("<", r"\<") + "\n\n" + except Exception as e: # noqa: BLE001 + logger.info( + "Error formatting docstring for %s: %s", getattr(obj, "__name__", "object"), e + ) + return raw_doc.replace("<", r"\<") + "\n\n" + + def _write_description(self, parsed: Docstring, output: io.StringIO) -> None: + """Writes the description section of the docstring.""" + description = "" + if parsed.short_description: + description += parsed.short_description + if parsed.long_description: + if description: + description += "\n\n" + description += parsed.long_description + if description: + output.write(description.replace("<", r"\<") + "\n\n") + + def _write_parameters(self, parsed: Docstring, output: io.StringIO) -> None: + """Writes the parameters section of the docstring.""" + if parsed.params: + output.write("**Parameters:**\n\n") + for param in parsed.params: + param_header = f"**`{param.arg_name}`**" + if param.type_name: + safe_type = param.type_name.replace("`", r"\`").replace("<", r"\<") + safe_type = format_type_annotation(safe_type) + param_header += f" (`{safe_type}`)" + if param.is_optional: + param_header += " *(optional)*" + output.write(f"- {param_header}") + if param.description: + updated_description = param.description.replace("<", r"\<") + output.write(f": {updated_description}") + if param.default: + safe_default = param.default.replace("`", r"\`") + output.write(f" Default: `{safe_default}`") + output.write("\n") + output.write("\n") + + def _write_returns(self, parsed: Docstring, output: io.StringIO) -> None: + """Writes the returns section of the docstring.""" + if parsed.returns: + return_line = "**Returns:** " + if parsed.returns.type_name: + safe_type = parsed.returns.type_name.replace("`", r"\`").replace("<", r"\<") + safe_type = format_type_annotation(safe_type) + return_line += f"`{safe_type}`" + if parsed.returns.description: + if parsed.returns.type_name: + return_line += " — " + return_line += parsed.returns.description.replace("<", r"\<") + output.write(return_line + "\n\n") + + def _write_raises(self, parsed: Docstring, output: io.StringIO) -> None: + """Writes the raises section of the docstring.""" + if parsed.raises: + output.write("**Raises:**\n\n") + for exc in parsed.raises: + exc_line = "- " + if exc.type_name: + safe_type = exc.type_name.replace("`", r"\`").replace("<", r"\<") + safe_type = format_type_annotation(safe_type) + exc_line += f"`{safe_type}`" + if exc.description: + if exc.type_name: + exc_line += " — " + exc_line += exc.description.replace("<", r"\<") + output.write(exc_line + "\n") + output.write("\n") + + # --- Overridden pydoc methods --- + + def page(self, title: str, contents: str) -> str: + safe_title = title.replace("'", "''") + return f"---\ntitle: '{safe_title}'\n---\n\n{contents}" + + def heading(self, title: str, level: str = 1) -> str: + return f"{'#' * level} {title}\n" + + def section(self, title: str, contents: str, level: int = 2) -> str: + return f"\n{'#' * level} {title}\n\n{contents}\n" + + def docmodule( + self, + object: types.ModuleType, + ) -> str: + full_name = object.__name__ + short_name = full_name.split(".")[-1] + safe_short_name = short_name.replace("'", "''") + output = io.StringIO() + + # Write frontmatter and module header + self._write_frontmatter(output, safe_short_name, short_name, full_name) + + # Write source file information + self._write_source_file_info(output, object) + + # Write module docstring + module_doc_formatted = self._format_docstring(object) + output.write(module_doc_formatted) + + # Collect and document members + classes, functions = self._collect_members(object, full_name) + self._write_classes(output, classes, full_name) + self._write_functions(output, functions) + + return output.getvalue() + + def _write_frontmatter(self, output, safe_short_name, short_name, full_name): + """Writes the frontmatter and module header.""" + output.write(f"---\ntitle: '{safe_short_name}'\nsidebarTitle: '{safe_short_name}'\n") + if self.auth_group: + output.write(f'groups: ["{self.auth_group}"]\n') + output.write("---\n\n") + output.write(f"# Module `{short_name}`\n\n") + output.write(f"*(Full name: `{full_name}`)*\n\n") + + def _write_source_file_info(self, output, object): + """Writes the source file information.""" + try: + source_file = inspect.getsourcefile(object) + if source_file: + output.write(f"**Source file:** `{Path(source_file).name}`\n\n") + except (TypeError, OSError): + pass + except Exception as e: # noqa: BLE001 + logger.warning("Warning: Could not get source file: %s", e) + + def _collect_members(self, object, full_name): + """Collects classes and functions defined in the module.""" + classes, functions = [], [] + try: + for member_name, member_obj in inspect.getmembers(object): + if member_name.startswith("_") and not member_name.startswith("__"): + continue + if self._is_defined_here(member_obj, full_name): + if inspect.isclass(member_obj): + classes.append((member_name, member_obj)) + elif inspect.isfunction(member_obj): + functions.append((member_name, member_obj)) + except Exception as e: # noqa: BLE001 + logger.warning("Error inspecting members: %s", e) + return classes, functions + + def _is_defined_here(self, member_obj, full_name): + """Checks if a member is defined in the current module.""" + try: + member_module = inspect.getmodule(member_obj) + except Exception: # noqa: BLE001 + return False + return member_module is not None and member_module.__name__ == full_name + + def _write_classes(self, output, classes, full_name): + """Writes the documentation for classes.""" + if classes: + output.write("## Classes\n\n") + for class_name, class_obj in sorted(classes, key=lambda item: item[0]): + output.write(self.docclass(class_obj, class_name, module_name=full_name)) + output.write("\n") + + def _write_functions(self, output, functions): + """Writes the documentation for functions.""" + if functions: + output.write("## Functions\n\n") + for func_name, func_obj in sorted(functions, key=lambda item: item[0]): + output.write(self.docroutine(func_obj, func_name, class_name=None)) + output.write("\n") + + def docclass( + self, object: type, name: str | None = None, module_name: str | None = None + ) -> str: + real_name = name or object.__name__ + output = io.StringIO() + output.write(f"\n### Class `{real_name}`\n\n") + + self._write_inheritance_info(object, output) + self._write_class_docstring(object, output) + + methods, properties = self._collect_class_members(object, module_name) + self._write_properties(properties, output, real_name) + self._write_methods(methods, output, real_name) + + return output.getvalue() + + def _write_inheritance_info(self, object: type, output: io.StringIO): + """Writes inheritance information for a class.""" + if object.__bases__: + bases = [] + for b in object.__bases__: + if b is object or (b is builtins.object and len(object.__bases__) > 1): + continue + base_module = getattr(b, "__module__", "") + base_name_str = getattr(b, "__name__", str(b)) + if base_module and base_module != "builtins": + bases.append(f"`{base_module}.{base_name_str}`") + else: + bases.append(f"`{base_name_str}`") + if bases: + output.write(f"**Inherits from:** {', '.join(bases)}\n\n") + + def _write_class_docstring(self, object: type, output: io.StringIO): + """Writes the formatted docstring for a class.""" + class_doc_formatted = self._format_docstring(object) + output.write(class_doc_formatted) + + def _collect_class_members(self, object: type, module_name: str | None) -> tuple[list, list]: + """Collects methods and properties of a class.""" + methods, properties = [], [] + try: + for member_name, member_obj in inspect.getmembers(object): + if member_name.startswith("_") and not member_name.startswith("__"): + continue + if self._is_relevant_member(member_name, member_obj, object, module_name): + if isinstance(member_obj, property): + properties.append((member_name, member_obj)) + elif self._is_method(member_obj): + methods.append((member_name, member_obj)) + except Exception as e: # noqa: BLE001 + logger.warning("Error inspecting members of %s: %s", object.__name__, e) + return methods, properties + + def _is_relevant_member( + self, member_name: str, member_obj: t.Any, object: type, module_name: str | None + ) -> bool: + """Determines if a member is relevant for documentation.""" + is_directly_defined = member_name in object.__dict__ + try: + target_obj = member_obj.fget if isinstance(member_obj, property) else member_obj + member_origin_module = inspect.getmodule(target_obj) + except Exception: # noqa: BLE001 + return False + + return is_directly_defined or ( + member_origin_module and member_origin_module.__name__ == module_name + ) + + def _is_method(self, member_obj: t.Any) -> bool: + """Checks if a member is a method.""" + return inspect.isfunction(member_obj) or isinstance(member_obj, (classmethod, staticmethod)) + + def _write_properties(self, properties: list, output: io.StringIO, class_name: str): + """Writes properties of a class.""" + if properties: + output.write("#### Properties\n\n") + for prop_name, prop_obj in sorted(properties, key=lambda item: item[0]): + output.write(self._docproperty(prop_obj, prop_name, class_name=class_name)) + + def _write_methods(self, methods: list, output: io.StringIO, class_name: str): + """Writes methods of a class.""" + if methods: + output.write("#### Methods\n\n") + for method_name, method_obj in sorted(methods, key=lambda item: item[0]): + output.write(self.docroutine(method_obj, method_name, class_name=class_name)) + + def docroutine( + self, + object: t.Any, + name: str | None = None, + class_name: str | None = None, + ) -> str: + real_name = name or getattr(object, "__name__", "unknown_routine") + output = io.StringIO() + + # Heading level based on context (class method vs standalone function) + heading_level = 5 if class_name else 3 + output.write(f"{'#' * heading_level} `{real_name}`\n\n") + + # Function/method signature + signature = format_signature(object, class_name=class_name) + if signature and signature != "(...)": + output.write(f"```python\n{real_name}{signature}\n```\n\n") + elif real_name: + output.write(f"`{real_name}(...)`\n\n") + + # Docstring content + doc_formatted = self._format_docstring(object) + output.write(doc_formatted) + + # Add a separator only if we're not at the end of a section + if class_name: + output.write("---\n\n") + + return output.getvalue() + + def _docproperty(self, prop: property, name: str, class_name: str | None = None) -> str: + output = io.StringIO() + output.write(f"##### `{name}`\n\n") + + # Get property type annotation + type_hint_str = "" + target_for_type = prop.fget if prop.fget else prop + if hasattr(target_for_type, "__annotations__"): + try: + return_annotation = t.get_type_hints(target_for_type).get("return") + if return_annotation: + annotation = str(return_annotation) + annotation = re.sub(r"\btyping\.", "", annotation) + if class_name: + annotation = re.sub(rf"\b{re.escape(class_name)}\.", "", annotation) + annotation = format_type_annotation(annotation) + safe_annotation = annotation.replace("`", r"\`").replace("<", r"\<") + type_hint_str = f"`{safe_annotation}`" + except Exception as e: # noqa: BLE001 + logger.warning( + "Warning: Error getting type hint for property %s: %s", + getattr(prop, "__name__", "unknown_property"), + e, + ) + + # Show type compactly + if type_hint_str: + output.write(f"**Type:** {type_hint_str} *(property)*\n\n") + else: + output.write("*(property)*\n\n") + + # Documentation + doc_obj = prop + raw_doc = get_raw_doc(prop) + if not raw_doc and prop.fget: + doc_obj = prop.fget + + doc_formatted = self._format_docstring(doc_obj) + output.write(doc_formatted) + + # Property details (compact) + details = [] + if prop.fget: + details.append("getter") + if prop.fset: + details.append("setter") + if prop.fdel: + details.append("deleter") + if details: + output.write(f"*Has: {', '.join(details)}*\n\n") + + # Add separator + output.write("---\n\n") + + return output.getvalue() + + def link(self, text: str, url: str) -> str: + return f"[{text}]({url})" + + def strong(self, text: str) -> str: + return f"**{text}**" + + def emphasis(self, text: str) -> str: + return f"*{text}*" + + def escape(self, text: str) -> str: + return text.replace("<", r"\<") + + def preformat(self, text: str) -> str: + return f"```\n{text}\n```" + + def multicolumn(self, list_items: list[t.Any], fmt: t.Callable) -> str: + return "\n".join(f"- {fmt(item)}" for item in list_items) + + def grey(self, text: str) -> str: + return text + + def write(self, *args, **kwargs): + pass + + +# --- Main execution logic --- +def generate_mdx_docs( + module_paths: list[str], + output_dir: str, + auth_group: str | None = None, + project_root: str | None = None, +): + """Generates clean, traditional MDX documentation for Python modules.""" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + logger.info("Output directory: %s", output_path.resolve()) + + mdx_formatter = MDXDoc(auth_group=auth_group) + _project_root_path = Path(project_root).resolve() if project_root else Path.cwd() + _setup_sys_path(_project_root_path) + + processed_modules = set() + generated_files = [] + + for path_str in module_paths: + path = Path(path_str).resolve() + logger.info("Processing path: %s", path) + if path.is_file() and path.suffix == ".py" and path.name != "__init__.py": + _process_file(path, mdx_formatter, processed_modules, generated_files, output_path) + elif path.is_dir(): + _process_directory(path, mdx_formatter, processed_modules, generated_files, output_path) + else: + logger.warning("Warning: Path is not Python file/directory: %s", path) + + _write_docs_json(generated_files, output_path) + + +def _setup_sys_path(project_root: Path): + """Sets up the system path for module imports.""" + if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + if Path.cwd() not in sys.path: + sys.path.insert(0, Path.cwd()) + + +def _process_file(path, mdx_formatter, processed_modules, generated_files, output_path): + """Processes a single Python file.""" + module_name = _determine_module_name(path) + if not module_name: + return + + logger.info(" Attempting to import module: %s", module_name) + try: + module = import_module(module_name) + if module.__name__ in processed_modules: + return + logger.info(" Generating MDX for module: %s", module.__name__) + _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path) + except ImportError: + logger.exception("Error importing module '%s'", module_name) + except Exception: + logger.exception("Error processing module %s", module_name) + + +def _process_directory(path, mdx_formatter, processed_modules, generated_files, output_path): + """Processes a directory as a package.""" + logger.info(" Processing directory as package: %s", path.name) + package_name = path.name + if str(path.parent) not in sys.path: + sys.path.insert(0, str(path.parent)) + for _, modname, _ in pkgutil.walk_packages([str(path)], prefix=f"{package_name}."): + if modname in processed_modules: + continue + logger.info(" Attempting to import package module: %s", modname) + try: + module = import_module(modname) + logger.info(" Generating MDX for module: %s", module.__name__) + _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path) + except ImportError: + logger.exception("Error importing package module %s", modname) + except Exception: + logger.exception("Error processing package module %s", modname) + + +def _determine_module_name(path): + """Determines the module name for a given file path.""" + try: + best_match_len = -1 + module_name = None + for p_str in sys.path: + p = Path(p_str).resolve() + try: + rel_path = path.relative_to(p) + if ".." not in rel_path.parts: + current_len = len(p.parts) + if current_len > best_match_len: + best_match_len = current_len + module_name_parts = [*list(rel_path.parts[:-1]), path.stem] + module_name = ".".join(part for part in module_name_parts if part) + except ValueError: + continue + if not module_name: + module_name = path.stem + if str(path.parent) not in sys.path: + sys.path.insert(0, str(path.parent)) + else: + return module_name + except Exception as e: # noqa: BLE001 + logger.warning("Warning: Error determining module name for %s: %s", path, e) + return None + + +def _generate_mdx(module, mdx_formatter, processed_modules, generated_files, output_path): + """Generates MDX documentation for a module.""" + mdx_content = mdx_formatter.docmodule(module) + output_filename = f"{module.__name__.replace('.', '/')}.mdx" + output_file = output_path / output_filename + output_file.parent.mkdir(parents=True, exist_ok=True) + with Path.open(output_file, "w", encoding="utf-8") as f: + f.write(mdx_content) + logger.info(" -> Wrote %s", output_file) + processed_modules.add(module.__name__) + relative_path = str(output_file.relative_to(output_path.parent)).replace(".mdx", "") + generated_files.append(relative_path) + + +def _write_docs_json(generated_files, output_path): + """Writes the docs.json file.""" + docs_json_path = output_path / "docs.json" + nav_file_paths = [ + str(file_path).replace(str(output_path) + "/", "") for file_path in generated_files + ] + for i, file_path in enumerate(nav_file_paths): + if "dreadnode/" in file_path: + parts = file_path.split("/") + if len(parts) > 1 and parts[0] == "dreadnode": + parts.insert(1, "library") + nav_file_paths[i] = "/".join(parts) + nested_pages = _build_nested_structure(nav_file_paths) + docs_structure = {"group": "API Reference", "pages": nested_pages} + with Path.open(docs_json_path, "w", encoding="utf-8") as f: + json.dump(docs_structure, f, indent=2) + logger.info("Generated navigation structure written to %s", docs_json_path) + + +def _build_nested_structure(file_paths, base_prefix="dreadnode/library"): + """Builds a nested structure for the navigation based on file paths.""" + nested_structure = {} + for file_path in file_paths: + if file_path.startswith(base_prefix): + relative_path = file_path[len(base_prefix) + 1 :] + parts = relative_path.split("/") + current_level = nested_structure + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif isinstance(current_level[part], str): + current_level[part] = {"index": current_level[part]} + current_level = current_level[part] + if parts[-1] in current_level and isinstance(current_level[parts[-1]], dict): + if "dreadnode/" in file_path: + modified_path = file_path.replace("dreadnode/", "strikes/") + current_level[parts[-1]]["index"] = modified_path + else: + current_level[parts[-1]]["index"] = file_path + elif "dreadnode/" in file_path: + modified_path = file_path.replace("dreadnode/", "strikes/") + current_level[parts[-1]] = modified_path + else: + current_level[parts[-1]] = file_path + + def convert_to_list(structure): + result = [] + for key, value in sorted(structure.items()): + if isinstance(value, dict): + result.append({"group": key, "pages": convert_to_list(value)}) + else: + result.append(value) + return result + + return convert_to_list(nested_structure) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate simple, clean MDX documentation for Python modules." + ) + parser.add_argument("modules", nargs="+", help="Paths to Python files or package directories.") + parser.add_argument( + "-o", "--output-dir", default="docs", help="Directory to write MDX files (default: ./docs)." + ) + parser.add_argument( + "-p", + "--project-root", + default=None, + help="Optional path to the project root directory (assists with import resolution). Defaults to CWD.", + ) + parser.add_argument( + "-g", + "--auth-group", + choices=["crucible", "strikes", "spyglass"], + help="Optional authentication group to add to frontmatter.", + ) + + args = parser.parse_args() + generate_mdx_docs(args.modules, args.output_dir, args.auth_group, args.project_root) + logger.info("MDX generation complete.")