diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 6988eb6e..9acafcce 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -759,6 +759,405 @@ def log_inputs( ``` + + +### log\_metric + +```python +log_metric( + name: str, + value: float | bool, + *, + step: int = 0, + origin: Any | None = None, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> Metric +``` + +```python +log_metric( + name: str, + value: Metric, + *, + origin: Any | None = None, + mode: MetricAggMode | None = None, + to: ToObject = "task-or-run", +) -> Metric +``` + +```python +log_metric( + name: str, + value: float | bool | Metric, + *, + step: int = 0, + origin: Any | None = None, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> Metric +``` + +Log a single metric to the current task or run. + +Metrics are some measurement or recorded value related to the task or run. +They can be used to track performance, resource usage, or other quantitative data. + +**Examples:** + +With a raw value: + +```python +with dreadnode.run("my_run") as run: + run.log_metric("accuracy", 0.95, step=10) + run.log_metric("loss", 0.05, step=10, mode="min") +``` + +With a Metric object: + +```python +with dreadnode.run("my_run") as run: + metric = Metric(0.95, step=10, timestamp=datetime.now(timezone.utc)) + run.log_metric("accuracy", metric) +``` + +**Parameters:** + +* **`name`** + (`str`) + –The name of the metric. +* **`value`** + (`float | bool | Metric`) + –The value of the metric, either as a raw float/bool or a Metric object. +* **`step`** + (`int`, default: + `0` + ) + –The step of the metric. +* **`origin`** + (`Any | None`, default: + `None` + ) + –The origin of the metric - can be provided any object which was logged + as an input or output anywhere in the run. +* **`timestamp`** + (`datetime | None`, default: + `None` + ) + –The timestamp of the metric - defaults to the current time. +* **`mode`** + (`MetricAggMode | None`, default: + `None` + ) + –The aggregation mode to use for the metric. Helpful when you want to let + the library take care of translating your raw values into better representations. + - direct: do not modify the value at all (default) + - min: the lowest observed value reported for this metric + - max: the highest observed value reported for this metric + - avg: the average of all reported values for this metric + - sum: the cumulative sum of all reported values for this metric + - count: increment every time this metric is logged - disregard value +* **`attributes`** + (`JsonDict | None`, default: + `None` + ) + –A dictionary of additional attributes to attach to the metric. +* **`to`** + (`ToObject`, default: + `'task-or-run'` + ) + –The target object to log the metric to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metric will be logged + to the current task or run, whichever is the nearest ancestor. + +**Returns:** + +* `Metric` + –The logged metric object. + + +```python +@handle_internal_errors() +def log_metric( + self, + name: str, + value: float | bool | Metric, + *, + step: int = 0, + origin: t.Any | None = None, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> Metric: + """ + Log a single metric to the current task or run. + + Metrics are some measurement or recorded value related to the task or run. + They can be used to track performance, resource usage, or other quantitative data. + + Examples: + With a raw value: + ~~~ + with dreadnode.run("my_run") as run: + run.log_metric("accuracy", 0.95, step=10) + run.log_metric("loss", 0.05, step=10, mode="min") + ~~~ + + With a Metric object: + ~~~ + with dreadnode.run("my_run") as run: + metric = Metric(0.95, step=10, timestamp=datetime.now(timezone.utc)) + run.log_metric("accuracy", metric) + ~~~ + + Args: + name: The name of the metric. + value: The value of the metric, either as a raw float/bool or a Metric object. + step: The step of the metric. + origin: The origin of the metric - can be provided any object which was logged + as an input or output anywhere in the run. + timestamp: The timestamp of the metric - defaults to the current time. + mode: The aggregation mode to use for the metric. Helpful when you want to let + the library take care of translating your raw values into better representations. + - direct: do not modify the value at all (default) + - min: the lowest observed value reported for this metric + - max: the highest observed value reported for this metric + - avg: the average of all reported values for this metric + - sum: the cumulative sum of all reported values for this metric + - count: increment every time this metric is logged - disregard value + attributes: A dictionary of additional attributes to attach to the metric. + to: The target object to log the metric to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metric will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + The logged metric object. + """ + task = current_task_span.get() + run = current_run_span.get() + + target = (task or run) if to == "task-or-run" else run + if target is None: + raise RuntimeError("log_metric() must be called within a run") + + metric = ( + value + if isinstance(value, Metric) + else Metric( + float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + ) + ) + return target.log_metric(name, metric, origin=origin, mode=mode) +``` + + + + +### log\_metrics + +```python +log_metrics( + metrics: dict[str, float | bool], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> list[Metric] +``` + +```python +log_metrics( + metrics: list[MetricDict], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> list[Metric] +``` + +```python +log_metrics( + metrics: dict[str, float | bool] | list[MetricDict], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> list[Metric] +``` + +Log multiple metrics to the current task or run. + +**Examples:** + +Log metrics from a dictionary: + +```python +dreadnode.log_metrics( + { + "accuracy": 0.95, + "loss": 0.05, + "f1_score": 0.92 + }, + step=10 +) +``` + +Log metrics from a list of MetricDicts: + +```python +dreadnode.log_metrics( + [ + {"name": "accuracy", "value": 0.95}, + {"name": "loss", "value": 0.05, "mode": "min"} + ], + step=10 +) +``` + +**Parameters:** + +* **`metrics`** + (`dict[str, float | bool] | list[MetricDict]`) + –Either a dictionary of name/value pairs or a list of MetricDicts to log. +* **`step`** + (`int`, default: + `0` + ) + –Default step value for metrics if not supplied. +* **`timestamp`** + (`datetime | None`, default: + `None` + ) + –Default timestamp for metrics if not supplied. +* **`mode`** + (`MetricAggMode | None`, default: + `None` + ) + –Default aggregation mode for metrics if not supplied. +* **`attributes`** + (`JsonDict | None`, default: + `None` + ) + –Default attributes for metrics if not supplied. +* **`to`** + (`ToObject`, default: + `'task-or-run'` + ) + –The target object to log metrics to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metrics will be logged + to the current task or run, whichever is the nearest ancestor. + +**Returns:** + +* `list[Metric]` + –List of logged Metric objects. + + +```python +@handle_internal_errors() +def log_metrics( + self, + metrics: dict[str, float | bool] | list[MetricDict], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", +) -> list[Metric]: + """ + Log multiple metrics to the current task or run. + + Examples: + Log metrics from a dictionary: + ~~~ + dreadnode.log_metrics( + { + "accuracy": 0.95, + "loss": 0.05, + "f1_score": 0.92 + }, + step=10 + ) + ~~~ + + Log metrics from a list of MetricDicts: + ~~~ + dreadnode.log_metrics( + [ + {"name": "accuracy", "value": 0.95}, + {"name": "loss", "value": 0.05, "mode": "min"} + ], + step=10 + ) + ~~~ + + Args: + metrics: Either a dictionary of name/value pairs or a list of MetricDicts to log. + step: Default step value for metrics if not supplied. + timestamp: Default timestamp for metrics if not supplied. + mode: Default aggregation mode for metrics if not supplied. + attributes: Default attributes for metrics if not supplied. + to: The target object to log metrics to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metrics will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + List of logged Metric objects. + """ + + task = current_task_span.get() + run = current_run_span.get() + + target = (task or run) if to == "task-or-run" else run + if target is None: + raise RuntimeError("log_metrics() must be called within a run") + + logged_metrics: list[Metric] = [] + + # Dictionary of name/value pairs + if isinstance(metrics, dict): + logged_metrics = [ + target.log_metric( + name, + value, + step=step, + timestamp=timestamp, + mode=mode, + attributes=attributes, + ) + for name, value in metrics.items() + ] + + # List of MetricDicts + else: + logged_metrics = [ + target.log_metric( + metric["name"], + metric["value"], + step=metric.get("step", step), + timestamp=metric.get("timestamp", timestamp), + mode=metric.get("mode", mode), + attributes=metric.get("attributes", attributes) or {}, + ) + for metric in metrics + ] + + return logged_metrics +``` + + ### log\_output diff --git a/docs/sdk/metric.mdx b/docs/sdk/metric.mdx index 33754078..68609cde 100644 --- a/docs/sdk/metric.mdx +++ b/docs/sdk/metric.mdx @@ -6,6 +6,24 @@ title: dreadnode.metric ::: dreadnode.metric */} +MetricsDict +----------- + +```python +MetricsDict = dict[str, list[Metric]] +``` + +A dictionary of metrics, where the key is the metric name and the value is a list of metrics with that name. + +ScorerResult +------------ + +```python +ScorerResult = float | int | bool | Metric +``` + +The result of a scorer function, which can be a numeric value or a Metric object. + Metric ------ @@ -200,6 +218,11 @@ def from_many( +MetricDict +---------- + +Dictionary representation of a metric for easier APIs + Scorer ------ diff --git a/docs/usage/metrics.mdx b/docs/usage/metrics.mdx index c6c0a274..3516c504 100644 --- a/docs/usage/metrics.mdx +++ b/docs/usage/metrics.mdx @@ -58,7 +58,11 @@ with dn.run("my-experiment"): dn.log_metric("loss", 0.19, step=2) dn.log_metric("loss", 0.15, step=3) - dn.log_metric("success", True) + # Log multiple metrics at once + dn.log_metrics({ + "success": True, + "execution_time": 0.45, + }) ``` Metrics can be logged for your run as a whole (run-level) or for individual tasks within a run (task-level). Run-level metrics are generally used to track the broad performance of the system, and task-level metrics monitor more nuanced behaviors inside your flows. To make things easy, any task-level metrics will also be mirrored to the run level using the label (name) of the originating task as a prefix. This means that you can still use the same metric name in different tasks, and they will be reported separately in the UI. @@ -96,6 +100,10 @@ dn.log_metric("processing_time", 1.23, origin=document) The `origin` parameter creates a reference to the object that was measured, allowing you to track which specific inputs led to particular performance outcomes. + +When you associate scorers with tasks, the metrics they generate will automatically include the task's output objects as their origin. This makes it easy to trace back the results of your evaluations to the specific data that was processed. + + ### Aggregation Modes When working with metrics, it's important to provide context—such as averages, sums, or counts. You can always do this manually by keeping separate variables or lists of previous values. But Strikes provides a way to do this automatically for you: @@ -117,11 +125,15 @@ dn.log_metric("total_processed", 10, mode="sum") dn.log_metric("total_processed", 15, mode="sum") # Will be 25 # Count mode: count the number of times a metric is logged -dn.log_metric("failures", 1, mode="count") -dn.log_metric("failures", 1, mode="count") # Will be 2 +dn.log_metric("failures", 3, mode="count") +dn.log_metric("failures", 5, mode="count") # Will be 2 (note the values are ignored) ``` -These modes help create meaningful aggregate metrics without requiring you to manually track previous values. +These modes help create meaningful aggregate metrics without requiring you to manually track previous values and perform calculations like averages or sums. + + +The original values you log are still stored in the metric attributes, so you can always retrieve the raw data if needed. + ## Metrics in Tasks diff --git a/dreadnode/__init__.py b/dreadnode/__init__.py index f21f81c6..8ea2178d 100644 --- a/dreadnode/__init__.py +++ b/dreadnode/__init__.py @@ -15,18 +15,18 @@ task_span = DEFAULT_INSTANCE.task_span run = DEFAULT_INSTANCE.run scorer = DEFAULT_INSTANCE.scorer -task_span = DEFAULT_INSTANCE.task_span push_update = DEFAULT_INSTANCE.push_update tag = DEFAULT_INSTANCE.tag get_run_context = DEFAULT_INSTANCE.get_run_context continue_run = DEFAULT_INSTANCE.continue_run - log_metric = DEFAULT_INSTANCE.log_metric +log_metrics = DEFAULT_INSTANCE.log_metrics log_param = DEFAULT_INSTANCE.log_param log_params = DEFAULT_INSTANCE.log_params log_input = DEFAULT_INSTANCE.log_input log_inputs = DEFAULT_INSTANCE.log_inputs log_output = DEFAULT_INSTANCE.log_output +log_outputs = DEFAULT_INSTANCE.log_outputs link_objects = DEFAULT_INSTANCE.link_objects log_artifact = DEFAULT_INSTANCE.log_artifact diff --git a/dreadnode/main.py b/dreadnode/main.py index 38f54228..0bea356b 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -32,7 +32,7 @@ ENV_SERVER, ENV_SERVER_URL, ) -from dreadnode.metric import Metric, MetricAggMode, Scorer, ScorerCallable, T +from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T from dreadnode.task import P, R, Task from dreadnode.tracing.exporters import ( FileExportConfig, @@ -846,7 +846,7 @@ def log_params(self, to: ToObject = "run", **params: JsonValue) -> None: @t.overload def log_metric( self, - key: str, + name: str, value: float | bool, *, step: int = 0, @@ -869,7 +869,7 @@ def log_metric( ``` Args: - key: The name of the metric. + name: The name of the metric. value: The value of the metric. step: The step of the metric. origin: The origin of the metric - can be provided any object which was logged @@ -895,7 +895,7 @@ def log_metric( @t.overload def log_metric( self, - key: str, + name: str, value: Metric, *, origin: t.Any | None = None, @@ -915,7 +915,7 @@ def log_metric( ``` Args: - key: The name of the metric. + name: The name of the metric. value: The metric object. origin: The origin of the metric - can be provided any object which was logged as an input or output anywhere in the run. @@ -937,7 +937,7 @@ def log_metric( @handle_internal_errors() def log_metric( self, - key: str, + name: str, value: float | bool | Metric, *, step: int = 0, @@ -947,6 +947,50 @@ def log_metric( attributes: JsonDict | None = None, to: ToObject = "task-or-run", ) -> Metric: + """ + Log a single metric to the current task or run. + + Metrics are some measurement or recorded value related to the task or run. + They can be used to track performance, resource usage, or other quantitative data. + + Examples: + With a raw value: + ``` + with dreadnode.run("my_run") as run: + run.log_metric("accuracy", 0.95, step=10) + run.log_metric("loss", 0.05, step=10, mode="min") + ``` + + With a Metric object: + ``` + with dreadnode.run("my_run") as run: + metric = Metric(0.95, step=10, timestamp=datetime.now(timezone.utc)) + run.log_metric("accuracy", metric) + ``` + + Args: + name: The name of the metric. + value: The value of the metric, either as a raw float/bool or a Metric object. + step: The step of the metric. + origin: The origin of the metric - can be provided any object which was logged + as an input or output anywhere in the run. + timestamp: The timestamp of the metric - defaults to the current time. + mode: The aggregation mode to use for the metric. Helpful when you want to let + the library take care of translating your raw values into better representations. + - direct: do not modify the value at all (default) + - min: the lowest observed value reported for this metric + - max: the highest observed value reported for this metric + - avg: the average of all reported values for this metric + - sum: the cumulative sum of all reported values for this metric + - count: increment every time this metric is logged - disregard value + attributes: A dictionary of additional attributes to attach to the metric. + to: The target object to log the metric to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metric will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + The logged metric object. + """ task = current_task_span.get() run = current_run_span.get() @@ -961,7 +1005,177 @@ def log_metric( float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} ) ) - return target.log_metric(key, metric, origin=origin, mode=mode) + return target.log_metric(name, metric, origin=origin, mode=mode) + + @t.overload + def log_metrics( + self, + metrics: dict[str, float | bool], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", + ) -> list[Metric]: + """ + Log multiple metrics from a dictionary of name/value pairs. + + Examples: + ``` + dreadnode.log_metrics( + { + "accuracy": 0.95, + "loss": 0.05, + "f1_score": 0.92 + }, + step=10 + ) + ``` + + Args: + metrics: Dictionary of name/value pairs to log as metrics. + step: Step value for all metrics. + timestamp: Timestamp for all metrics. + mode: Aggregation mode for all metrics. + attributes: Attributes for all metrics. + to: The target object to log metrics to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metrics will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + List of logged Metric objects. + """ + + @t.overload + def log_metrics( + self, + metrics: list[MetricDict], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", + ) -> list[Metric]: + """ + Log multiple metrics from a list of metric configurations. + + Example: + ``` + dreadnode.log_metrics( + [ + {"name": "accuracy", "value": 0.95}, + {"name": "loss", "value": 0.05, "mode": "min"} + ], + step=10 + ) + ``` + + Args: + metrics: List of metric configurations to log. + step: Default step value for metrics if not supplied. + timestamp: Default timestamp for metrics if not supplied. + mode: Default aggregation mode for metrics if not supplied. + attributes: Default attributes for metrics if not supplied. + to: The target object to log metrics to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metrics will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + List of logged Metric objects. + """ + + @handle_internal_errors() + def log_metrics( + self, + metrics: dict[str, float | bool] | list[MetricDict], + *, + step: int = 0, + timestamp: datetime | None = None, + mode: MetricAggMode | None = None, + attributes: JsonDict | None = None, + to: ToObject = "task-or-run", + ) -> list[Metric]: + """ + Log multiple metrics to the current task or run. + + Examples: + Log metrics from a dictionary: + ``` + dreadnode.log_metrics( + { + "accuracy": 0.95, + "loss": 0.05, + "f1_score": 0.92 + }, + step=10 + ) + ``` + + Log metrics from a list of MetricDicts: + ``` + dreadnode.log_metrics( + [ + {"name": "accuracy", "value": 0.95}, + {"name": "loss", "value": 0.05, "mode": "min"} + ], + step=10 + ) + ``` + + Args: + metrics: Either a dictionary of name/value pairs or a list of MetricDicts to log. + step: Default step value for metrics if not supplied. + timestamp: Default timestamp for metrics if not supplied. + mode: Default aggregation mode for metrics if not supplied. + attributes: Default attributes for metrics if not supplied. + to: The target object to log metrics to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the metrics will be logged + to the current task or run, whichever is the nearest ancestor. + + Returns: + List of logged Metric objects. + """ + + task = current_task_span.get() + run = current_run_span.get() + + target = (task or run) if to == "task-or-run" else run + if target is None: + raise RuntimeError("log_metrics() must be called within a run") + + logged_metrics: list[Metric] = [] + + # Dictionary of name/value pairs + if isinstance(metrics, dict): + logged_metrics = [ + target.log_metric( + name, + value, + step=step, + timestamp=timestamp, + mode=mode, + attributes=attributes, + ) + for name, value in metrics.items() + ] + + # List of MetricDicts + else: + logged_metrics = [ + target.log_metric( + metric["name"], + metric["value"], + step=metric.get("step", step), + timestamp=metric.get("timestamp", timestamp), + mode=metric.get("mode", mode), + attributes=metric.get("attributes", attributes) or {}, + ) + for metric in metrics + ] + + return logged_metrics @handle_internal_errors() def log_artifact( diff --git a/dreadnode/metric.py b/dreadnode/metric.py index 148d8962..244191cc 100644 --- a/dreadnode/metric.py +++ b/dreadnode/metric.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone +import typing_extensions as te from logfire._internal.stack_info import warn_at_user_stacklevel from logfire._internal.utils import safe_repr from opentelemetry.trace import Tracer @@ -18,6 +19,18 @@ class MetricWarning(UserWarning): pass +class MetricDict(te.TypedDict, total=False): + """Dictionary representation of a metric for easier APIs""" + + name: str + value: float | bool + step: int + timestamp: datetime | None + mode: MetricAggMode | None + attributes: JsonDict | None + origin: t.Any | None + + @dataclass class Metric: """ @@ -102,9 +115,10 @@ def apply_mode(self, mode: MetricAggMode, others: "list[Metric]") -> "Metric": return self -MetricDict = dict[str, list[Metric]] - +MetricsDict = dict[str, list[Metric]] +"""A dictionary of metrics, where the key is the metric name and the value is a list of metrics with that name.""" ScorerResult = float | int | bool | Metric +"""The result of a scorer function, which can be a numeric value or a Metric object.""" ScorerCallable = t.Callable[[T], t.Awaitable[ScorerResult]] | t.Callable[[T], ScorerResult] diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 47449d7f..3c6fe1e5 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -30,7 +30,7 @@ from dreadnode.artifact.storage import ArtifactStorage from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode from dreadnode.constants import MAX_INLINE_OBJECT_BYTES -from dreadnode.metric import Metric, MetricAggMode, MetricDict +from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset @@ -242,7 +242,7 @@ def __init__( tracer: Tracer, project: str, *, - metrics: MetricDict | None = None, + metrics: MetricsDict | None = None, params: JsonDict | None = None, inputs: list[ObjectRef] | None = None, outputs: list[ObjectRef] | None = None, @@ -283,7 +283,7 @@ def __init__( prefix_path: str, *, params: AnyDict | None = None, - metrics: MetricDict | None = None, + metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, autolog: bool = True, update_frequency: int = 5, @@ -625,13 +625,13 @@ def log_artifact( self._artifacts = self._artifact_merger.get_merged_trees() @property - def metrics(self) -> MetricDict: + def metrics(self) -> MetricsDict: return self._metrics @t.overload def log_metric( self, - key: str, + name: str, value: float | bool, *, step: int = 0, @@ -645,7 +645,7 @@ def log_metric( @t.overload def log_metric( self, - key: str, + name: str, value: Metric, *, origin: t.Any | None = None, @@ -655,7 +655,7 @@ def log_metric( def log_metric( self, - key: str, + name: str, value: float | bool | Metric, *, step: int = 0, @@ -673,7 +673,7 @@ def log_metric( ) ) - key = clean_str(key) + key = clean_str(name) if prefix is not None: key = f"{prefix}.{key}" @@ -726,7 +726,7 @@ def __init__( *, label: str | None = None, params: AnyDict | None = None, - metrics: MetricDict | None = None, + metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, ) -> None: self._params = params or {} @@ -857,7 +857,7 @@ def metrics(self) -> dict[str, list[Metric]]: @t.overload def log_metric( self, - key: str, + name: str, value: float | bool, *, step: int = 0, @@ -870,7 +870,7 @@ def log_metric( @t.overload def log_metric( self, - key: str, + name: str, value: Metric, *, origin: t.Any | None = None, @@ -879,7 +879,7 @@ def log_metric( def log_metric( self, - key: str, + name: str, value: float | bool | Metric, *, step: int = 0, @@ -896,7 +896,7 @@ def log_metric( ) ) - key = clean_str(key) + key = clean_str(name) # For every metric we log, also log it to the run # with our `label` as a prefix.