Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
current_task_span,
)
from dreadnode.types import (
INHERITED,
AnyDict,
Inherited,
JsonDict,
JsonValue,
)
Expand Down Expand Up @@ -412,8 +414,8 @@ def task(
name: str | None = None,
label: str | None = None,
log_params: t.Sequence[str] | bool = False,
log_inputs: t.Sequence[str] | bool = True,
log_output: bool = True,
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
log_output: bool | Inherited = INHERITED,
tags: t.Sequence[str] | None = None,
**attributes: t.Any,
) -> TaskDecorator: ...
Expand All @@ -426,8 +428,8 @@ def task(
name: str | None = None,
label: str | None = None,
log_params: t.Sequence[str] | bool = False,
log_inputs: t.Sequence[str] | bool = True,
log_output: bool = True,
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
log_output: bool | Inherited = INHERITED,
tags: t.Sequence[str] | None = None,
**attributes: t.Any,
) -> ScoredTaskDecorator[R]: ...
Expand All @@ -439,8 +441,8 @@ def task(
name: str | None = None,
label: str | None = None,
log_params: t.Sequence[str] | bool = False,
log_inputs: t.Sequence[str] | bool = True,
log_output: bool = True,
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
log_output: bool | Inherited = INHERITED,
tags: t.Sequence[str] | None = None,
**attributes: t.Any,
) -> TaskDecorator:
Expand Down Expand Up @@ -622,6 +624,7 @@ def run(
tags: t.Sequence[str] | None = None,
params: AnyDict | None = None,
project: str | None = None,
autolog: bool = True,
**attributes: t.Any,
) -> RunSpan:
"""
Expand All @@ -647,6 +650,7 @@ def run(
project: The project name to associate the run with. If not provided,
the project passed to `configure()` will be used, or the
run will be associated with a default project.
autolog: Whether to automatically log task inputs, outputs, and execution metrics if unspecified.
**attributes: Additional attributes to attach to the run span.
"""
if not self._initialized:
Expand All @@ -664,6 +668,7 @@ def run(
tags=tags,
file_system=self._fs,
prefix_path=self._fs_prefix,
autolog=autolog,
)

@handle_internal_errors()
Expand Down
3 changes: 3 additions & 0 deletions dreadnode/object.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import typing as t
from dataclasses import dataclass

from dreadnode.types import JsonDict


@dataclass
class ObjectRef:
name: str
label: str
hash: str
attributes: JsonDict


@dataclass
Expand Down
45 changes: 32 additions & 13 deletions dreadnode/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dreadnode.metric import Scorer, ScorerCallable
from dreadnode.tracing.span import TaskSpan, current_run_span
from dreadnode.types import INHERITED, Inherited

P = t.ParamSpec("P")
R = t.TypeVar("R")
Expand Down Expand Up @@ -114,9 +115,9 @@ class Task(t.Generic[P, R]):

log_params: t.Sequence[str] | bool = False
"Whether to log all, or specific, incoming arguments to the function as parameters."
log_inputs: t.Sequence[str] | bool = True
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED
"Whether to log all, or specific, incoming arguments to the function as inputs."
log_output: bool = True
log_output: bool | Inherited = INHERITED
"Whether to automatically log the result of the function as an output."

def __post_init__(self) -> None:
Expand Down Expand Up @@ -239,6 +240,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
if run is None or not run.is_recording:
raise RuntimeError("Tasks must be executed within a run")

log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs
log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output

bound_args = self._bind_args(*args, **kwargs)

params_to_log = (
Expand All @@ -250,9 +254,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
)
inputs_to_log = (
bound_args
if self.log_inputs is True
else {k: v for k, v in bound_args.items() if k in self.log_inputs}
if self.log_inputs is not False
if log_inputs is True
else {k: v for k, v in bound_args.items() if k in log_inputs}
if log_inputs is not False
else {}
)

Expand All @@ -265,13 +269,16 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
run_id=run.run_id,
tracer=self.tracer,
) as span:
span.run.log_metric(f"{self.label}.exec.count", 1, mode="count")
if run.autolog:
span.run.log_metric(
"count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True}
)

for name, value in params_to_log.items():
span.log_param(name, value)

input_object_hashes: list[str] = [
span.log_input(name, value, label=f"{self.label}.input.{name}")
span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True)
for name, value in inputs_to_log.items()
]

Expand All @@ -280,17 +287,29 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
if inspect.isawaitable(output):
output = await output
except Exception:
span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg")
if run.autolog:
span.run.log_metric(
"success_rate",
0,
prefix=f"{self.label}.exec",
mode="avg",
attributes={"auto": True},
)
raise

span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg")
if run.autolog:
span.run.log_metric(
"success_rate",
1,
prefix=f"{self.label}.exec",
mode="avg",
attributes={"auto": True},
)
span.output = output

if self.log_output:
if log_output:
output_object_hash = span.log_output(
"output",
output,
label=f"{self.label}.output",
"output", output, label=f"{self.label}.output", auto=True
)

# Link the output to the inputs
Expand Down
27 changes: 18 additions & 9 deletions dreadnode/tracing/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,15 @@ def __init__(
tracer: Tracer,
file_system: AbstractFileSystem,
prefix_path: str,
*,
params: AnyDict | None = None,
metrics: MetricDict | None = None,
run_id: str | None = None,
tags: t.Sequence[str] | None = None,
autolog: bool = True,
) -> None:
self.autolog = autolog

self._params = params or {}
self._metrics = metrics or {}
self._objects: dict[str, Object] = {}
Expand Down Expand Up @@ -486,9 +490,8 @@ def log_input(
value,
label=label,
event_name=EVENT_NAME_OBJECT_INPUT,
**attributes,
)
self._inputs.append(ObjectRef(name, label=label, hash=hash_))
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))

def log_artifact(
self,
Expand Down Expand Up @@ -528,6 +531,7 @@ def log_metric(
origin: t.Any | None = None,
timestamp: datetime | None = None,
mode: MetricAggMode | None = None,
prefix: str | None = None,
attributes: JsonDict | None = None,
) -> None: ...

Expand All @@ -539,6 +543,7 @@ def log_metric(
*,
origin: t.Any | None = None,
mode: MetricAggMode | None = None,
prefix: str | None = None,
) -> None: ...

def log_metric(
Expand All @@ -550,6 +555,7 @@ def log_metric(
origin: t.Any | None = None,
timestamp: datetime | None = None,
mode: MetricAggMode | None = None,
prefix: str | None = None,
attributes: JsonDict | None = None,
) -> None:
metric = (
Expand All @@ -560,6 +566,10 @@ def log_metric(
)
)

key = re.sub(r"[^\w/]+", "_", key.lower())
if prefix is not None:
key = f"{prefix}.{key}"

if origin is not None:
origin_hash = self.log_object(
origin,
Expand Down Expand Up @@ -590,9 +600,8 @@ def log_output(
value,
label=label,
event_name=EVENT_NAME_OBJECT_OUTPUT,
**attributes,
)
self._outputs.append(ObjectRef(name, label=label, hash=hash_))
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))


class TaskSpan(Span, t.Generic[R]):
Expand Down Expand Up @@ -694,9 +703,8 @@ def log_output(
value,
label=label,
event_name=EVENT_NAME_OBJECT_OUTPUT,
**attributes,
)
self._outputs.append(ObjectRef(name, label=label, hash=hash_))
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
return hash_

@property
Expand Down Expand Up @@ -726,9 +734,8 @@ def log_input(
value,
label=label,
event_name=EVENT_NAME_OBJECT_INPUT,
**attributes,
)
self._inputs.append(ObjectRef(name, label=label, hash=hash_))
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
return hash_

@property
Expand Down Expand Up @@ -777,6 +784,8 @@ def log_metric(
)
)

key = re.sub(r"[^\w/]+", "_", key.lower())

if origin is not None:
origin_hash = self.run.log_object(
origin,
Expand All @@ -795,7 +804,7 @@ def log_metric(
#
# Don't include `source` and `mode` as we handled it here.
if (run := current_run_span.get()) is not None:
run.log_metric(f"{self._label}.{key}", metric)
run.log_metric(key, metric, prefix=self._label)

def get_average_metric_value(self, key: str | None = None) -> float:
metrics = (
Expand Down
8 changes: 8 additions & 0 deletions dreadnode/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ def __bool__(self) -> t.Literal[False]:


UNSET: Unset = Unset()


class Inherited:
def __repr__(self) -> str:
return "Inherited"


INHERITED: Inherited = Inherited()
Loading