Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions docs/sdk/main.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,11 @@ def link_objects(
attributes: Additional attributes to attach to the link.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("link() must be called within a run")
warn_at_user_stacklevel(
"link_objects() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

origin_hash = run.log_object(origin)
link_hash = run.log_object(link)
Expand Down Expand Up @@ -652,7 +656,11 @@ def log_artifact(
local_uri: The local path to the file to upload.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("log_artifact() must be called within a run")
warn_at_user_stacklevel(
"log_artifact() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.log_artifact(local_uri=local_uri)
```
Expand Down Expand Up @@ -729,7 +737,11 @@ def log_input(

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_inputs() must be called within a run")
warn_at_user_stacklevel(
"log_input() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.log_input(name, value, label=label, attributes=attributes)
```
Expand Down Expand Up @@ -946,13 +958,6 @@ def log_metric(
Returns:
The logged metric object.
"""
task = current_task_span.get()
run = current_run_span.get()

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_metric() must be called within a run")

metric = (
value
if isinstance(value, Metric)
Expand All @@ -963,6 +968,18 @@ def log_metric(
attributes or {},
)
)

task = current_task_span.get()
run = current_run_span.get()

target = (task or run) if to == "task-or-run" else run
if target is None:
warn_at_user_stacklevel(
"log_metric() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return metric

return target.log_metric(name, metric, origin=origin, mode=mode)
```

Expand Down Expand Up @@ -1133,7 +1150,11 @@ def log_metrics(

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_metrics() must be called within a run")
warn_at_user_stacklevel(
"log_metrics() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return []

logged_metrics: list[Metric] = []

Expand Down Expand Up @@ -1276,9 +1297,11 @@ def log_output(

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError(
"log_output() must be called within a run or a task",
warn_at_user_stacklevel(
"log_output() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.log_output(name, value, label=label, attributes=attributes)
```
Expand Down Expand Up @@ -1431,7 +1454,12 @@ def log_params(self, **params: JsonValue) -> None:
**params: The parameters to log. Each parameter is a key-value pair.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("Parameters must be logged within a run")
warn_at_user_stacklevel(
"log_params() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.log_params(**params)
```

Expand Down Expand Up @@ -1484,7 +1512,11 @@ def push_update(self) -> None:
# do more work
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("Run updates must be pushed within a run")
warn_at_user_stacklevel(
"push_update() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.push_update(force=True)
```
Expand Down Expand Up @@ -1910,7 +1942,11 @@ def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("Tagging must be done within a run")
warn_at_user_stacklevel(
"tag() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.add_tags(tag)
```
Expand Down
68 changes: 52 additions & 16 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,11 @@ def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("Tagging must be done within a run")
warn_at_user_stacklevel(
"tag() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.add_tags(tag)

Expand All @@ -883,7 +887,11 @@ def push_update(self) -> None:
# do more work
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("Run updates must be pushed within a run")
warn_at_user_stacklevel(
"push_update() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.push_update(force=True)

Expand Down Expand Up @@ -934,7 +942,12 @@ def log_params(self, **params: JsonValue) -> None:
**params: The parameters to log. Each parameter is a key-value pair.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("Parameters must be logged within a run")
warn_at_user_stacklevel(
"log_params() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.log_params(**params)

@t.overload
Expand Down Expand Up @@ -1085,13 +1098,6 @@ def log_metric(
Returns:
The logged metric object.
"""
task = current_task_span.get()
run = current_run_span.get()

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_metric() must be called within a run")

metric = (
value
if isinstance(value, Metric)
Expand All @@ -1102,6 +1108,18 @@ def log_metric(
attributes or {},
)
)

task = current_task_span.get()
run = current_run_span.get()

target = (task or run) if to == "task-or-run" else run
if target is None:
warn_at_user_stacklevel(
"log_metric() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return metric
Comment thread
monoxgas marked this conversation as resolved.

return target.log_metric(name, metric, origin=origin, mode=mode)

@t.overload
Expand Down Expand Up @@ -1240,7 +1258,11 @@ def log_metrics(

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_metrics() must be called within a run")
warn_at_user_stacklevel(
"log_metrics() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return []

logged_metrics: list[Metric] = []

Expand Down Expand Up @@ -1312,7 +1334,11 @@ def log_artifact(
local_uri: The local path to the file to upload.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("log_artifact() must be called within a run")
warn_at_user_stacklevel(
"log_artifact() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

run.log_artifact(local_uri=local_uri)

Expand Down Expand Up @@ -1350,7 +1376,11 @@ async def my_task(x: int) -> int:

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError("log_inputs() must be called within a run")
warn_at_user_stacklevel(
"log_input() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.log_input(name, value, label=label, attributes=attributes)

Expand Down Expand Up @@ -1412,9 +1442,11 @@ async def my_task(x: int) -> int:

target = (task or run) if to == "task-or-run" else run
if target is None:
raise RuntimeError(
"log_output() must be called within a run or a task",
warn_at_user_stacklevel(
"log_output() was called outside of a task or run.",
category=DreadnodeUsageWarning,
)
return

target.log_output(name, value, label=label, attributes=attributes)

Expand Down Expand Up @@ -1461,7 +1493,11 @@ def link_objects(
attributes: Additional attributes to attach to the link.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("link() must be called within a run")
warn_at_user_stacklevel(
"link_objects() was called outside of a run.",
category=DreadnodeUsageWarning,
)
return

origin_hash = run.log_object(origin)
link_hash = run.log_object(link)
Expand Down
4 changes: 1 addition & 3 deletions dreadnode/tracing/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,14 +879,12 @@ def __init__(

def __enter__(self) -> te.Self:
self._run = current_run_span.get()
if self._run is None:
raise RuntimeError("You cannot start a task span without a run")

self._parent_task = current_task_span.get()
if self._parent_task is not None:
self.set_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, self._parent_task.span_id)
self._parent_task._tasks.append(self) # noqa: SLF001
else:
elif self._run:
self._run._tasks.append(self) # noqa: SLF001

self._context_token = current_task_span.set(self)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[project]
name = "dreadnode"
version = "1.11.0"
version = "1.11.1"
description = "Dreadnode SDK"
requires-python = ">=3.10,<3.14"

[tool.poetry]
name = "dreadnode"
version = "1.11.0"
version = "1.11.1"
description = "Dreadnode SDK"
authors = ["Nick Landers <monoxgas@gmail.com>"]
repository = "https://github.com/dreadnode/sdk"
Expand Down
Loading