From 583a5a8e2bb6c68a7d10144c71b4f67c9c0003f7 Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Mon, 28 Apr 2025 13:13:28 -0500 Subject: [PATCH 1/5] wip --- dreadnode/api/client.py | 160 ++++++++++++++++++++++++- dreadnode/integrations/transformers.py | 106 ++++++++++++++++ 2 files changed, 260 insertions(+), 6 deletions(-) diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 674f8c1f..6c047f24 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -26,7 +26,11 @@ class ApiClient: - """Client for the Dreadnode API.""" + """Client for the Dreadnode API. + + This class provides methods to interact with the Dreadnode API, including + retrieving projects, runs, tasks, and exporting data. + """ def __init__( self, @@ -35,6 +39,13 @@ def __init__( *, debug: bool = False, ): + """Initializes the API client. + + Args: + base_url (str): The base URL of the Dreadnode API. + api_key (str): The API key for authentication. + debug (bool, optional): Whether to enable debug logging. Defaults to False. + """ self._base_url = base_url.rstrip("/") if not self._base_url.endswith("/api"): self._base_url += "/api" @@ -54,7 +65,11 @@ def __init__( self._client.event_hooks["response"].append(self._log_response) def _log_request(self, request: httpx.Request) -> None: - """Log every request to the console if debug is enabled.""" + """Logs HTTP requests if debug mode is enabled. + + Args: + request (httpx.Request): The HTTP request object. + """ logger.debug("-------------------------------------------") logger.debug("%s %s", request.method, request.url) @@ -63,7 +78,11 @@ def _log_request(self, request: httpx.Request) -> None: logger.debug("-------------------------------------------") def _log_response(self, response: httpx.Response) -> None: - """Log every response to the console if debug is enabled.""" + """Logs HTTP responses if debug mode is enabled. + + Args: + response (httpx.Response): The HTTP response object. + """ logger.debug("-------------------------------------------") logger.debug("Response: %s", response.status_code) @@ -72,7 +91,14 @@ def _log_response(self, response: httpx.Response) -> None: logger.debug("--------------------------------------------") def _get_error_message(self, response: httpx.Response) -> str: - """Get the error message from the response.""" + """Extracts the error message from an HTTP response. + + Args: + response (httpx.Response): The HTTP response object. + + Returns: + str: The error message extracted from the response. + """ try: obj = response.json() @@ -87,7 +113,17 @@ def _request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Make a raw request to the API.""" + """Makes a raw HTTP request to the API. + + Args: + method (str): The HTTP method (e.g., "GET", "POST"). + path (str): The API endpoint path. + params (dict[str, Any] | None, optional): Query parameters for the request. Defaults to None. + json_data (dict[str, Any] | None, optional): JSON payload for the request. Defaults to None. + + Returns: + httpx.Response: The HTTP response object. + """ return self._client.request(method, path, json=json_data, params=params) @@ -98,7 +134,20 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Make a request to the API. Raise an exception for non-200 status codes.""" + """Makes an HTTP request to the API and raises exceptions for errors. + + Args: + method (str): The HTTP method (e.g., "GET", "POST"). + path (str): The API endpoint path. + params (dict[str, Any] | None, optional): Query parameters for the request. Defaults to None. + json_data (dict[str, Any] | None, optional): JSON payload for the request. Defaults to None. + + Returns: + httpx.Response: The HTTP response object. + + Raises: + RuntimeError: If the response status code indicates an error. + """ response = self._request(method, path, params, json_data) if response.status_code == 401: # noqa: PLR2004 @@ -117,26 +166,71 @@ def request( # return UserResponse(**response.json()) def list_projects(self) -> list[Project]: + """Retrieves a list of projects. + + Returns: + list[Project]: A list of Project objects. + """ response = self.request("GET", "/strikes/projects") return [Project(**project) for project in response.json()] def get_project(self, project: str) -> Project: + """Retrieves details of a specific project. + + Args: + project (str): The project identifier. + + Returns: + Project: The Project object. + """ response = self.request("GET", f"/strikes/projects/{project!s}") return Project(**response.json()) def list_runs(self, project: str) -> list[Run]: + """Retrieves a list of runs for a specific project. + + Args: + project (str): The project identifier. + + Returns: + list[Run]: A list of Run objects. + """ response = self.request("GET", f"/strikes/projects/{project!s}/runs") return [Run(**run) for run in response.json()] def get_run(self, run: str | ULID) -> Run: + """Retrieves details of a specific run. + + Args: + run (str | ULID): The run identifier. + + Returns: + Run: The Run object. + """ response = self.request("GET", f"/strikes/projects/runs/{run!s}") return Run(**response.json()) def get_run_tasks(self, run: str | ULID) -> list[Task]: + """Retrieves tasks for a specific run. + + Args: + run (str | ULID): The run identifier. + + Returns: + list[Task]: A list of Task objects. + """ response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks") return [Task(**task) for task in response.json()] def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]: + """Retrieves trace spans for a specific run. + + Args: + run (str | ULID): The run identifier. + + Returns: + list[Task | TraceSpan]: A list of Task or TraceSpan objects. + """ response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans") spans: list[Task | TraceSpan] = [] for item in response.json(): @@ -157,6 +251,17 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports run data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of runs to include. Defaults to "completed". + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported run data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export", @@ -179,6 +284,18 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports metric data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of metrics to include. Defaults to "completed". + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported metric data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/metrics", @@ -203,6 +320,19 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: + """Exports parameter data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of parameters to include. Defaults to "completed". + parameters (list[str] | None, optional): A list of parameter names to include. Defaults to None. + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + aggregations (list[MetricAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported parameter data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/parameters", @@ -228,6 +358,19 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: + """Exports timeseries data for a specific project. + + Args: + project (str): The project identifier. + filter (str | None, optional): A filter to apply to the exported data. Defaults to None. + status (StatusFilter, optional): The status of timeseries to include. Defaults to "completed". + metrics (list[str] | None, optional): A list of metric names to include. Defaults to None. + time_axis (TimeAxisType, optional): The type of time axis to use. Defaults to "relative". + aggregations (list[TimeAggregationType] | None, optional): A list of aggregation types to apply. Defaults to None. + + Returns: + pd.DataFrame: A DataFrame containing the exported timeseries data. + """ response = self.request( "GET", f"/strikes/projects/{project!s}/export/timeseries", @@ -245,5 +388,10 @@ def export_timeseries( # User data access def get_user_data_credentials(self) -> UserDataCredentials: + """Retrieves user data credentials. + + Returns: + UserDataCredentials: The user data credentials object. + """ response = self.request("GET", "/user-data/credentials") return UserDataCredentials(**response.json()) diff --git a/dreadnode/integrations/transformers.py b/dreadnode/integrations/transformers.py index c21059b4..a873f22c 100644 --- a/dreadnode/integrations/transformers.py +++ b/dreadnode/integrations/transformers.py @@ -1,3 +1,10 @@ +""" +This module provides an integration with the `transformers` library for logging +metrics and parameters to Dreadnode during training. It includes a custom +`TrainerCallback` implementation that tracks training progress and logs relevant +information to Dreadnode. +""" + import importlib.util if importlib.util.find_spec("transformers") is None: @@ -21,6 +28,15 @@ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]: + """ + Cleans the keys of a dictionary by replacing certain prefixes with slashes. + + Args: + data (dict[str, t.Any]): The dictionary to clean. + + Returns: + dict[str, t.Any]: A new dictionary with cleaned keys. + """ cleaned: dict[str, t.Any] = {} for key, val in data.items(): _key = key.replace("eval_", "eval/").replace("test_", "test/").replace("train_", "train/") @@ -34,6 +50,12 @@ class DreadnodeCallback(TrainerCallback): # type: ignore [misc] This callback is used to log metrics and parameters to Dreadnode during training inside the `transformers` library or derivations (`trl`, etc.). + + + Attributes: + project (str | None): The project name in Dreadnode. + run_name (str | None): The name of the training run. + tags (list[str]): A list of tags associated with the run. """ def __init__( @@ -42,6 +64,14 @@ def __init__( run_name: str | None = None, tags: list[str] | None = None, ): + """ + Initializes the DreadnodeCallback. + + Args: + project (str | None): The project name in Dreadnode. + run_name (str | None): The name of the training run. + tags (list[str] | None): A list of tags associated with the run. + """ self.project = project self.run_name = run_name self.tags = tags or [] @@ -52,6 +82,9 @@ def __init__( self._step_span: Span | None = None def _shutdown(self) -> None: + """ + Shuts down the callback by closing any active spans and the run. + """ if self._step_span is not None: self._step_span.__exit__(None, None, None) self._step_span = None @@ -65,6 +98,14 @@ def _shutdown(self) -> None: self._run = None def _setup(self, args: TrainingArguments, state: TrainerState, model: t.Any) -> None: + """ + Sets up the callback by initializing the Dreadnode run and logging parameters. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + model (t.Any): The model being trained. + """ if self._initialized: return @@ -105,6 +146,16 @@ def on_train_begin( model: t.Any | None = None, **kwargs: t.Any, ) -> None: + """ + Called at the beginning of training. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + model (t.Any | None): The model being trained. + **kwargs (t.Any): Additional keyword arguments. + """ if not self._initialized: self._setup(args, state, model) @@ -115,6 +166,15 @@ def on_train_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of training. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ self._shutdown() def on_epoch_begin( @@ -124,6 +184,15 @@ def on_epoch_begin( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the beginning of an epoch. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._run is None: return @@ -139,6 +208,15 @@ def on_epoch_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of an epoch. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._epoch_span is not None: self._epoch_span.__exit__(None, None, None) self._epoch_span = None @@ -150,6 +228,15 @@ def on_step_begin( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the beginning of a training step. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._run is None: return @@ -165,6 +252,15 @@ def on_step_end( control: TrainerControl, **kwargs: t.Any, ) -> None: + """ + Called at the end of a training step. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + **kwargs (t.Any): Additional keyword arguments. + """ if self._step_span is not None: self._step_span.__exit__(None, None, None) self._step_span = None @@ -177,6 +273,16 @@ def on_log( logs: dict[str, t.Any] | None = None, **kwargs: t.Any, ) -> None: + """ + Called when logs are reported. + + Args: + args (TrainingArguments): The training arguments. + state (TrainerState): The state of the trainer. + control (TrainerControl): The control object for the trainer. + logs (dict[str, t.Any] | None): The logs to process. + **kwargs (t.Any): Additional keyword arguments. + """ if self._run is None or logs is None: return From c139c6f5bfee1ea9c8385fe47185f0a0cbc3774a Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Fri, 23 May 2025 12:35:26 -0500 Subject: [PATCH 2/5] chore: stubbed workflow --- .github/workflows/docs.yaml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .github/workflows/docs.yaml diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 00000000..70b5fe12 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,20 @@ +name: Notify Documentation Update + +on: + push: + branches: [main] + paths: + - "docs/**" + +jobs: + notify-docs: + runs-on: ubuntu-latest + steps: + - name: Trigger docs repository workflow + run: echo "Triggering docs repository workflow not yet implemented" + # uses: peter-evans/repository-dispatch@v2 + # with: + # token: ${{ secrets.DOCS_REPO_PAT }} + # repository: dreadnode/prod-docs + # event-type: code-update + # client-payload: '{"repository": "${{ github.repository }}", "ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "product": "sdk"}' From 382e439db69a55e954206376799d2aadbaf7160e Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Fri, 23 May 2025 12:45:58 -0500 Subject: [PATCH 3/5] chore: fixed 3.10 typing issue with backslash in fstring --- pyproject.toml | 5 ++++- scripts/make_docs.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1234fd46..d7ced9c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,9 @@ logfire = "^3.5.3" python-ulid = "^3.0.0" coolname = "^2.2.0" pandas = "^2.2.3" -fsspec = {version = ">=2023.1.0,<=2025.3.0", extras = ["s3"]} # Pinned for datasets compatibility +fsspec = { version = ">=2023.1.0,<=2025.3.0", extras = [ + "s3", +] } # Pinned for datasets compatibility transformers = { version = "^4.41.0", optional = true } @@ -59,6 +61,7 @@ asyncio_default_fixture_loop_scope = "function" [tool.mypy] strict = true +python_version = "3.10" [tool.bandit] exclude_dirs = ["tests"] diff --git a/scripts/make_docs.py b/scripts/make_docs.py index 00e9cc6f..e47ffebe 100644 --- a/scripts/make_docs.py +++ b/scripts/make_docs.py @@ -138,7 +138,8 @@ def _write_parameters(self, parsed: Docstring, output: io.StringIO) -> None: param_header += " *(optional)*" output.write(f"- {param_header}") if param.description: - output.write(f": {param.description.replace('<', r'\\<')}") + updated_description = param.description.replace("<", r"\<") + output.write(f": {updated_description}") if param.default: safe_default = param.default.replace("`", r"\`") output.write(f" Default: `{safe_default}`") From 969bda68fee61608cde1b860391c82c826e6acaf Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Sun, 25 May 2025 14:14:03 -0500 Subject: [PATCH 4/5] fix: updated docs workflow- tokens and client payload --- .github/workflows/docs.yaml | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 70b5fe12..484004b0 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -5,16 +5,27 @@ on: branches: [main] paths: - "docs/**" + - "scripts/make_docs.py" + workflow_dispatch: jobs: notify-docs: runs-on: ubuntu-latest steps: + - uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 + id: app-token + with: + app-id: ${{ vars.UPDATE_DOCS_APP_ID }} + private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }} + owner: "${{ github.repository_owner }}" + repositories: | + spyglass + prod-docs + - name: Trigger docs repository workflow - run: echo "Triggering docs repository workflow not yet implemented" - # uses: peter-evans/repository-dispatch@v2 - # with: - # token: ${{ secrets.DOCS_REPO_PAT }} - # repository: dreadnode/prod-docs - # event-type: code-update - # client-payload: '{"repository": "${{ github.repository }}", "ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "product": "sdk"}' + uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0 + with: + token: ${{ steps.app-token.outputs.token }} + repository: dreadnode/prod-docs + event-type: code-update + client-payload: '{"repository": "${{ github.repository }}", "ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "product": "strikes", "docs_dir": "docs", "module_dir": "dreadnode"}' From 4e331ce79d78f46d014a687bb977df80c069c4d7 Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Sun, 25 May 2025 14:14:35 -0500 Subject: [PATCH 5/5] fix: updated doc workflow- changed repo to sdk --- .github/workflows/docs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 484004b0..f58663f3 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -19,7 +19,7 @@ jobs: private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }} owner: "${{ github.repository_owner }}" repositories: | - spyglass + sdk prod-docs - name: Trigger docs repository workflow