Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def log_metric(
mode: MetricAggMode | None = None,
attributes: JsonDict | None = None,
to: ToObject = "task-or-run",
) -> None:
) -> Metric:
"""
Log a single metric to the current task or run.

Expand Down Expand Up @@ -809,6 +809,9 @@ def log_metric(
to: The target object to log the metric to. Can be "task-or-run" or "run".
Defaults to "task-or-run". If "task-or-run", the metric will be logged
to the current task or run, whichever is the nearest ancestor.

Returns:
The logged metric object.
"""

@t.overload
Expand All @@ -820,7 +823,7 @@ def log_metric(
origin: t.Any | None = None,
mode: MetricAggMode | None = None,
to: ToObject = "task-or-run",
) -> None:
) -> Metric:
"""
Log a single metric to the current task or run.

Expand Down Expand Up @@ -848,6 +851,9 @@ def log_metric(
to: The target object to log the metric to. Can be "task-or-run" or "run".
Defaults to "task-or-run". If "task-or-run", the metric will be logged
to the current task or run, whichever is the nearest ancestor.

Returns:
The logged metric object.
"""

@handle_internal_errors()
Expand All @@ -862,7 +868,7 @@ def log_metric(
mode: MetricAggMode | None = None,
attributes: JsonDict | None = None,
to: ToObject = "task-or-run",
) -> None:
) -> Metric:
task = current_task_span.get()
run = current_run_span.get()

Expand All @@ -877,7 +883,7 @@ def log_metric(
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
)
)
target.log_metric(key, metric, origin=origin, mode=mode)
return target.log_metric(key, metric, origin=origin, mode=mode)

@handle_internal_errors()
def log_artifact(
Expand Down
5 changes: 3 additions & 2 deletions dreadnode/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def apply_mode(self, mode: MetricAggMode, others: "list[Metric]") -> "Metric":
This will modify the metric in place.

Args:
mode: The mode to apply. One of "sum", "min", "max", or "inc".
mode: The mode to apply. One of "sum", "min", "max", or "count".
others: A list of other metrics to apply the mode to.

Returns:
Expand All @@ -87,7 +87,8 @@ def apply_mode(self, mode: MetricAggMode, others: "list[Metric]") -> "Metric":
prior_values = [m.value for m in sorted(others, key=lambda m: m.timestamp)]

if mode == "sum":
self.value += max(prior_values)
# Take the max of the priors because they might already be summed
self.value += max(prior_values) if prior_values else 0
elif mode == "min":
self.value = min([self.value, *prior_values])
elif mode == "max":
Expand Down
38 changes: 17 additions & 21 deletions dreadnode/tracing/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def log_metric(
mode: MetricAggMode | None = None,
prefix: str | None = None,
attributes: JsonDict | None = None,
) -> None: ...
) -> Metric: ...

@t.overload
def log_metric(
Expand All @@ -544,7 +544,7 @@ def log_metric(
origin: t.Any | None = None,
mode: MetricAggMode | None = None,
prefix: str | None = None,
) -> None: ...
) -> Metric: ...

def log_metric(
self,
Expand All @@ -557,7 +557,7 @@ def log_metric(
mode: MetricAggMode | None = None,
prefix: str | None = None,
attributes: JsonDict | None = None,
) -> None:
) -> Metric:
metric = (
value
if isinstance(value, Metric)
Expand All @@ -583,6 +583,8 @@ def log_metric(
metric = metric.apply_mode(mode, metrics)
metrics.append(metric)

return metric

@property
def outputs(self) -> AnyDict:
return {ref.name: self.get_object(ref.hash) for ref in self._outputs}
Expand Down Expand Up @@ -753,7 +755,7 @@ def log_metric(
timestamp: datetime | None = None,
mode: MetricAggMode | None = None,
attributes: JsonDict | None = None,
) -> None: ...
) -> Metric: ...

@t.overload
def log_metric(
Expand All @@ -763,7 +765,7 @@ def log_metric(
*,
origin: t.Any | None = None,
mode: MetricAggMode | None = None,
) -> None: ...
) -> Metric: ...

def log_metric(
self,
Expand All @@ -775,7 +777,7 @@ def log_metric(
timestamp: datetime | None = None,
mode: MetricAggMode | None = None,
attributes: JsonDict | None = None,
) -> None:
) -> Metric:
metric = (
value
if isinstance(value, Metric)
Expand All @@ -786,25 +788,19 @@ def log_metric(

key = re.sub(r"[^\w/]+", "_", key.lower())

if origin is not None:
origin_hash = self.run.log_object(
origin,
label=key,
event_name=EVENT_NAME_OBJECT_METRIC,
)
metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash

metrics = self._metrics.setdefault(key, [])
if mode is not None:
metric = metric.apply_mode(mode, metrics)
metrics.append(metric)

# For every metric we log, also log it to the run
# with our `label` as a prefix.
#
# Don't include `source` and `mode` as we handled it here.
# Let the run handle the origin and mode aggregation
# for us as we don't have access to the other times
# this task-metric was logged here.

if (run := current_run_span.get()) is not None:
run.log_metric(key, metric, prefix=self._label)
metric = run.log_metric(key, metric, prefix=self._label, origin=origin, mode=mode)

self._metrics.setdefault(key, []).append(metric)

return metric

def get_average_metric_value(self, key: str | None = None) -> float:
metrics = (
Expand Down