Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions dreadnode/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
from pydantic import BaseModel
from ulid import ULID

from dreadnode.api.util import (
convert_flat_tasks_to_tree,
convert_flat_trace_to_tree,
process_run,
process_task,
)
from dreadnode.util import logger
from dreadnode.version import VERSION

from .models import (
MetricAggregationType,
Project,
RawRun,
RawTask,
Run,
RunSummary,
StatusFilter,
Task,
TaskTree,
TimeAggregationType,
TimeAxisType,
TraceSpan,
TraceTree,
UserDataCredentials,
)

Expand Down Expand Up @@ -119,27 +130,63 @@ def get_project(self, project: str) -> Project:
response = self.request("GET", f"/strikes/projects/{project!s}")
return Project(**response.json())

def list_runs(self, project: str) -> list[Run]:
def list_runs(self, project: str) -> list[RunSummary]:
response = self.request("GET", f"/strikes/projects/{project!s}/runs")
return [Run(**run) for run in response.json()]
return [RunSummary(**run) for run in response.json()]

def get_run(self, run: str | ULID) -> Run:
def _get_run(self, run: str | ULID) -> RawRun:
response = self.request("GET", f"/strikes/projects/runs/{run!s}")
return Run(**response.json())
return RawRun(**response.json())

def get_run_tasks(self, run: str | ULID) -> list[Task]:
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
return [Task(**task) for task in response.json()]
def get_run(self, run: str | ULID) -> Run:
return process_run(self._get_run(run))

TraceFormat = t.Literal["tree", "flat"]

def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
@t.overload
def get_run_tasks(
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
) -> list[TaskTree]: ...

@t.overload
def get_run_tasks(
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
) -> list[Task]: ...

def get_run_tasks(
self, run: str | ULID, *, format: TraceFormat = "flat"
) -> list[Task] | list[TaskTree]:
raw_run = self._get_run(run)
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
raw_tasks = [RawTask(**task) for task in response.json()]
tasks = [process_task(task, raw_run) for task in raw_tasks]
tasks = sorted(tasks, key=lambda x: x.timestamp)
return tasks if format == "flat" else convert_flat_tasks_to_tree(tasks)

@t.overload
def get_run_trace(
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
) -> list[TraceTree]: ...

@t.overload
def get_run_trace(
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
) -> list[Task | TraceSpan]: ...

def get_run_trace(
self, run: str | ULID, *, format: TraceFormat = "flat"
) -> list[Task | TraceSpan] | list[TraceTree]:
raw_run = self._get_run(run)
response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans/full")
spans: list[Task | TraceSpan] = []
trace: list[Task | TraceSpan] = []
for item in response.json():
if "parent_task_span_id" in item:
spans.append(Task(**item))
trace.append(process_task(RawTask(**item), raw_run))
else:
spans.append(TraceSpan(**item))
return spans
trace.append(TraceSpan(**item))

trace = sorted(trace, key=lambda x: x.timestamp)
return trace if format == "flat" else convert_flat_trace_to_tree(trace)

# Data exports

Expand Down
176 changes: 135 additions & 41 deletions dreadnode/api/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import contextlib
import typing as t
from datetime import datetime
from functools import cached_property
from uuid import UUID

from pydantic import BaseModel, Field
import requests
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
TypeAdapter,
ValidationError,
field_validator,
)
from ulid import ULID

AnyDict = dict[str, t.Any]
Expand Down Expand Up @@ -79,17 +90,17 @@ class TraceLog(BaseModel):
class TraceSpan(BaseModel):
timestamp: datetime
duration: int
trace_id: str
trace_id: str = Field(repr=False)
span_id: str
parent_span_id: str | None
service_name: str | None
parent_span_id: str | None = Field(repr=False)
service_name: str | None = Field(repr=False)
status: SpanStatus
exception: SpanException | None
name: str
attributes: AnyDict
resource_attributes: AnyDict
events: list[SpanEvent]
links: list[SpanLink]
attributes: AnyDict = Field(repr=False)
resource_attributes: AnyDict = Field(repr=False)
events: list[SpanEvent] = Field(repr=False)
links: list[SpanLink] = Field(repr=False)


class Metric(BaseModel):
Expand All @@ -105,22 +116,22 @@ class ObjectRef(BaseModel):
hash: str


class ObjectUri(BaseModel):
class RawObjectUri(BaseModel):
hash: str
schema_hash: str
uri: str
size: int
type: t.Literal["uri"]


class ObjectVal(BaseModel):
class RawObjectVal(BaseModel):
hash: str
schema_hash: str
value: t.Any
type: t.Literal["val"]


Object = ObjectUri | ObjectVal
RawObject = RawObjectUri | RawObjectVal


class V0Object(BaseModel):
Expand All @@ -129,56 +140,141 @@ class V0Object(BaseModel):
value: t.Any


class Run(BaseModel):
class ObjectVal(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
label: str
hash: str = Field(repr=False)
schema_: AnyDict
schema_hash: str = Field(repr=False)
value: t.Any

@field_validator("value")
@classmethod
def validate_value(cls, value: t.Any) -> t.Any:
if isinstance(value, str):
with contextlib.suppress(ValidationError):
return TypeAdapter(t.Any).validate_json(value)

return value


class ObjectUri(BaseModel):
name: str
label: str
hash: str = Field(repr=False)
schema_: AnyDict
schema_hash: str = Field(repr=False)
uri: str
size: int

_value: t.Any = PrivateAttr(default=None)

@cached_property
def value(self) -> t.Any:
if self._value is not None:
return self._value

try:
response = requests.get(self.uri, timeout=5)
response.raise_for_status()
self._value = response.text
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch object from {self.uri}") from e

if isinstance(self._value, str):
with contextlib.suppress(ValidationError):
self._value = TypeAdapter(t.Any).validate_json(self._value)

return self._value


Object = ObjectVal | ObjectUri


class ArtifactFile(BaseModel):
hash: str
uri: str
size_bytes: int
final_real_path: str


class ArtifactDir(BaseModel):
dir_path: str
hash: str
children: list[t.Union["ArtifactDir", ArtifactFile]]


class RunSummary(BaseModel):
id: ULID
name: str
span_id: str
trace_id: str
span_id: str = Field(repr=False)
trace_id: str = Field(repr=False)
timestamp: datetime
duration: int
status: SpanStatus
exception: SpanException | None
tags: set[str]
params: AnyDict
metrics: dict[str, list[Metric]]
inputs: list[ObjectRef]
outputs: list[ObjectRef]
objects: dict[str, Object]
object_schemas: AnyDict
schema_: AnyDict = Field(alias="schema")
params: AnyDict = Field(repr=False)
metrics: dict[str, list[Metric]] = Field(repr=False)


class Task(BaseModel):
class RawRun(RunSummary):
inputs: list[ObjectRef] = Field(repr=False)
outputs: list[ObjectRef] = Field(repr=False)
objects: dict[str, RawObject] = Field(repr=False)
object_schemas: AnyDict = Field(repr=False)
artifacts: list[ArtifactDir] = Field(repr=False)
schema_: AnyDict = Field(alias="schema", repr=False)


class Run(RunSummary):
inputs: dict[str, Object] = Field(repr=False)
outputs: dict[str, Object] = Field(repr=False)
artifacts: list[ArtifactDir] = Field(repr=False)
schema_: AnyDict = Field(alias="schema", repr=False)


class _Task(BaseModel):
name: str
span_id: str
trace_id: str
parent_span_id: str | None
parent_task_span_id: str | None
trace_id: str = Field(repr=False)
parent_span_id: str | None = Field(repr=False)
parent_task_span_id: str | None = Field(repr=False)
timestamp: datetime
duration: int
status: SpanStatus
exception: SpanException | None
tags: set[str]
params: AnyDict
metrics: dict[str, list[Metric]]
inputs: list[ObjectRef] | list[V0Object] # v0 compat
outputs: list[ObjectRef] | list[V0Object] # v0 compat
schema_: AnyDict = Field(alias="schema")
attributes: AnyDict
resource_attributes: AnyDict
events: list[SpanEvent]
links: list[SpanLink]
params: AnyDict = Field(repr=False)
metrics: dict[str, list[Metric]] = Field(repr=False)
schema_: AnyDict = Field(alias="schema", repr=False)
attributes: AnyDict = Field(repr=False)
resource_attributes: AnyDict = Field(repr=False)
events: list[SpanEvent] = Field(repr=False)
links: list[SpanLink] = Field(repr=False)


class RawTask(_Task):
inputs: list[ObjectRef] | list[V0Object] = Field(repr=False)
outputs: list[ObjectRef] | list[V0Object] = Field(repr=False)


class Task(_Task):
inputs: dict[str, Object] = Field(repr=False)
outputs: dict[str, Object] = Field(repr=False)


class Project(BaseModel):
id: UUID
id: UUID = Field(repr=False)
key: str
name: str
description: str | None
description: str | None = Field(repr=False)
created_at: datetime
updated_at: datetime
run_count: int
last_run: Run | None
last_run: RawRun | None = Field(repr=False)


# Derived types
Expand All @@ -189,11 +285,9 @@ class TaskTree(BaseModel):
children: list["TaskTree"] = []


class SpanTree(BaseModel):
"""Tree representation of a trace span with its children"""

class TraceTree(BaseModel):
span: Task | TraceSpan
children: list["SpanTree"] = []
children: list["TraceTree"] = []


# User data credentials
Expand Down
Loading