diff --git a/docs/api/chat.mdx b/docs/api/chat.mdx
index 3d1469f..e0bbab8 100644
--- a/docs/api/chat.mdx
+++ b/docs/api/chat.mdx
@@ -10,17 +10,6 @@ Chats are used pre and post generation to hold messages.
They are the primary way to interact with the generator.
-CacheMode
----------
-
-```python
-CacheMode = Literal['latest']
-```
-
-How to handle cache\_control entries on messages.
-
-* latest: Assign cache\_control to the latest 2 non-assistant messages in the pipeline before inference.
-
DEFAULT\_MAX\_DEPTH
-------------------
@@ -1274,14 +1263,18 @@ def __init__(
"""How to handle failures in the pipeline unless overridden in calls."""
self.caching: CacheMode | None = None
"""How to handle cache_control entries on messages."""
+ self.task_name: str = generator.to_identifier(short=True)
+ """The name of the pipeline task, used for logging and debugging."""
+ self.scorers: list[dn.Scorer[Chat]] = []
+ """List of dreadnode scorers to evaluate the generated chat upon completion."""
self.until_types: list[type[Model]] = []
self.tools: list[Tool[..., t.Any]] = []
self.tool_mode: ToolMode = "auto"
self.inject_tool_prompt = True
self.add_tool_stop_token = True
- self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
- self.map_callbacks: list[tuple[MapChatCallback, int]] = []
+ self.then_callbacks: list[tuple[ThenChatCallback, int, bool]] = []
+ self.map_callbacks: list[tuple[MapChatCallback, int, bool]] = []
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
self.transforms: list[Transform] = []
```
@@ -1356,6 +1349,22 @@ params = params
The parameters for generating messages.
+### scorers
+
+```python
+scorers: list[Scorer[Chat]] = []
+```
+
+List of dreadnode scorers to evaluate the generated chat upon completion.
+
+### task\_name
+
+```python
+task_name: str = to_identifier(short=True)
+```
+
+The name of the pipeline task, used for logging and debugging.
+
### add
```python
@@ -1725,6 +1734,9 @@ def clone(
new.errors_to_catch = self.errors_to_catch.copy()
new.errors_to_exclude = self.errors_to_exclude.copy()
new.caching = self.caching
+ new.task_name = self.task_name
+ new.scorers = self.scorers.copy()
+ new.transforms = self.transforms.copy()
new.watch_callbacks = self.watch_callbacks.copy()
@@ -1736,18 +1748,18 @@ def clone(
return new
new.then_callbacks = [
- (callback, max_depth)
+ (callback, max_depth, as_task)
if not hasattr(callback, "__self__")
or not isinstance(callback.__self__, ChatPipeline)
- else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr]
- for callback, max_depth in self.then_callbacks.copy()
+ else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr]
+ for callback, max_depth, as_task in self.then_callbacks.copy()
]
new.map_callbacks = [
- (callback, max_depth)
+ (callback, max_depth, as_task)
if not hasattr(callback, "__self__")
or not isinstance(callback.__self__, ChatPipeline)
- else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr]
- for callback, max_depth in self.map_callbacks.copy()
+ else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr]
+ for callback, max_depth, as_task in self.map_callbacks.copy()
]
new.transforms = [
callback
@@ -1759,13 +1771,13 @@ def clone(
if not isinstance(callbacks, bool):
new.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback in callbacks
]
new.map_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.map_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.map_callbacks
if callback in callbacks
]
new.transforms = [callback for callback in self.transforms if callback in callbacks]
@@ -1833,6 +1845,7 @@ map(
*callbacks: MapChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> ChatPipeline
```
@@ -1861,6 +1874,11 @@ the final return value from the pipeline.
`False`
)
–Whether to allow (seemingly) duplicate callbacks to be added.
+* **`as_task`**
+ (`bool`, default:
+ `True`
+ )
+ –Whether to create a task for this callback.
**Returns:**
@@ -1884,6 +1902,7 @@ def map(
*callbacks: MapChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> "ChatPipeline":
"""
Registers a callback to be executed after the generation process completes.
@@ -1897,6 +1916,7 @@ def map(
callbacks: The callback function to be executed.
max_depth: The maximum depth to allow recursive pipeline calls during this callback.
allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.
+ as_task: Whether to create a task for this callback.
Returns:
The updated pipeline.
@@ -1918,7 +1938,7 @@ def map(
f"Callback '{get_qualified_name(callback)}' is already registered.",
)
- self.map_callbacks.extend([(callback, max_depth) for callback in callbacks])
+ self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
return self
```
@@ -1963,6 +1983,44 @@ def meta(self, **kwargs: t.Any) -> "ChatPipeline":
```
+
+
+### name
+
+```python
+name(name: str) -> ChatPipeline
+```
+
+Sets the name of the pipeline.
+
+**Parameters:**
+
+* **`name`**
+ (`str`)
+ –The name to set for the pipeline.
+
+**Returns:**
+
+* `ChatPipeline`
+ –The updated pipeline.
+
+
+```python
+def name(self, name: str) -> "ChatPipeline":
+ """
+ Sets the name of the pipeline.
+
+ Args:
+ name: The name to set for the pipeline.
+
+ Returns:
+ The updated pipeline.
+ """
+ self.task_name = name
+ return self
+```
+
+
### prompt
@@ -2015,6 +2073,7 @@ def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> "Prompt[P,
```python
run(
*,
+ name: str | None = None,
on_failed: FailMode | None = None,
allow_failed: bool = False,
) -> Chat
@@ -2024,6 +2083,11 @@ Execute the generation process for a single message.
**Parameters:**
+* **`name`**
+ (`str | None`, default:
+ `None`
+ )
+ –The name of the task for logging purposes.
* **`on_failed`**
(`FailMode | None`, default:
`None`
@@ -2045,6 +2109,7 @@ Execute the generation process for a single message.
async def run(
self,
*,
+ name: str | None = None,
on_failed: FailMode | None = None,
allow_failed: bool = False,
) -> Chat:
@@ -2052,6 +2117,7 @@ async def run(
Execute the generation process for a single message.
Args:
+ name: The name of the task for logging purposes.
on_failed: The behavior when a message fails to generate.
allow_failed: Deprecated, use `on_failed="include"`.
@@ -2068,16 +2134,44 @@ async def run(
if on_failed is None:
on_failed = "include" if allow_failed else self.on_failed
- last: PipelineStep | None = None
- async with self.step(on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ if on_failed == "skip":
+ raise ValueError(
+ "Cannot use 'skip' mode with single message generation (pass allow_failed=True or on_failed='include'/'raise')",
+ )
- if last is None or last.state != "final":
- raise RuntimeError("The pipeline did not complete successfully")
+ messages = [self.chat.all]
+ params = self._fit_params(1, [self.params])
- if not last.chats:
- raise RuntimeError("The pipeline process did not produce any chats")
+ last: PipelineStep | None = None
+ with dn.task_span(
+ name or f"pipeline - {self.task_name}",
+ label=name or f"pipeline_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run"},
+ ) as task:
+ dn.log_inputs(
+ messages=messages[0],
+ params=params[0],
+ generator_id=self.generator.to_identifier(),
+ )
+
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None and last.chats:
+ dn.log_output("chat", last.chats[-1])
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
+
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
+
+ if not last.chats:
+ raise RuntimeError("The pipeline process did not produce any chats")
return last.chats[-1]
```
@@ -2097,6 +2191,8 @@ run_batch(
| str,
params: Sequence[GenerateParams | None] | None = None,
*,
+ name: str | None = None,
+ mode: Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList
```
@@ -2117,6 +2213,18 @@ Anything already in this chat pipeline will be prepended to the input messages.
`None`
)
–A sequence of parameters to be used for each set of messages.
+* **`name`**
+ (`str | None`, default:
+ `None`
+ )
+ –The name of the task for logging purposes.
+* **`mode`**
+ (`Literal['merged', 'parallel']`, default:
+ `'parallel'`
+ )
+ –The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
* **`on_failed`**
(`FailMode | None`, default:
`None`
@@ -2140,6 +2248,8 @@ async def run_batch(
| str,
params: t.Sequence[GenerateParams | None] | None = None,
*,
+ name: str | None = None,
+ mode: t.Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList:
"""
@@ -2151,21 +2261,76 @@ async def run_batch(
Args:
many: A sequence of sequences of messages to be generated.
params: A sequence of parameters to be used for each set of messages.
+ name: The name of the task for logging purposes.
+ mode: The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
on_failed: The behavior when a message fails to generate.
Returns:
A list of generatated Chats.
"""
+ on_failed = on_failed or self.on_failed
+ count, messages, params = self._fit_batch_args(many, params)
last: PipelineStep | None = None
- async with self.step_batch(many, params=params, on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ with dn.task_span(
+ name or f"pipeline - {self.task_name} (batch x{count})",
+ label=name or f"pipeline_batch_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run_batch"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ messages=messages,
+ params=params,
+ generator_id=self.generator.to_identifier(),
+ )
+
+ if mode == "merged":
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None:
+ dn.log_output("chats", last.chats)
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
+
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
+
+ return last.chats
+
+ if mode == "parallel":
+ tasks = [
+ asyncio.create_task(
+ self.clone().add(_messages).with_(_params).run(on_failed="include")
+ )
+ for _messages, _params in zip(messages, params, strict=True)
+ ]
+ chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True)
+
+ self._raise_if_failed(chats_or_errors, on_failed)
- if last is None or last.state != "final":
- raise ValueError("The generation process did not complete successfully")
+ chats = [
+ chat
+ for chat in chats_or_errors
+ if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed)
+ ]
+
+ dn.log_output("chats", chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", chats)
+
+ return ChatList(chats)
- return last.chats
+ raise ValueError(
+ f"Invalid mode '{mode}', expected 'merged' or 'separate'",
+ )
```
@@ -2178,6 +2343,8 @@ run_many(
count: int,
*,
params: Sequence[GenerateParams | None] | None = None,
+ name: str | None = None,
+ mode: Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList
```
@@ -2194,6 +2361,18 @@ Executes the generation process in parallel over the same input.
`None`
)
–A sequence of parameters to be used for each execution.
+* **`name`**
+ (`str | None`, default:
+ `None`
+ )
+ –The name of the task for logging purposes.
+* **`mode`**
+ (`Literal['merged', 'parallel']`, default:
+ `'parallel'`
+ )
+ –The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
* **`on_failed`**
(`FailMode | None`, default:
`None`
@@ -2203,7 +2382,7 @@ Executes the generation process in parallel over the same input.
**Returns:**
* `ChatList`
- –A list of generatated Chats.
+ –A list of generated Chats.
```python
@@ -2212,6 +2391,8 @@ async def run_many(
count: int,
*,
params: t.Sequence[GenerateParams | None] | None = None,
+ name: str | None = None,
+ mode: t.Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList:
"""
@@ -2220,21 +2401,76 @@ async def run_many(
Args:
count: The number of times to execute the generation process.
params: A sequence of parameters to be used for each execution.
+ name: The name of the task for logging purposes.
+ mode: The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
on_failed: The behavior when a message fails to generate.
Returns:
- A list of generatated Chats.
+ A list of generated Chats.
"""
+ if count < 1:
+ raise ValueError("Count must be greater than 0")
+
+ on_failed = on_failed or self.on_failed
+
+ messages = [self.chat.all] * count
+ params = self._fit_params(count, params)
last: PipelineStep | None = None
- async with self.step_many(count, params=params, on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ with dn.task_span(
+ name or f"pipeline - {self.task_name} (x{count})",
+ label=name or f"pipeline_many_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run_many"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ messages=messages[0],
+ params=params[0],
+ generator_id=self.generator.to_identifier(),
+ )
+
+ if mode == "merged":
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None:
+ dn.log_output("chats", last.chats)
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
+
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
+
+ return last.chats
+
+ if mode == "parallel":
+ tasks = [asyncio.create_task(self.run(on_failed="include")) for _ in range(count)]
+ chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True)
+
+ self._raise_if_failed(chats_or_errors, on_failed)
+
+ chats = [
+ chat
+ for chat in chats_or_errors
+ if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed)
+ ]
- if last is None or last.state != "final":
- raise ValueError("The generation process did not complete successfully")
+ dn.log_output("chats", chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", chats)
- return last.chats
+ return ChatList(chats)
+
+ raise ValueError(
+ f"Invalid mode '{mode}', expected 'merged' or 'parallel'",
+ )
```
@@ -2314,11 +2550,100 @@ async def run_over(
sub.generator = generator
coros.append(sub.run(allow_failed=(on_failed != "raise")))
- with tracer.span(f"Chat over {len(coros)} generators", count=len(coros)):
+ short_generators = [g.to_identifier(short=True) for g in _generators]
+ task_name = "iterate - " + ", ".join(short_generators)
+
+ with dn.task_span(
+ task_name,
+ label="iterate_over",
+ attributes={"rigging.type": "chat_pipeline.run_over"},
+ ):
+ dn.log_input("generators", [g.to_identifier() for g in _generators])
return ChatList(await asyncio.gather(*coros))
```
+
+
+### score
+
+```python
+score(
+ *scorers: Scorer[Chat] | ScorerCallable[Chat],
+ filter: ChatFilterMode | ChatFilterFunction = "last",
+) -> ChatPipeline
+```
+
+Adds one or more scorers to the pipeline to evaluate the generated chat upon completion.
+
+**Parameters:**
+
+* **`*scorers`**
+ (`Scorer[Chat] | ScorerCallable[Chat]`, default:
+ `()`
+ )
+ –The scorer or scorers to be added. These can be either:
+ - A dreadnode.Scorer instance.
+ - A callable function that can be converted to a dreadnode.Scorer.
+* **`filter`**
+ (`ChatFilterMode | ChatFilterFunction`, default:
+ `'last'`
+ )
+ –The strategy for filtering which messages to include:
+ - "all": Use all messages in the chat.
+ - "last": Use only the last message.
+ - "first": Use only the first message.
+ - "user": Use only user messages.
+ - "assistant": Use only assistant messages.
+ - "last\_user": Use only the last user message.
+ - "last\_assistant": Use only the last assistant message.
+ - A callable that takes a list of `Message` objects and returns a filtered list.
+
+**Returns:**
+
+* `ChatPipeline`
+ –The updated pipeline.
+
+
+```python
+def score(
+ self,
+ *scorers: dn.Scorer[Chat] | ScorerCallable[Chat],
+ filter: "ChatFilterMode | ChatFilterFunction" = "last",
+) -> "ChatPipeline":
+ """
+ Adds one or more scorers to the pipeline to evaluate the generated chat upon completion.
+
+ Args:
+ *scorers: The scorer or scorers to be added. These can be either:
+ - A dreadnode.Scorer instance.
+ - A callable function that can be converted to a dreadnode.Scorer.
+ filter: The strategy for filtering which messages to include:
+ - "all": Use all messages in the chat.
+ - "last": Use only the last message.
+ - "first": Use only the first message.
+ - "user": Use only user messages.
+ - "assistant": Use only assistant messages.
+ - "last_user": Use only the last user message.
+ - "last_assistant": Use only the last assistant message.
+ - A callable that takes a list of `Message` objects and returns a filtered list.
+
+ Returns:
+ The updated pipeline.
+ """
+ self.scorers.extend(
+ [
+ dn.scorers.wrap_chat(
+ scorer if isinstance(scorer, dn.Scorer) else dn.Scorer.from_callable(scorer),
+ filter=filter,
+ )
+ for scorer in scorers
+ ]
+ )
+ return self
+```
+
+
### step
@@ -2374,15 +2699,10 @@ async def step(
messages = [self.chat.all]
params = self._fit_params(1, [self.params])
- with tracer.span(
- f"Chat with {self.generator.to_identifier()}",
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
```
@@ -2461,37 +2781,12 @@ async def step_batch(
Pipeline steps.
"""
on_failed = on_failed or self.on_failed
+ _, messages, params = self._fit_batch_args(many, params)
- # Get the maximum of either incoming messages or params
-
- count = max(len(many), len(params) if params is not None else 0)
-
- # If we have less messages than params, we need to either:
- # 1. Error because we have >1 messages that we can't reasonably
- # zip with our parameters of a different length
- # 2. Duplicate a single message we have len(params) times as the
- # user is just batching only over parameters
-
- messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many]
- if len(messages) < count:
- if len(messages) != 1:
- raise ValueError(
- f"Can't fit {len(messages)} messages to {count} params",
- )
- messages = messages * count
-
- params = self._fit_params(count, params)
-
- with tracer.span(
- f"Chat batch with {self.generator.to_identifier()} ({count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
```
@@ -2557,16 +2852,10 @@ async def step_many(
messages = [self.chat.all] * count
params = self._fit_params(count, params)
- with tracer.span(
- f"Chat with {self.generator.to_identifier()} (x{count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
```
@@ -2579,6 +2868,7 @@ then(
*callbacks: ThenChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> ChatPipeline
```
@@ -2607,6 +2897,11 @@ from the pipeline.
`False`
)
–Whether to allow (seemingly) duplicate callbacks to be added.
+* **`as_task`**
+ (`bool`, default:
+ `True`
+ )
+ –Whether to create a task for this callback.
**Returns:**
@@ -2630,6 +2925,7 @@ def then(
*callbacks: ThenChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> "ChatPipeline":
"""
Registers one or many callbacks to be executed after the generation process completes.
@@ -2643,6 +2939,7 @@ def then(
callbacks: The callback functions to be added.
max_depth: The maximum depth to allow recursive pipeline calls during this callback.
allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.
+ as_task: Whether to create a task for this callback.
Returns:
The updated pipeline.
@@ -2664,7 +2961,7 @@ def then(
f"Callback '{get_qualified_name(callback)}' is already registered.",
)
- self.then_callbacks.extend([(callback, max_depth) for callback in callbacks])
+ self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
return self
```
@@ -2867,11 +3164,11 @@ def until_parsed_as(
max_depth = max_rounds or max_depth
self.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback != self._then_parse
]
- self.then_callbacks.append((self._then_parse, max_depth))
+ self.then_callbacks.append((self._then_parse, max_depth, False))
return self
```
@@ -3037,13 +3334,13 @@ def using(
self.tools = [tool for tool in self.tools if tool.name not in new_names] + _tools
self.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback != self._then_tools # Always remove to update max_depth
]
self.then_callbacks.insert(
0, # make sure this is first
- (self._then_tools, max_depth),
+ (self._then_tools, max_depth, False),
)
if mode is not None:
diff --git a/docs/api/completion.mdx b/docs/api/completion.mdx
index e829262..2c394cb 100644
--- a/docs/api/completion.mdx
+++ b/docs/api/completion.mdx
@@ -897,12 +897,21 @@ async def run(
on_failed = on_failed or self.on_failed
states = self._initialize_states(1)
- with tracer.span(
- f"Completion with {self.generator.to_identifier()}",
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return (await self._run(span, states, on_failed))[0]
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)}",
+ label=f"pipeline_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run"},
+ ) as task:
+ dn.log_inputs(
+ text=self.text,
+ params=self.params.to_dict() if self.params is not None else {},
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed)
+ completion = completions[0]
+ dn.log_output("completion", completion)
+ task.set_attribute("completions", completions)
+ return completion
```
@@ -978,13 +987,21 @@ async def run_batch(
for state in states:
next(state.processor)
- with tracer.span(
- f"Completion batch with {self.generator.to_identifier()} ({len(states)})",
- count=len(states),
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return await self._run(span, states, on_failed, batch_mode=True)
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)} (batch x{len(states)})",
+ label=f"pipeline_batch_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run_batch"},
+ ) as task:
+ dn.log_inputs(
+ count=len(states),
+ many=many,
+ params=params,
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed, batch_mode=True)
+ dn.log_output("completions", completions)
+ task.set_attribute("completions", completions)
+ return completions
```
@@ -1047,13 +1064,21 @@ async def run_many(
on_failed = on_failed or self.on_failed
states = self._initialize_states(count, params)
- with tracer.span(
- f"Completion with {self.generator.to_identifier()} (x{count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return await self._run(span, states, on_failed)
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)} (x{count})",
+ label=f"pipeline_many_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run_many"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ text=self.text,
+ params=self.params.to_dict() if self.params is not None else {},
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed)
+ dn.log_output("completions", completions)
+ task.set_attribute("completions", completions)
+ return completions
```
@@ -1133,9 +1158,20 @@ async def run_over(
sub.generator = generator
coros.append(sub.run(allow_failed=(on_failed != "raise")))
- with tracer.span(f"Completion over {len(coros)} generators", count=len(coros)):
+ short_generators = [g.to_identifier(short=True) for g in _generators]
+ task_name = "iterate - " + ", ".join(short_generators)
+
+ with dn.task_span(
+ task_name,
+ label="iterate_over",
+ attributes={"rigging.type": "completion_pipeline.run_over"},
+ ) as task:
+ dn.log_input("generators", [g.to_identifier() for g in _generators])
completions = await asyncio.gather(*coros)
- return await self._post_run(completions, on_failed)
+ final_completions = await self._post_run(completions, on_failed)
+ dn.log_output("completions", final_completions)
+ task.set_attribute("completions", final_completions)
+ return final_completions
```
diff --git a/docs/api/generator.mdx b/docs/api/generator.mdx
index f9c4c17..e77a07b 100644
--- a/docs/api/generator.mdx
+++ b/docs/api/generator.mdx
@@ -793,7 +793,11 @@ async def supports_function_calling(self) -> bool | None:
### to\_identifier
```python
-to_identifier(params: GenerateParams | None = None) -> str
+to_identifier(
+ params: GenerateParams | None = None,
+ *,
+ short: bool = False,
+) -> str
```
Converts the generator instance back into a rigging identifier string.
@@ -815,7 +819,7 @@ This calls [rigging.generator.get\_identifier][] with the current instance.
```python
-def to_identifier(self, params: GenerateParams | None = None) -> str:
+def to_identifier(self, params: GenerateParams | None = None, *, short: bool = False) -> str:
"""
Converts the generator instance back into a rigging identifier string.
@@ -827,7 +831,7 @@ def to_identifier(self, params: GenerateParams | None = None) -> str:
Returns:
The identifier string.
"""
- return get_identifier(self, params)
+ return get_identifier(self, params, short=short)
```
@@ -1075,6 +1079,163 @@ min_delay_between_requests: float = 0.0
Minimum time (ms) between each request.
This is useful to set when you run into API limits at a provider.
+TransformersGenerator
+---------------------
+
+Generator backed by the Transformers library for local model loading.
+
+
+The use of Transformers requires the `transformers` package to be installed directly or by
+installing rigging as `rigging[all]`.
+
+
+
+The `transformers` library is expansive with many different models, tokenizers,
+options, constructors, etc. We do our best to implement a consistent interface,
+but there may be limitations. Where needed, use
+[`.from_obj()`][rigging.generator.transformers\_.TransformersGenerator.from\_obj].
+
+
+
+This generator doesn't leverage any async capabilities.
+
+
+
+The model load into memory will occur lazily when the first generation is requested.
+If you'd want to force this to happen earlier, you can use the
+[`.load()`][rigging.generator.Generator.load] method.
+
+To unload, call [`.unload()`][rigging.generator.Generator.unload].
+
+
+### device\_map
+
+```python
+device_map: str = 'auto'
+```
+
+Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)
+
+### llm
+
+```python
+llm: AutoModelForCausalLM
+```
+
+The underlying `AutoModelForCausalLM` instance.
+
+### load\_in\_4bit
+
+```python
+load_in_4bit: bool = False
+```
+
+Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)
+
+### load\_in\_8bit
+
+```python
+load_in_8bit: bool = False
+```
+
+Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)
+
+### pipeline
+
+```python
+pipeline: TextGenerationPipeline
+```
+
+The underlying `TextGenerationPipeline` instance.
+
+### tokenizer
+
+```python
+tokenizer: AutoTokenizer
+```
+
+The underlying `AutoTokenizer` instance.
+
+### torch\_dtype
+
+```python
+torch_dtype: str = 'auto'
+```
+
+Torch dtype passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)
+
+### trust\_remote\_code
+
+```python
+trust_remote_code: bool = False
+```
+
+Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)
+
+### from\_obj
+
+```python
+from_obj(
+ model: Any,
+ tokenizer: AutoTokenizer,
+ *,
+ pipeline: TextGenerationPipeline | None = None,
+ params: GenerateParams | None = None,
+) -> TransformersGenerator
+```
+
+Create a new instance of TransformersGenerator from an already loaded model and tokenizer.
+
+**Parameters:**
+
+* **`model`**
+ (`Any`)
+ –The loaded model for text generation.
+* **`tokenizer`**
+ –The tokenizer associated with the model.
+* **`pipeline`**
+ (`TextGenerationPipeline | None`, default:
+ `None`
+ )
+ –The text generation pipeline. Defaults to None.
+
+**Returns:**
+
+* `TransformersGenerator`
+ –The TransformersGenerator instance.
+
+
+```python
+@classmethod
+def from_obj(
+ cls,
+ model: t.Any,
+ tokenizer: AutoTokenizer,
+ *,
+ pipeline: TextGenerationPipeline | None = None,
+ params: GenerateParams | None = None,
+) -> "TransformersGenerator":
+ """
+ Create a new instance of TransformersGenerator from an already loaded model and tokenizer.
+
+ Args:
+ model: The loaded model for text generation.
+ tokenizer : The tokenizer associated with the model.
+ pipeline: The text generation pipeline. Defaults to None.
+
+ Returns:
+ The TransformersGenerator instance.
+ """
+ instance = cls(model=model, params=params or GenerateParams(), api_key=None)
+ instance._llm = model
+ instance._tokenizer = tokenizer
+ instance._pipeline = pipeline
+ return instance
+```
+
+
+
+
Usage
-----
@@ -1104,6 +1265,128 @@ total_tokens: int
The total number of tokens processed.
+VLLMGenerator
+-------------
+
+Generator backed by the vLLM library for local model loading.
+
+Find more information about supported models and formats [in their docs.](https://docs.vllm.ai/en/latest/index.html)
+
+
+The use of VLLM requires the `vllm` package to be installed directly or by
+installing rigging as `rigging[all]`.
+
+
+
+This generator doesn't leverage any async capabilities.
+
+
+
+The model load into memory will occur lazily when the first generation is requested.
+If you'd want to force this to happen earlier, you can use the
+[`.load()`][rigging.generator.Generator.load] method.
+
+To unload, call [`.unload()`][rigging.generator.Generator.unload].
+
+
+### dtype
+
+```python
+dtype: str = 'auto'
+```
+
+Tensor dtype passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)
+
+### enforce\_eager
+
+```python
+enforce_eager: bool = False
+```
+
+Eager enforcement passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)
+
+### gpu\_memory\_utilization
+
+```python
+gpu_memory_utilization: float = 0.9
+```
+
+Memory utilization passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)
+
+### llm
+
+```python
+llm: LLM
+```
+
+The underlying [`vLLM model`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) instance.
+
+### quantization
+
+```python
+quantization: str | None = None
+```
+
+Quantiziation passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)
+
+### trust\_remote\_code
+
+```python
+trust_remote_code: bool = False
+```
+
+Trust remote code passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)
+
+### from\_obj
+
+```python
+from_obj(
+ model: str,
+ llm: LLM,
+ *,
+ params: GenerateParams | None = None,
+) -> VLLMGenerator
+```
+
+Create a generator from an existing vLLM instance.
+
+**Parameters:**
+
+* **`llm`**
+ (`LLM`)
+ –The vLLM instance to create the generator from.
+
+**Returns:**
+
+* `VLLMGenerator`
+ –The VLLMGenerator instance.
+
+
+```python
+@classmethod
+def from_obj(
+ cls,
+ model: str,
+ llm: vllm.LLM,
+ *,
+ params: GenerateParams | None = None,
+) -> "VLLMGenerator":
+ """Create a generator from an existing vLLM instance.
+
+ Args:
+ llm: The vLLM instance to create the generator from.
+
+ Returns:
+ The VLLMGenerator instance.
+ """
+ generator = cls(model=model, params=params or GenerateParams(), api_key=None)
+ generator._llm = llm
+ return generator
+```
+
+
+
+
chat
----
@@ -1361,6 +1644,8 @@ get\_identifier
get_identifier(
generator: Generator,
params: GenerateParams | None = None,
+ *,
+ short: bool = False,
) -> str
```
@@ -1388,7 +1673,9 @@ The `extra` parameter field is not currently supported in identifiers.
```python
-def get_identifier(generator: Generator, params: GenerateParams | None = None) -> str:
+def get_identifier(
+ generator: Generator, params: GenerateParams | None = None, *, short: bool = False
+) -> str:
"""
Converts the generator instance back into a rigging identifier string.
@@ -1408,7 +1695,10 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
for name, klass in g_generators.items()
if isinstance(klass, type) and isinstance(generator, klass)
)
- identifier = f"{provider}!{generator.model}"
+ identifier = f"{provider}!{generator.model}" if provider != "litellm" else generator.model
+
+ if short:
+ return identifier
identifier_extra = generator.model_dump(
exclude_unset=True,
@@ -1613,7 +1903,7 @@ def from_obj(
Returns:
The VLLMGenerator instance.
"""
- generator = cls(model=model, params=params or GenerateParams())
+ generator = cls(model=model, params=params or GenerateParams(), api_key=None)
generator._llm = llm
return generator
```
@@ -1776,7 +2066,7 @@ def from_obj(
Returns:
The TransformersGenerator instance.
"""
- instance = cls(model=model, params=params or GenerateParams())
+ instance = cls(model=model, params=params or GenerateParams(), api_key=None)
instance._llm = model
instance._tokenizer = tokenizer
instance._pipeline = pipeline
diff --git a/docs/api/message.mdx b/docs/api/message.mdx
index 2050fd4..87fe06b 100644
--- a/docs/api/message.mdx
+++ b/docs/api/message.mdx
@@ -2471,14 +2471,16 @@ def clone(self) -> "MessageSlice":
Returns:
A new MessageSlice instance with the same properties.
"""
- return MessageSlice(
+ cloned = MessageSlice(
type=self.type,
obj=self.obj,
start=self.start,
stop=self.stop,
metadata=copy.deepcopy(self.metadata),
- _message=self._message, # Keep the reference to the original message
)
+ # Leaving this detached to align with tests
+ # cloned._message = self._message
+ return cloned # noqa: RET504
```
diff --git a/docs/api/prompt.mdx b/docs/api/prompt.mdx
index d798f9f..2f0559d 100644
--- a/docs/api/prompt.mdx
+++ b/docs/api/prompt.mdx
@@ -269,10 +269,36 @@ def bind(
raise NotImplementedError(
"pipeline.on_failed='skip' cannot be used for prompt methods that return one object",
)
+ if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput):
+ raise NotImplementedError(
+ "pipeline.on_failed='include' cannot be used with prompts that process outputs",
+ )
async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- results = await self.bind_many(pipeline)(1, *args, **kwargs)
- return results[0]
+ name = get_qualified_name(self.func) if self.func else ""
+ with dn.task_span(
+ f"prompt - {name}",
+ attributes={"prompt_name": name, "rigging.type": "prompt.run"},
+ ):
+ dn.log_inputs(**self._bind_args(*args, **kwargs))
+ content = self.render(*args, **kwargs)
+ _pipeline = (
+ pipeline.fork(content)
+ .using(*self.tools, max_depth=self.max_tool_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
+ .then(*self.then_callbacks)
+ .map(*self.map_callbacks)
+ .watch(*self.watch_callbacks)
+ .with_(self.params)
+ )
+
+ if self.system_prompt:
+ _pipeline.chat.inject_system_content(self.system_prompt)
+
+ chat = await _pipeline.run()
+ output = self.process(chat)
+ dn.log_output("output", output)
+ return output
run.__signature__ = self.__signature__ # type: ignore [attr-defined]
run.__name__ = self.__name__
@@ -315,7 +341,7 @@ Example
def say_hello(name: str) -> str:
"""Say hello to {{ name }}"""
-await say_hello.bind("gpt-3.5-turbo")(5, "the world")
+await say_hello.bind_many("gpt-4.1")(5, "the world")
```
@@ -340,7 +366,7 @@ def bind_many(
def say_hello(name: str) -> str:
\"""Say hello to {{ name }}\"""
- await say_hello.bind("gpt-3.5-turbo")(5, "the world")
+ await say_hello.bind_many("gpt-4.1")(5, "the world")
~~~
"""
pipeline = self._resolve_to_pipeline(other)
@@ -351,17 +377,17 @@ def bind_many(
async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]:
name = get_qualified_name(self.func) if self.func else ""
- with tracer.span(
- f"Prompt {name}()" if count == 1 else f"Prompt {name}() (x{count})",
- count=count,
- name=name,
- arguments=self._bind_args(*args, **kwargs),
+ with dn.task_span(
+ f"prompt - {name} (x{count})",
+ label=f"prompt_{name}",
+ attributes={"prompt_name": name, "rigging.type": "prompt.run_many"},
) as span:
+ dn.log_inputs(**self._bind_args(*args, **kwargs))
content = self.render(*args, **kwargs)
_pipeline = (
pipeline.fork(content)
.using(*self.tools, max_depth=self.max_tool_rounds)
- .then(self._then_parse, max_depth=self.max_parsing_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
.then(*self.then_callbacks)
.map(*self.map_callbacks)
.watch(*self.watch_callbacks)
@@ -372,35 +398,13 @@ def bind_many(
_pipeline.chat.inject_system_content(self.system_prompt)
chats = await _pipeline.run_many(count)
-
- # TODO: I can't remember why we don't just pass the watch_callbacks to the pipeline
- # Maybe it has something to do with uniqueness and merging?
-
- def wrap_watch_callback(callback: "WatchChatCallback") -> "WatchChatCallback":
- async def traced_watch_callback(chats: list[Chat]) -> None:
- callback_name = get_qualified_name(callback)
- with tracer.span(
- f"Watch with {callback_name}()",
- callback=callback_name,
- chat_count=len(chats),
- chat_ids=[str(c.uuid) for c in chats],
- ):
- await callback(chats)
-
- return traced_watch_callback
-
- coros = [
- wrap_watch_callback(watch)(chats)
- for watch in self.watch_callbacks
- if watch not in pipeline.watch_callbacks
- ]
- await asyncio.gather(*coros)
-
- results = [self.process(chat) for chat in chats]
- span.set_attribute("results", results)
- return results
+ outputs = [self.process(chat) for chat in chats]
+ span.log_output("outputs", outputs)
+ return outputs
run_many.__rg_prompt__ = self # type: ignore [attr-defined]
+ run_many.__name__ = self.__name__
+ run_many.__doc__ = self.__doc__
return run_many
```
@@ -445,7 +449,7 @@ Example
def say_hello(name: str) -> str:
"""Say hello to {{ name }}"""
-await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world")
+await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world")
```
@@ -473,7 +477,7 @@ def bind_over(
def say_hello(name: str) -> str:
\"""Say hello to {{ name }}\"""
- await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world")
+ await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world")
~~~
"""
include_original = other is not None
@@ -500,7 +504,7 @@ def bind_over(
_pipeline = (
pipeline.fork(content)
.using(*self.tools, max_depth=self.max_tool_rounds)
- .then(self._then_parse, max_depth=self.max_parsing_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
.then(*self.then_callbacks)
.map(*self.map_callbacks)
.watch(*self.watch_callbacks)
@@ -512,13 +516,6 @@ def bind_over(
chats = await _pipeline.run_over(*generators, include_original=include_original)
- coros = [
- watch(chats)
- for watch in self.watch_callbacks
- if watch not in pipeline.watch_callbacks
- ]
- await asyncio.gather(*coros)
-
return [self.process(chat) for chat in chats]
run_over.__rg_prompt__ = self # type: ignore [attr-defined]
diff --git a/docs/api/tokenize.mdx b/docs/api/tokenize.mdx
index 3b4c849..527053d 100644
--- a/docs/api/tokenize.mdx
+++ b/docs/api/tokenize.mdx
@@ -366,6 +366,125 @@ async def tokenize_chat(self, chat: "Chat") -> TokenizedChat:
```
+
+
+TransformersTokenizer
+---------------------
+
+A tokenizer implementation using Hugging Face Transformers.
+
+This class provides tokenization capabilities for chat conversations
+using transformers models and their associated tokenizers.
+
+### apply\_chat\_template\_kwargs
+
+```python
+apply_chat_template_kwargs: dict[str, Any] = Field(
+ default_factory=dict
+)
+```
+
+Additional keyword arguments for applying the chat template.
+
+### decode\_kwargs
+
+```python
+decode_kwargs: dict[str, Any] = Field(default_factory=dict)
+```
+
+Additional keyword arguments for decoding tokens.
+
+### encode\_kwargs
+
+```python
+encode_kwargs: dict[str, Any] = Field(default_factory=dict)
+```
+
+Additional keyword arguments for encoding text.
+
+### tokenizer
+
+```python
+tokenizer: PreTrainedTokenizer
+```
+
+The underlying `PreTrainedTokenizer` instance.
+
+### encode
+
+```python
+encode(text: str) -> list[int]
+```
+
+Encodes the given text into a list of tokens.
+
+**Parameters:**
+
+* **`text`**
+ (`str`)
+ –The text to encode.
+
+**Returns:**
+
+* `list[int]`
+ –A list of tokens representing the encoded text.
+
+
+```python
+def encode(self, text: str) -> list[int]:
+ """
+ Encodes the given text into a list of tokens.
+
+ Args:
+ text: The text to encode.
+
+ Returns:
+ A list of tokens representing the encoded text.
+ """
+ return self.tokenizer.encode(text, **self.encode_kwargs) # type: ignore [no-any-return]
+```
+
+
+
+
+### from\_obj
+
+```python
+from_obj(
+ tokenizer: PreTrainedTokenizer,
+) -> TransformersTokenizer
+```
+
+Create a new instance of TransformersTokenizer from an already loaded tokenizer.
+
+**Parameters:**
+
+* **`tokenizer`**
+ (`PreTrainedTokenizer`)
+ –The tokenizer associated with the model.
+
+**Returns:**
+
+* `TransformersTokenizer`
+ –The TransformersTokenizer instance.
+
+
+```python
+@classmethod
+def from_obj(cls, tokenizer: "PreTrainedTokenizer") -> "TransformersTokenizer":
+ """
+ Create a new instance of TransformersTokenizer from an already loaded tokenizer.
+
+ Args:
+ tokenizer: The tokenizer associated with the model.
+
+ Returns:
+ The TransformersTokenizer instance.
+ """
+ return cls(model=str(tokenizer), _tokenizer=tokenizer)
+```
+
+
get\_tokenizer
diff --git a/docs/api/tools.mdx b/docs/api/tools.mdx
index 259d132..4c78436 100644
--- a/docs/api/tools.mdx
+++ b/docs/api/tools.mdx
@@ -36,17 +36,6 @@ How tool calls are handled.
Tool
----
-```python
-Tool(
- name: str,
- description: str,
- parameters_schema: dict[str, Any],
- fn: Callable[P, R],
- catch: bool | set[type[Exception]] = False,
- truncate: int | None = None,
-)
-```
-
Base class for representing a tool to a generator.
### catch
@@ -82,7 +71,10 @@ A description of the tool.
### fn
```python
-fn: Callable[P, R]
+fn: Callable[P, R] = Field(
+ default_factory=lambda: lambda *args, **kwargs: None,
+ exclude=True,
+)
```
The function to call.
@@ -153,7 +145,12 @@ async def handle_tool_call( # noqa: PLR0912
from rigging.message import ContentText, ContentTypes, Message
- with tracer.span(f"Tool {self.name}()", name=self.name) as span:
+ with dn.task_span(
+ f"tool - {self.name}",
+ attributes={"tool_name": self.name, "rigging.type": "tool"},
+ ) as task:
+ dn.log_input("tool_call", tool_call)
+
if tool_call.name != self.name:
warnings.warn(
f"Tool call name mismatch: {tool_call.name} != {self.name}",
@@ -163,7 +160,7 @@ async def handle_tool_call( # noqa: PLR0912
return Message.from_model(SystemErrorModel(content="Invalid tool call.")), True
if hasattr(tool_call, "id") and isinstance(tool_call.id, str):
- span.set_attribute("tool_call_id", tool_call.id)
+ task.set_attribute("tool_call_id", tool_call.id)
result: t.Any
stop = False
@@ -174,8 +171,9 @@ async def handle_tool_call( # noqa: PLR0912
kwargs = json.loads(tool_call.function.arguments)
if self._type_adapter is not None:
kwargs = self._type_adapter.validate_python(kwargs)
- span.set_attribute("arguments", kwargs)
+ dn.log_inputs(**kwargs)
except (json.JSONDecodeError, ValidationError) as e:
+ task.set_exception(e)
result = ErrorModel.from_exception(e)
# Call the function
@@ -190,17 +188,18 @@ async def handle_tool_call( # noqa: PLR0912
raise result # noqa: TRY301
except Stop as e:
result = f"<{TOOL_STOP_TAG}>{e.message}{TOOL_STOP_TAG}>"
- span.set_attribute("stop", True)
+ task.set_attribute("stop", True)
stop = True
except Exception as e:
if self.catch is True or (
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
):
+ task.set_exception(e)
result = ErrorModel.from_exception(e)
else:
raise
- span.set_attribute("result", result)
+ dn.log_output("output", result)
message = Message(role="tool", tool_call_id=tool_call.id)
@@ -232,17 +231,6 @@ async def handle_tool_call( # noqa: PLR0912
ToolMethod
----------
-```python
-ToolMethod(
- name: str,
- description: str,
- parameters_schema: dict[str, Any],
- fn: Callable[P, R],
- catch: bool | set[type[Exception]] = False,
- truncate: int | None = None,
-)
-```
-
A Tool wrapping a class method.
tool
@@ -683,10 +671,10 @@ def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.A
tools.append(
Tool(
- function.name,
- function.description or "",
- function.parameters or {},
- make_execute_on_server(url, function.name),
+ name=function.name,
+ description=function.description or "",
+ parameters_schema=function.parameters or {},
+ fn=make_execute_on_server(url, function.name),
),
)
diff --git a/poetry.lock b/poetry.lock
index 9674924..e53e55d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1156,6 +1156,18 @@ traitlets = ">=4"
[package.extras]
test = ["pytest"]
+[[package]]
+name = "coolname"
+version = "2.2.0"
+description = "Random name and slug generator"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "coolname-2.2.0-py2.py3-none-any.whl", hash = "sha256:4d1563186cfaf71b394d5df4c744f8c41303b6846413645e31d31915cdeb13e8"},
+ {file = "coolname-2.2.0.tar.gz", hash = "sha256:6c5d5731759104479e7ca195a9b64f7900ac5bead40183c09323c7d0be9e75c7"},
+]
+
[[package]]
name = "coverage"
version = "7.9.2"
@@ -1432,6 +1444,32 @@ files = [
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
+[[package]]
+name = "dreadnode"
+version = "1.12.0"
+description = "Dreadnode SDK"
+optional = false
+python-versions = "<3.14,>=3.10"
+groups = ["main"]
+files = [
+ {file = "dreadnode-1.12.0-py3-none-any.whl", hash = "sha256:0286ba18c47718891e43e39bc8f330f80045d80be2efce8e89c082b1f7101c5a"},
+ {file = "dreadnode-1.12.0.tar.gz", hash = "sha256:73204c6ac0424931e505d6ca0598a6703dd7465a61f54d8bc62e0a52e0f98b67"},
+]
+
+[package.dependencies]
+coolname = ">=2.2.0,<3.0.0"
+fsspec = {version = ">=2023.1.0,<=2025.3.0", extras = ["s3"]}
+httpx = ">=0.28.0,<0.29.0"
+logfire = ">=3.5.3,<=3.20.0"
+pandas = ">=2.2.3,<3.0.0"
+pydantic = ">=2.9.2,<3.0.0"
+python-ulid = ">=3.0.0,<4.0.0"
+rigging = ">=3.1.1,<4.0.0"
+
+[package.extras]
+multimodal = ["moviepy (>=2.1.2,<3.0.0)", "pillow (>=11.2.1,<12.0.0)", "soundfile (>=0.13.1,<0.14.0)"]
+training = ["transformers (>=4.41.0,<5.0.0)"]
+
[[package]]
name = "elastic-transport"
version = "8.17.1"
@@ -1504,7 +1542,7 @@ version = "2.2.0"
description = "Get the currently executing AST node of a frame, and other information"
optional = false
python-versions = ">=3.8"
-groups = ["dev"]
+groups = ["main", "dev"]
files = [
{file = "executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa"},
{file = "executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755"},
@@ -1680,6 +1718,7 @@ files = [
[package.dependencies]
aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
+s3fs = {version = "*", optional = true, markers = "extra == \"s3\""}
[package.extras]
abfs = ["adlfs"]
@@ -1744,6 +1783,24 @@ python-dateutil = ">=2.8.1"
[package.extras]
dev = ["flake8", "markdown", "twine", "wheel"]
+[[package]]
+name = "googleapis-common-protos"
+version = "1.70.0"
+description = "Common protobufs used in Google APIs"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"},
+ {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"},
+]
+
+[package.dependencies]
+protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0)"]
+
[[package]]
name = "griffe"
version = "1.7.3"
@@ -1874,14 +1931,14 @@ test = ["Cython (>=0.29.24)"]
[[package]]
name = "httpx"
-version = "0.27.2"
+version = "0.28.1"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
- {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
- {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
+ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
+ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
]
[package.dependencies]
@@ -1889,7 +1946,6 @@ anyio = "*"
certifi = "*"
httpcore = "==1.*"
idna = "*"
-sniffio = "*"
[package.extras]
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
@@ -2588,17 +2644,51 @@ pydantic = ">=1.10.8"
pyyaml = "*"
[[package]]
-name = "logfire-api"
-version = "3.25.0"
-description = "Shim for the Logfire SDK which does nothing unless Logfire is installed"
+name = "logfire"
+version = "3.20.0"
+description = "The best Python observability tool! 🪵🔥"
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
groups = ["main"]
files = [
- {file = "logfire_api-3.25.0-py3-none-any.whl", hash = "sha256:cc1c2482d6a738e15cd165c483577f8ef7a8a4c462eafa0f6129aa9077676a8d"},
- {file = "logfire_api-3.25.0.tar.gz", hash = "sha256:d6aeeeb246cc8d7aeb14a503523422292047db5e7be35d47c8979f70b0962bb0"},
+ {file = "logfire-3.20.0-py3-none-any.whl", hash = "sha256:561ea5f197f4c3a4e521e893f35535b955fc22592fd6cbd5901434c5ad16226d"},
+ {file = "logfire-3.20.0.tar.gz", hash = "sha256:592f242edb6ef7e33cc245de6f457ac92c5d012cf48cff0725830bdc5ba602bf"},
]
+[package.dependencies]
+executing = ">=2.0.1"
+opentelemetry-exporter-otlp-proto-http = ">=1.21.0,<1.35.0"
+opentelemetry-instrumentation = ">=0.41b0"
+opentelemetry-sdk = ">=1.21.0,<1.35.0"
+protobuf = ">=4.23.4"
+rich = ">=13.4.2"
+tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
+typing-extensions = ">=4.1.0"
+
+[package.extras]
+aiohttp = ["opentelemetry-instrumentation-aiohttp-client (>=0.42b0)"]
+aiohttp-client = ["opentelemetry-instrumentation-aiohttp-client (>=0.42b0)"]
+aiohttp-server = ["opentelemetry-instrumentation-aiohttp-server (>=0.55b0)"]
+asgi = ["opentelemetry-instrumentation-asgi (>=0.42b0)"]
+asyncpg = ["opentelemetry-instrumentation-asyncpg (>=0.42b0)"]
+aws-lambda = ["opentelemetry-instrumentation-aws-lambda (>=0.42b0)"]
+celery = ["opentelemetry-instrumentation-celery (>=0.42b0)"]
+django = ["opentelemetry-instrumentation-asgi (>=0.42b0)", "opentelemetry-instrumentation-django (>=0.42b0)"]
+fastapi = ["opentelemetry-instrumentation-fastapi (>=0.42b0)"]
+flask = ["opentelemetry-instrumentation-flask (>=0.42b0)"]
+httpx = ["opentelemetry-instrumentation-httpx (>=0.42b0)"]
+mysql = ["opentelemetry-instrumentation-mysql (>=0.42b0)"]
+psycopg = ["opentelemetry-instrumentation-psycopg (>=0.42b0)", "packaging"]
+psycopg2 = ["opentelemetry-instrumentation-psycopg2 (>=0.42b0)", "packaging"]
+pymongo = ["opentelemetry-instrumentation-pymongo (>=0.42b0)"]
+redis = ["opentelemetry-instrumentation-redis (>=0.42b0)"]
+requests = ["opentelemetry-instrumentation-requests (>=0.42b0)"]
+sqlalchemy = ["opentelemetry-instrumentation-sqlalchemy (>=0.42b0)"]
+sqlite3 = ["opentelemetry-instrumentation-sqlite3 (>=0.42b0)"]
+starlette = ["opentelemetry-instrumentation-starlette (>=0.42b0)"]
+system-metrics = ["opentelemetry-instrumentation-system-metrics (>=0.42b0)"]
+wsgi = ["opentelemetry-instrumentation-wsgi (>=0.42b0)"]
+
[[package]]
name = "loguru"
version = "0.7.3"
@@ -2634,6 +2724,31 @@ files = [
docs = ["mdx_gh_links (>=0.2)", "mkdocs (>=1.6)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
testing = ["coverage", "pyyaml"]
+[[package]]
+name = "markdown-it-py"
+version = "3.0.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
+ {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
[[package]]
name = "markdownify"
version = "1.1.0"
@@ -2766,6 +2881,18 @@ cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
[[package]]
name = "mergedeep"
version = "1.3.4"
@@ -3584,6 +3711,124 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<16)"]
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
+[[package]]
+name = "opentelemetry-api"
+version = "1.34.1"
+description = "OpenTelemetry Python API"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_api-1.34.1-py3-none-any.whl", hash = "sha256:b7df4cb0830d5a6c29ad0c0691dbae874d8daefa934b8b1d642de48323d32a8c"},
+ {file = "opentelemetry_api-1.34.1.tar.gz", hash = "sha256:64f0bd06d42824843731d05beea88d4d4b6ae59f9fe347ff7dfa2cc14233bbb3"},
+]
+
+[package.dependencies]
+importlib-metadata = ">=6.0,<8.8.0"
+typing-extensions = ">=4.5.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-common"
+version = "1.34.1"
+description = "OpenTelemetry Protobuf encoding"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_common-1.34.1-py3-none-any.whl", hash = "sha256:8e2019284bf24d3deebbb6c59c71e6eef3307cd88eff8c633e061abba33f7e87"},
+ {file = "opentelemetry_exporter_otlp_proto_common-1.34.1.tar.gz", hash = "sha256:b59a20a927facd5eac06edaf87a07e49f9e4a13db487b7d8a52b37cb87710f8b"},
+]
+
+[package.dependencies]
+opentelemetry-proto = "1.34.1"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-http"
+version = "1.34.1"
+description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_http-1.34.1-py3-none-any.whl", hash = "sha256:5251f00ca85872ce50d871f6d3cc89fe203b94c3c14c964bbdc3883366c705d8"},
+ {file = "opentelemetry_exporter_otlp_proto_http-1.34.1.tar.gz", hash = "sha256:aaac36fdce46a8191e604dcf632e1f9380c7d5b356b27b3e0edb5610d9be28ad"},
+]
+
+[package.dependencies]
+googleapis-common-protos = ">=1.52,<2.0"
+opentelemetry-api = ">=1.15,<2.0"
+opentelemetry-exporter-otlp-proto-common = "1.34.1"
+opentelemetry-proto = "1.34.1"
+opentelemetry-sdk = ">=1.34.1,<1.35.0"
+requests = ">=2.7,<3.0"
+typing-extensions = ">=4.5.0"
+
+[[package]]
+name = "opentelemetry-instrumentation"
+version = "0.55b1"
+description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_instrumentation-0.55b1-py3-none-any.whl", hash = "sha256:cbb1496b42bc394e01bc63701b10e69094e8564e281de063e4328d122cc7a97e"},
+ {file = "opentelemetry_instrumentation-0.55b1.tar.gz", hash = "sha256:2dc50aa207b9bfa16f70a1a0571e011e737a9917408934675b89ef4d5718c87b"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.4,<2.0"
+opentelemetry-semantic-conventions = "0.55b1"
+packaging = ">=18.0"
+wrapt = ">=1.0.0,<2.0.0"
+
+[[package]]
+name = "opentelemetry-proto"
+version = "1.34.1"
+description = "OpenTelemetry Python Proto"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_proto-1.34.1-py3-none-any.whl", hash = "sha256:eb4bb5ac27f2562df2d6857fc557b3a481b5e298bc04f94cc68041f00cebcbd2"},
+ {file = "opentelemetry_proto-1.34.1.tar.gz", hash = "sha256:16286214e405c211fc774187f3e4bbb1351290b8dfb88e8948af209ce85b719e"},
+]
+
+[package.dependencies]
+protobuf = ">=5.0,<6.0"
+
+[[package]]
+name = "opentelemetry-sdk"
+version = "1.34.1"
+description = "OpenTelemetry Python SDK"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_sdk-1.34.1-py3-none-any.whl", hash = "sha256:308effad4059562f1d92163c61c8141df649da24ce361827812c40abb2a1e96e"},
+ {file = "opentelemetry_sdk-1.34.1.tar.gz", hash = "sha256:8091db0d763fcd6098d4781bbc80ff0971f94e260739aa6afe6fd379cdf3aa4d"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.34.1"
+opentelemetry-semantic-conventions = "0.55b1"
+typing-extensions = ">=4.5.0"
+
+[[package]]
+name = "opentelemetry-semantic-conventions"
+version = "0.55b1"
+description = "OpenTelemetry Semantic Conventions"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_semantic_conventions-0.55b1-py3-none-any.whl", hash = "sha256:5da81dfdf7d52e3d37f8fe88d5e771e191de924cfff5f550ab0b8f7b2409baed"},
+ {file = "opentelemetry_semantic_conventions-0.55b1.tar.gz", hash = "sha256:ef95b1f009159c28d7a7849f5cbc71c4c34c845bb514d66adfdf1b3fff3598b3"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.34.1"
+typing-extensions = ">=4.5.0"
+
[[package]]
name = "outlines"
version = "0.0.46"
@@ -4107,22 +4352,23 @@ files = [
[[package]]
name = "protobuf"
-version = "6.31.0"
+version = "5.29.5"
description = ""
-optional = true
-python-versions = ">=3.9"
+optional = false
+python-versions = ">=3.8"
groups = ["main"]
-markers = "extra == \"all\""
files = [
- {file = "protobuf-6.31.0-cp310-abi3-win32.whl", hash = "sha256:10bd62802dfa0588649740a59354090eaf54b8322f772fbdcca19bc78d27f0d6"},
- {file = "protobuf-6.31.0-cp310-abi3-win_amd64.whl", hash = "sha256:3e987c99fd634be8347246a02123250f394ba20573c953de133dc8b2c107dd71"},
- {file = "protobuf-6.31.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2c812f0f96ceb6b514448cefeb1df54ec06dde456783f5099c0e2f8a0f2caa89"},
- {file = "protobuf-6.31.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:67ce50195e4e584275623b8e6bc6d3d3dfd93924bf6116b86b3b8975ab9e4571"},
- {file = "protobuf-6.31.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:5353e38844168a327acd2b2aa440044411cd8d1b6774d5701008bd1dba067c79"},
- {file = "protobuf-6.31.0-cp39-cp39-win32.whl", hash = "sha256:96d8da25c83b11db5fe9e0376351ce25e7205e13224d939e097b6f82a72af824"},
- {file = "protobuf-6.31.0-cp39-cp39-win_amd64.whl", hash = "sha256:00a873c06efdfb854145d9ded730b09cf57d206075c38132674093370e2edabb"},
- {file = "protobuf-6.31.0-py3-none-any.whl", hash = "sha256:6ac2e82556e822c17a8d23aa1190bbc1d06efb9c261981da95c71c9da09e9e23"},
- {file = "protobuf-6.31.0.tar.gz", hash = "sha256:314fab1a6a316469dc2dd46f993cbbe95c861ea6807da910becfe7475bc26ffe"},
+ {file = "protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079"},
+ {file = "protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc"},
+ {file = "protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671"},
+ {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015"},
+ {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61"},
+ {file = "protobuf-5.29.5-cp38-cp38-win32.whl", hash = "sha256:ef91363ad4faba7b25d844ef1ada59ff1604184c0bcd8b39b8a6bef15e1af238"},
+ {file = "protobuf-5.29.5-cp38-cp38-win_amd64.whl", hash = "sha256:7318608d56b6402d2ea7704ff1e1e4597bee46d760e7e4dd42a3d45e24b87f2e"},
+ {file = "protobuf-5.29.5-cp39-cp39-win32.whl", hash = "sha256:6f642dc9a61782fa72b90878af134c5afe1917c89a568cd3476d758d3c3a0736"},
+ {file = "protobuf-5.29.5-cp39-cp39-win_amd64.whl", hash = "sha256:470f3af547ef17847a28e1f47200a1cbf0ba3ff57b7de50d22776607cd2ea353"},
+ {file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"},
+ {file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"},
]
[[package]]
@@ -4483,7 +4729,7 @@ version = "2.19.1"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
-groups = ["dev"]
+groups = ["main", "dev"]
files = [
{file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"},
{file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"},
@@ -4597,6 +4843,21 @@ files = [
{file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"},
]
+[[package]]
+name = "python-ulid"
+version = "3.0.0"
+description = "Universally unique lexicographically sortable identifier"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "python_ulid-3.0.0-py3-none-any.whl", hash = "sha256:e4c4942ff50dbd79167ad01ac725ec58f924b4018025ce22c858bfcff99a5e31"},
+ {file = "python_ulid-3.0.0.tar.gz", hash = "sha256:e50296a47dc8209d28629a22fc81ca26c00982c78934bd7766377ba37ea49a9f"},
+]
+
+[package.extras]
+pydantic = ["pydantic (>=2.0)"]
+
[[package]]
name = "pytz"
version = "2025.2"
@@ -5027,6 +5288,26 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+[[package]]
+name = "rich"
+version = "14.0.0"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0"},
+ {file = "rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0"
+pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
[[package]]
name = "rpds-py"
version = "0.25.1"
@@ -5180,7 +5461,7 @@ description = "C version of reader, parser and emitter for ruamel.yaml derived f
optional = false
python-versions = ">=3.9"
groups = ["main"]
-markers = "platform_python_implementation == \"CPython\" and python_version < \"3.14\""
+markers = "platform_python_implementation == \"CPython\""
files = [
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:11f891336688faf5156a36293a9c362bdc7c88f03a8a027c2c1d8e0bcde998e5"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"},
@@ -5188,7 +5469,6 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"},
- {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"},
@@ -5197,7 +5477,6 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"},
- {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"},
@@ -5206,7 +5485,6 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"},
- {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"},
@@ -5215,7 +5493,6 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"},
- {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"},
@@ -5224,7 +5501,6 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"},
- {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"},
{file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"},
@@ -5258,6 +5534,22 @@ files = [
{file = "ruff-0.10.0.tar.gz", hash = "sha256:fa1554e18deaf8aa097dbcfeafaf38b17a2a1e98fdc18f50e62e8a836abee392"},
]
+[[package]]
+name = "s3fs"
+version = "0.4.2"
+description = "Convenient Filesystem interface over S3"
+optional = false
+python-versions = ">= 3.5"
+groups = ["main"]
+files = [
+ {file = "s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3"},
+ {file = "s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0"},
+]
+
+[package.dependencies]
+botocore = ">=1.12.91"
+fsspec = ">=0.6.0"
+
[[package]]
name = "s3transfer"
version = "0.13.0"
@@ -5821,7 +6113,7 @@ version = "2.2.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
-groups = ["dev"]
+groups = ["main", "dev"]
markers = "python_version < \"3.11\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
@@ -6652,6 +6944,95 @@ files = [
[package.extras]
dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"]
+[[package]]
+name = "wrapt"
+version = "1.17.2"
+description = "Module for decorators, wrappers and monkey patching."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62"},
+ {file = "wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563"},
+ {file = "wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72"},
+ {file = "wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317"},
+ {file = "wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504"},
+ {file = "wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a"},
+ {file = "wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5c803c401ea1c1c18de70a06a6f79fcc9c5acfc79133e9869e730ad7f8ad8ef9"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f917c1180fdb8623c2b75a99192f4025e412597c50b2ac870f156de8fb101119"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ecc840861360ba9d176d413a5489b9a0aff6d6303d7e733e2c4623cfa26904a6"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb87745b2e6dc56361bfde481d5a378dc314b252a98d7dd19a651a3fa58f24a9"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58455b79ec2661c3600e65c0a716955adc2410f7383755d537584b0de41b1d8a"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4e42a40a5e164cbfdb7b386c966a588b1047558a990981ace551ed7e12ca9c2"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:91bd7d1773e64019f9288b7a5101f3ae50d3d8e6b1de7edee9c2ccc1d32f0c0a"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:bb90fb8bda722a1b9d48ac1e6c38f923ea757b3baf8ebd0c82e09c5c1a0e7a04"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:08e7ce672e35efa54c5024936e559469436f8b8096253404faeb54d2a878416f"},
+ {file = "wrapt-1.17.2-cp38-cp38-win32.whl", hash = "sha256:410a92fefd2e0e10d26210e1dfb4a876ddaf8439ef60d6434f21ef8d87efc5b7"},
+ {file = "wrapt-1.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:95c658736ec15602da0ed73f312d410117723914a5c91a14ee4cdd72f1d790b3"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:99039fa9e6306880572915728d7f6c24a86ec57b0a83f6b2491e1d8ab0235b9a"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2696993ee1eebd20b8e4ee4356483c4cb696066ddc24bd70bcbb80fa56ff9061"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:612dff5db80beef9e649c6d803a8d50c409082f1fedc9dbcdfde2983b2025b82"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c2caa1585c82b3f7a7ab56afef7b3602021d6da34fbc1cf234ff139fed3cd9"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c958bcfd59bacc2d0249dcfe575e71da54f9dcf4a8bdf89c4cb9a68a1170d73f"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc78a84e2dfbc27afe4b2bd7c80c8db9bca75cc5b85df52bfe634596a1da846b"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ba0f0eb61ef00ea10e00eb53a9129501f52385c44853dbd6c4ad3f403603083f"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1e1fe0e6ab7775fd842bc39e86f6dcfc4507ab0ffe206093e76d61cde37225c8"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c86563182421896d73858e08e1db93afdd2b947a70064b813d515d66549e15f9"},
+ {file = "wrapt-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f393cda562f79828f38a819f4788641ac7c4085f30f1ce1a68672baa686482bb"},
+ {file = "wrapt-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:36ccae62f64235cf8ddb682073a60519426fdd4725524ae38874adf72b5f2aeb"},
+ {file = "wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8"},
+ {file = "wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3"},
+]
+
[[package]]
name = "xformers"
version = "0.0.27.post2"
@@ -6970,5 +7351,5 @@ tracing = []
[metadata]
lock-version = "2.1"
-python-versions = "^3.10"
-content-hash = "70c9b2fdea3938a6ed1aa0a1316aaadf2d0a0633fbed8340bdf29d3f35337511"
+python-versions = ">=3.10,<3.14"
+content-hash = "f89b57179a89ba18d1d0cf3cbc8785348925eabc3b921147c0bbf80326f987f7"
diff --git a/pyproject.toml b/pyproject.toml
index 81ed7c2..e4c945e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ packages = [{ include = "rigging" }]
# Dependencies
[tool.poetry.dependencies]
-python = "^3.10"
+python = ">=3.10,<3.14"
pydantic = "^2.7.3"
pydantic-xml = "^2.11.0"
loguru = "^0.7.2"
@@ -21,11 +21,11 @@ xmltodict = "^0.13.0"
colorama = "^0.4.6"
boto3 = "^1.35.0"
boto3-stubs = { extras = ["s3"], version = "^1.35.0" }
-logfire-api = "^3.1.1"
jsonpath-ng = "^1.7.0"
ruamel-yaml = "^0.18.10"
jsonref = "^1.1.0"
mcp = "^1.5.0"
+dreadnode = ">=1.12.0"
vllm = { version = "^0.5.0", optional = true }
transformers = { version = "^4.41.0", optional = true }
@@ -34,10 +34,11 @@ elasticsearch = { version = "^8.13.2", optional = true }
asyncssh = { version = "^2.14.2", optional = true }
click = { version = "^8.1.7", optional = true }
-httpx = { version = "^0.27.0", optional = true }
+httpx = { version = "^0.28.0", optional = true }
aiodocker = { version = "^0.22.2", optional = true }
websockets = { version = "^13.0", optional = true }
+
[tool.poetry.extras]
tracing = ["logfire"]
examples = ["asyncssh", "click", "httpx", "aiodocker", "websockets"]
diff --git a/rigging/__init__.py b/rigging/__init__.py
index 279f8e9..65712ac 100644
--- a/rigging/__init__.py
+++ b/rigging/__init__.py
@@ -1,4 +1,5 @@
from rigging import (
+ caching,
data,
error,
generator,
@@ -94,6 +95,7 @@
"Transform",
"attr",
"await_",
+ "caching",
"chat",
"complete",
"data",
diff --git a/rigging/caching.py b/rigging/caching.py
new file mode 100644
index 0000000..765aabc
--- /dev/null
+++ b/rigging/caching.py
@@ -0,0 +1,42 @@
+import typing as t
+
+from loguru import logger
+
+if t.TYPE_CHECKING:
+ from rigging.message import Message
+
+CacheMode = t.Literal["latest"]
+"""
+How to handle cache_control entries on messages.
+
+- latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference.
+"""
+
+
+def apply_cache_mode_to_messages(
+ mode: CacheMode | None,
+ messages: "list[list[Message]]",
+) -> "list[list[Message]]":
+ if mode is None:
+ return messages
+
+ if mode != "latest":
+ logger.warning(
+ f"Unknown caching mode '{mode}', defaulting to 'latest'",
+ )
+ mode = "latest"
+
+ # first remove existing cache settings
+ updated: list[list[Message]] = []
+ for _messages in messages:
+ updated = [
+ *updated,
+ [m.clone().cache(cache_control=False) for m in _messages],
+ ]
+
+ # then apply the latest cache settings
+ for _messages in updated:
+ for message in [m for m in _messages if m.role != "assistant"][-2:]:
+ message.cache(cache_control=True)
+
+ return updated
diff --git a/rigging/chat.py b/rigging/chat.py
index 5e5c242..0646930 100644
--- a/rigging/chat.py
+++ b/rigging/chat.py
@@ -18,7 +18,8 @@
from typing import runtime_checkable
from uuid import UUID, uuid4
-from loguru import logger
+import dreadnode as dn
+from dreadnode.metric import ScorerCallable
from pydantic import (
BaseModel,
ConfigDict,
@@ -29,6 +30,7 @@
computed_field,
)
+from rigging.caching import CacheMode, apply_cache_mode_to_messages
from rigging.error import MaxDepthError, PipelineWarning
from rigging.generator import GenerateParams, Generator, get_generator
from rigging.generator.base import StopReason, Usage
@@ -46,7 +48,6 @@
from rigging.model import Model, ModelT, SystemErrorModel, ValidationErrorModel
from rigging.tokenizer import TokenizedChat, Tokenizer, get_tokenizer
from rigging.tools import Tool, ToolCall, ToolChoice, ToolMode
-from rigging.tracing import Span, tracer
from rigging.transform import (
PostTransform,
Transform,
@@ -59,6 +60,7 @@
from rigging.util import flatten_list, get_qualified_name
if t.TYPE_CHECKING:
+ from dreadnode.scorers.rigging import ChatFilterFunction, ChatFilterMode
from elasticsearch import AsyncElasticsearch
from rigging.data import ElasticOpType
@@ -84,20 +86,15 @@
- include: Mark the message as failed and include it in the final output.
"""
-CacheMode = t.Literal["latest"]
-"""
-How to handle cache_control entries on messages.
-
-- latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference.
-"""
-
class Chat(BaseModel):
"""
A completed chat interaction.
"""
- model_config = ConfigDict(arbitrary_types_allowed=True)
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True, json_schema_extra={"rigging.type": "chat"}
+ )
uuid: UUID = Field(default_factory=uuid4)
"""The unique identifier for the chat."""
@@ -765,19 +762,12 @@ def depth(self) -> int:
def _wrap_watch_callback(callback: WatchChatCallback) -> WatchChatCallback:
callback_name = get_qualified_name(callback)
-
- async def traced_watch_callback(chats: list[Chat]) -> None:
- with tracer.span(
- f"Watch with {callback_name}()",
- callback=callback_name,
- chat_count=len(chats),
- chat_ids=[str(c.uuid) for c in chats],
- ):
- result = callback(chats)
- if inspect.isawaitable(result):
- await result
-
- return traced_watch_callback
+ return dn.task(
+ name=f"watch - {callback_name}",
+ attributes={"rigging.type": "chat_pipeline.watch_callback"},
+ log_inputs=True,
+ log_output=False,
+ )(callback)
# Pipeline
@@ -812,14 +802,18 @@ def __init__(
"""How to handle failures in the pipeline unless overridden in calls."""
self.caching: CacheMode | None = None
"""How to handle cache_control entries on messages."""
+ self.task_name: str = generator.to_identifier(short=True)
+ """The name of the pipeline task, used for logging and debugging."""
+ self.scorers: list[dn.Scorer[Chat]] = []
+ """List of dreadnode scorers to evaluate the generated chat upon completion."""
self.until_types: list[type[Model]] = []
self.tools: list[Tool[..., t.Any]] = []
self.tool_mode: ToolMode = "auto"
self.inject_tool_prompt = True
self.add_tool_stop_token = True
- self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
- self.map_callbacks: list[tuple[MapChatCallback, int]] = []
+ self.then_callbacks: list[tuple[ThenChatCallback, int, bool]] = []
+ self.map_callbacks: list[tuple[MapChatCallback, int, bool]] = []
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
self.transforms: list[Transform] = []
@@ -1011,6 +1005,9 @@ def clone(
new.errors_to_catch = self.errors_to_catch.copy()
new.errors_to_exclude = self.errors_to_exclude.copy()
new.caching = self.caching
+ new.task_name = self.task_name
+ new.scorers = self.scorers.copy()
+ new.transforms = self.transforms.copy()
new.watch_callbacks = self.watch_callbacks.copy()
@@ -1022,18 +1019,18 @@ def clone(
return new
new.then_callbacks = [
- (callback, max_depth)
+ (callback, max_depth, as_task)
if not hasattr(callback, "__self__")
or not isinstance(callback.__self__, ChatPipeline)
- else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr]
- for callback, max_depth in self.then_callbacks.copy()
+ else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr]
+ for callback, max_depth, as_task in self.then_callbacks.copy()
]
new.map_callbacks = [
- (callback, max_depth)
+ (callback, max_depth, as_task)
if not hasattr(callback, "__self__")
or not isinstance(callback.__self__, ChatPipeline)
- else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr]
- for callback, max_depth in self.map_callbacks.copy()
+ else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr]
+ for callback, max_depth, as_task in self.map_callbacks.copy()
]
new.transforms = [
callback
@@ -1045,13 +1042,13 @@ def clone(
if not isinstance(callbacks, bool):
new.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback in callbacks
]
new.map_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.map_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.map_callbacks
if callback in callbacks
]
new.transforms = [callback for callback in self.transforms if callback in callbacks]
@@ -1071,11 +1068,25 @@ def meta(self, **kwargs: t.Any) -> "ChatPipeline":
self.metadata.update(kwargs)
return self
+ def name(self, name: str) -> "ChatPipeline":
+ """
+ Sets the name of the pipeline.
+
+ Args:
+ name: The name to set for the pipeline.
+
+ Returns:
+ The updated pipeline.
+ """
+ self.task_name = name
+ return self
+
def then(
self,
*callbacks: ThenChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> "ChatPipeline":
"""
Registers one or many callbacks to be executed after the generation process completes.
@@ -1089,6 +1100,7 @@ def then(
callbacks: The callback functions to be added.
max_depth: The maximum depth to allow recursive pipeline calls during this callback.
allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.
+ as_task: Whether to create a task for this callback.
Returns:
The updated pipeline.
@@ -1110,7 +1122,7 @@ async def process(chat: Chat) -> Chat | None:
f"Callback '{get_qualified_name(callback)}' is already registered.",
)
- self.then_callbacks.extend([(callback, max_depth) for callback in callbacks])
+ self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
return self
def map(
@@ -1118,6 +1130,7 @@ def map(
*callbacks: MapChatCallback,
max_depth: int = DEFAULT_MAX_DEPTH,
allow_duplicates: bool = False,
+ as_task: bool = True,
) -> "ChatPipeline":
"""
Registers a callback to be executed after the generation process completes.
@@ -1131,6 +1144,7 @@ def map(
callbacks: The callback function to be executed.
max_depth: The maximum depth to allow recursive pipeline calls during this callback.
allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.
+ as_task: Whether to create a task for this callback.
Returns:
The updated pipeline.
@@ -1152,7 +1166,7 @@ async def process(chats: list[Chat]) -> list[Chat]:
f"Callback '{get_qualified_name(callback)}' is already registered.",
)
- self.map_callbacks.extend([(callback, max_depth) for callback in callbacks])
+ self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
return self
def transform(
@@ -1337,13 +1351,13 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
self.tools = [tool for tool in self.tools if tool.name not in new_names] + _tools
self.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback != self._then_tools # Always remove to update max_depth
]
self.then_callbacks.insert(
0, # make sure this is first
- (self._then_tools, max_depth),
+ (self._then_tools, max_depth, False),
)
if mode is not None:
@@ -1364,6 +1378,44 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
return self
+ def score(
+ self,
+ *scorers: dn.Scorer[Chat] | ScorerCallable[Chat],
+ filter: "ChatFilterMode | ChatFilterFunction" = "last",
+ ) -> "ChatPipeline":
+ """
+ Adds one or more scorers to the pipeline to evaluate the generated chat upon completion.
+
+ Args:
+ *scorers: The scorer or scorers to be added. These can be either:
+ - A dreadnode.Scorer instance.
+ - A callable function that can be converted to a dreadnode.Scorer.
+ filter: The strategy for filtering which messages to include:
+ - "all": Use all messages in the chat.
+ - "last": Use only the last message.
+ - "first": Use only the first message.
+ - "user": Use only user messages.
+ - "assistant": Use only assistant messages.
+ - "last_user": Use only the last user message.
+ - "last_assistant": Use only the last assistant message.
+ - A callable that takes a list of `Message` objects and returns a filtered list.
+
+ Returns:
+ The updated pipeline.
+ """
+ self.scorers.extend(
+ [
+ dn.scorers.wrap_chat(
+ scorer if isinstance(scorer, dn.Scorer) else dn.Scorer.from_callable(scorer),
+ filter=filter,
+ )
+ for scorer in scorers
+ ]
+ )
+ return self
+
+ # Internal callbacks for handling tools and parsing
+
def until_parsed_as(
self,
*types: type[ModelT],
@@ -1410,16 +1462,14 @@ def until_parsed_as(
max_depth = max_rounds or max_depth
self.then_callbacks = [
- (callback, max_depth)
- for callback, max_depth in self.then_callbacks
+ (callback, max_depth, as_task)
+ for callback, max_depth, as_task in self.then_callbacks
if callback != self._then_parse
]
- self.then_callbacks.append((self._then_parse, max_depth))
+ self.then_callbacks.append((self._then_parse, max_depth, False))
return self
- # Internal callbacks for handling tools and parsing
-
async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
if not chat.last.tool_calls:
return None
@@ -1467,8 +1517,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
next_pipeline = self.clone(chat=chat)
+ type_names = " | ".join(sorted(until_type.__name__ for until_type in self.until_types))
+ task_name = f"parse - {type_names}"
+
try:
- chat.last.parse_many(*self.until_types)
+ with dn.task_span(task_name, attributes={"rigging.type": "chat_pipeline.parse"}):
+ dn.log_input("message", chat.last)
+ parsed = chat.last.parse_many(*self.until_types)
+ dn.log_output("parsed", parsed)
except ValidationError as e:
next_pipeline.add(
Message.from_model(
@@ -1502,33 +1558,6 @@ def _fit_params(
params = [self.params.merge_with(p) for p in params]
return [(p or GenerateParams()) for p in params]
- def _apply_cache_mode_to_messages(
- self,
- messages: list[list[Message]],
- ) -> list[list[Message]]:
- if self.caching is None:
- return messages
-
- if self.caching != "latest":
- logger.warning(
- f"Unknown caching mode '{self.caching}', defaulting to 'latest'",
- )
-
- # first remove existing cache settings
- updated: list[list[Message]] = []
- for _messages in messages:
- updated = [
- *updated,
- [m.clone().cache(cache_control=False) for m in _messages],
- ]
-
- # then apply the latest cache settings
- for _messages in updated:
- for message in [m for m in _messages if m.role != "assistant"][-2:]:
- message.cache(cache_control=True)
-
- return updated
-
@dataclass
class CallbackState:
chat: Chat
@@ -1548,51 +1577,69 @@ async def complete() -> None:
state.completed = True
state.ready_event.set()
- with tracer.span(
- f"Then with {callback_name}()",
- callback=callback_name,
- chat_id=str(state.chat.uuid),
- ):
- async with contextlib.AsyncExitStack() as exit_stack:
- exit_stack.push_async_callback(complete)
-
- result = callback(state.chat)
- if inspect.isawaitable(result):
- result = await result
+ async with contextlib.AsyncExitStack() as exit_stack:
+ exit_stack.push_async_callback(complete)
- if result is None or isinstance(result, Chat):
- state.chat = result or state.chat
- return
+ result = callback(state.chat)
+ if inspect.isawaitable(result):
+ result = await result
- if isinstance(result, contextlib.AbstractAsyncContextManager):
- result = await exit_stack.enter_async_context(result)
+ if result is None or isinstance(result, Chat):
+ state.chat = result or state.chat
+ return
- if not inspect.isasyncgen(result):
- raise TypeError(
- f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None",
- )
+ if isinstance(result, contextlib.AbstractAsyncContextManager):
+ result = await exit_stack.enter_async_context(result)
- generator = t.cast(
- "PipelineStepGenerator",
- await exit_stack.enter_async_context(aclosing(result)),
+ if not inspect.isasyncgen(result):
+ raise TypeError(
+ f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None",
)
- async for step in generator:
- state.step = step
- state.ready_event.set()
- await state.continue_event.wait()
+ generator = t.cast(
+ "PipelineStepGenerator",
+ await exit_stack.enter_async_context(aclosing(result)),
+ )
+ async for step in generator:
+ state.step = step
+
+ state.ready_event.set()
+ await state.continue_event.wait()
+
+ state.ready_event.clear()
+ state.continue_event.clear()
+ state.step = None
+
+ state.chat = step.chats[-1] if step.chats else state.chat
- state.ready_event.clear()
- state.continue_event.clear()
- state.step = None
+ async def _score_chats(self, chats: list[Chat]) -> None:
+ if not self.scorers:
+ return
+
+ for scorer in self.scorers:
+ for metric in await asyncio.gather(
+ *[scorer(chat) for chat in chats],
+ ):
+ dn.log_metric(scorer.name, metric)
- state.chat = step.chats[-1] if step.chats else state.chat
+ def _raise_if_failed(
+ self,
+ chats: list[Chat | BaseException] | ChatList,
+ on_failed: FailMode | None = None,
+ ) -> None:
+ for chat in chats:
+ error = chat.error if isinstance(chat, Chat) else chat
+ if error is not None and (
+ on_failed == "raise"
+ or not any(isinstance(error, t) for t in self.errors_to_catch)
+ or any(isinstance(error, t) for t in self.errors_to_exclude)
+ ):
+ raise error
# Run methods
async def _step( # noqa: PLR0915, PLR0912
self,
- span: Span,
messages: list[list[Message]],
params: list[GenerateParams],
on_failed: FailMode,
@@ -1644,8 +1691,16 @@ async def _step( # noqa: PLR0915, PLR0912
# Pass the messages to the generator
try:
- messages = self._apply_cache_mode_to_messages(messages)
- generated = await self.generator.generate_messages(messages, params)
+ messages = apply_cache_mode_to_messages(self.caching, messages)
+
+ with dn.task_span(
+ f"generate - {self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "chat_pipeline.generate"},
+ ):
+ dn.log_input("messages", messages)
+ dn.log_input("params", params)
+ generated = await self.generator.generate_messages(messages, params)
+ dn.log_output("generated", generated)
# If we got a total failure here for generation as a whole,
# we can't distinguish between incoming messages in terms
@@ -1653,9 +1708,6 @@ async def _step( # noqa: PLR0915, PLR0912
# on all of them.
except Exception as error: # noqa: BLE001
- span.set_attribute("failed", True)
- span.set_attribute("error", error)
-
chats = ChatList(
[
Chat(
@@ -1724,7 +1776,6 @@ async def _step( # noqa: PLR0915, PLR0912
# Yield what we generated
- span.set_attribute("chats", chats)
current_step = PipelineStep(
state="generated",
chats=chats,
@@ -1734,23 +1785,13 @@ async def _step( # noqa: PLR0915, PLR0912
# Check if we should immediately raise
- for chat in chats:
- if chat.error is not None and (
- on_failed == "raise"
- or not any(isinstance(chat.error, t) for t in self.errors_to_catch)
- or any(isinstance(chat.error, t) for t in self.errors_to_exclude)
- ):
- span.set_attribute("error", chat.error)
- span.set_attribute("failed", True)
- raise chat.error
+ self._raise_if_failed(chats, on_failed)
# Chat cleanup
if on_failed == "skip":
chats = ChatList([chat for chat in chats if not chat.failed])
- span.set_attribute("chats", chats)
-
if len(chats) == 0 or all(chat.failed for chat in chats):
yield PipelineStep(
state="final",
@@ -1761,7 +1802,7 @@ async def _step( # noqa: PLR0915, PLR0912
# Then callbacks
- for then_callback, max_depth in self.then_callbacks:
+ for then_callback, max_depth, as_task in self.then_callbacks:
callback_name = get_qualified_name(then_callback)
states = [
@@ -1773,8 +1814,19 @@ async def _step( # noqa: PLR0915, PLR0912
for chat in chats
]
+ callback_task = (
+ dn.task(
+ name=f"then - {callback_name}",
+ attributes={"rigging.type": "chat_pipeline.then_callback"},
+ log_inputs=True,
+ log_output=True,
+ )(then_callback)
+ if as_task
+ else then_callback
+ )
+
tasks = [
- asyncio.create_task(self._process_then_callback(then_callback, state))
+ asyncio.create_task(self._process_then_callback(callback_task, state)) # type: ignore [arg-type]
for state in states
]
@@ -1818,7 +1870,6 @@ async def _step( # noqa: PLR0915, PLR0912
chats = ChatList([state.chat for state in states if state.chat])
- span.set_attribute("chats", chats)
current_step = PipelineStep(
state="callback",
chats=chats,
@@ -1832,8 +1883,6 @@ async def _step( # noqa: PLR0915, PLR0912
if on_failed == "skip":
chats = ChatList([chat for chat in chats if not chat.failed])
- span.set_attribute("chats", chats)
-
if len(chats) == 0 or all(chat.failed for chat in chats):
yield PipelineStep(
state="final",
@@ -1844,56 +1893,59 @@ async def _step( # noqa: PLR0915, PLR0912
# Map callbacks
- for map_callback, max_depth in self.map_callbacks:
+ for map_callback, max_depth, as_task in self.map_callbacks:
callback_name = get_qualified_name(map_callback)
- with tracer.span(
- f"Map with {callback_name}()",
- callback=callback_name,
- chat_count=len(chats),
- chat_ids=[str(c.uuid) for c in chats],
- ):
- async with contextlib.AsyncExitStack() as exit_stack:
- result = map_callback(chats)
- chats_or_generator = await result if inspect.isawaitable(result) else result
-
- if isinstance(result, contextlib.AbstractAsyncContextManager):
- result = await exit_stack.enter_async_context(result)
-
- if inspect.isasyncgen(chats_or_generator):
- generator = t.cast(
- "PipelineStepGenerator",
- await exit_stack.enter_async_context(
- aclosing(chats_or_generator),
- ),
- )
- async for step in generator:
- _step = step.with_parent(current_step)
- if _step.depth > max_depth:
- max_depth_error = MaxDepthError(
- max_depth,
- _step,
- callback_name,
- )
- if on_failed == "raise":
- raise max_depth_error
+ map_task = (
+ dn.task(
+ name=f"map - {callback_name}",
+ attributes={"rigging.type": "chat_pipeline.map_callback"},
+ log_inputs=True,
+ log_output=True,
+ )(map_callback)
+ if as_task
+ else map_callback
+ )
- chats = ChatList(chats)
- for chat in chats:
- chat.error = max_depth_error
- chat.failed = True
- else:
- yield _step
- chats = step.chats
+ async with contextlib.AsyncExitStack() as exit_stack:
+ result = map_task(chats)
+ if inspect.isawaitable(result):
+ result = await result
- chats = step.chats
+ if isinstance(result, contextlib.AbstractAsyncContextManager):
+ result = await exit_stack.enter_async_context(result)
- elif isinstance(chats_or_generator, list) and all(
- isinstance(c, Chat) for c in chats_or_generator
- ):
- chats = ChatList(chats_or_generator)
+ if inspect.isasyncgen(result):
+ generator = t.cast(
+ "PipelineStepGenerator",
+ await exit_stack.enter_async_context(
+ aclosing(result),
+ ),
+ )
+ async for step in generator:
+ _step = step.with_parent(current_step)
+ if _step.depth > max_depth:
+ max_depth_error = MaxDepthError(
+ max_depth,
+ _step,
+ callback_name,
+ )
+ if on_failed == "raise":
+ raise max_depth_error
+
+ chats = ChatList(chats)
+ for chat in chats:
+ chat.error = max_depth_error
+ chat.failed = True
+ else:
+ yield _step
+ chats = step.chats
+
+ chats = step.chats
+
+ elif isinstance(result, list) and all(isinstance(c, Chat) for c in result):
+ chats = ChatList(result)
- span.set_attribute("chats", chats)
current_step = PipelineStep(
state="callback",
chats=chats,
@@ -1905,7 +1957,6 @@ async def _step( # noqa: PLR0915, PLR0912
if on_failed == "skip":
chats = ChatList([chat for chat in chats if not chat.failed])
- span.set_attribute("chats", chats)
yield PipelineStep(
state="final",
chats=chats,
@@ -1941,19 +1992,15 @@ async def step(
messages = [self.chat.all]
params = self._fit_params(1, [self.params])
- with tracer.span(
- f"Chat with {self.generator.to_identifier()}",
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
async def run(
self,
*,
+ name: str | None = None,
on_failed: FailMode | None = None,
allow_failed: bool = False,
) -> Chat:
@@ -1961,6 +2008,7 @@ async def run(
Execute the generation process for a single message.
Args:
+ name: The name of the task for logging purposes.
on_failed: The behavior when a message fails to generate.
allow_failed: Deprecated, use `on_failed="include"`.
@@ -1977,16 +2025,44 @@ async def run(
if on_failed is None:
on_failed = "include" if allow_failed else self.on_failed
+ if on_failed == "skip":
+ raise ValueError(
+ "Cannot use 'skip' mode with single message generation (pass allow_failed=True or on_failed='include'/'raise')",
+ )
+
+ messages = [self.chat.all]
+ params = self._fit_params(1, [self.params])
+
last: PipelineStep | None = None
- async with self.step(on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ with dn.task_span(
+ name or f"pipeline - {self.task_name}",
+ label=name or f"pipeline_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run"},
+ ) as task:
+ dn.log_inputs(
+ messages=messages[0],
+ params=params[0],
+ generator_id=self.generator.to_identifier(),
+ )
+
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None and last.chats:
+ dn.log_output("chat", last.chats[-1])
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
- if last is None or last.state != "final":
- raise RuntimeError("The pipeline did not complete successfully")
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
- if not last.chats:
- raise RuntimeError("The pipeline process did not produce any chats")
+ if not last.chats:
+ raise RuntimeError("The pipeline process did not produce any chats")
return last.chats[-1]
@@ -2018,22 +2094,18 @@ async def step_many(
messages = [self.chat.all] * count
params = self._fit_params(count, params)
- with tracer.span(
- f"Chat with {self.generator.to_identifier()} (x{count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
async def run_many(
self,
count: int,
*,
params: t.Sequence[GenerateParams | None] | None = None,
+ name: str | None = None,
+ mode: t.Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList:
"""
@@ -2042,24 +2114,111 @@ async def run_many(
Args:
count: The number of times to execute the generation process.
params: A sequence of parameters to be used for each execution.
+ name: The name of the task for logging purposes.
+ mode: The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
on_failed: The behavior when a message fails to generate.
Returns:
- A list of generatated Chats.
+ A list of generated Chats.
"""
+ if count < 1:
+ raise ValueError("Count must be greater than 0")
+
+ on_failed = on_failed or self.on_failed
+
+ messages = [self.chat.all] * count
+ params = self._fit_params(count, params)
last: PipelineStep | None = None
- async with self.step_many(count, params=params, on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ with dn.task_span(
+ name or f"pipeline - {self.task_name} (x{count})",
+ label=name or f"pipeline_many_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run_many"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ messages=messages[0],
+ params=params[0],
+ generator_id=self.generator.to_identifier(),
+ )
+
+ if mode == "merged":
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None:
+ dn.log_output("chats", last.chats)
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
+
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
+
+ return last.chats
+
+ if mode == "parallel":
+ tasks = [asyncio.create_task(self.run(on_failed="include")) for _ in range(count)]
+ chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True)
+
+ self._raise_if_failed(chats_or_errors, on_failed)
+
+ chats = [
+ chat
+ for chat in chats_or_errors
+ if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed)
+ ]
- if last is None or last.state != "final":
- raise ValueError("The generation process did not complete successfully")
+ dn.log_output("chats", chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", chats)
- return last.chats
+ return ChatList(chats)
+
+ raise ValueError(
+ f"Invalid mode '{mode}', expected 'merged' or 'parallel'",
+ )
# Batch messages
+ def _fit_batch_args(
+ self,
+ many: t.Sequence[t.Sequence[Message]]
+ | t.Sequence[Message]
+ | t.Sequence[MessageDict]
+ | t.Sequence[str]
+ | MessageDict
+ | str,
+ params: t.Sequence[GenerateParams | None] | None = None,
+ ) -> tuple[int, list[list[Message]], list[GenerateParams]]:
+ # Get the maximum of either incoming messages or params
+
+ count = max(len(many), len(params) if params is not None else 0)
+
+ # If we have less messages than params, we need to either:
+ # 1. Error because we have >1 messages that we can't reasonably
+ # zip with our parameters of a different length
+ # 2. Duplicate a single message we have len(params) times as the
+ # user is just batching only over parameters
+
+ messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many]
+ if len(messages) < count:
+ if len(messages) != 1:
+ raise ValueError(
+ f"Can't fit {len(messages)} messages to {count} params",
+ )
+ messages = messages * count
+
+ params = self._fit_params(count, params)
+
+ return count, messages, params
+
@asynccontextmanager
async def step_batch(
self,
@@ -2088,37 +2247,12 @@ async def step_batch(
Pipeline steps.
"""
on_failed = on_failed or self.on_failed
+ _, messages, params = self._fit_batch_args(many, params)
- # Get the maximum of either incoming messages or params
-
- count = max(len(many), len(params) if params is not None else 0)
-
- # If we have less messages than params, we need to either:
- # 1. Error because we have >1 messages that we can't reasonably
- # zip with our parameters of a different length
- # 2. Duplicate a single message we have len(params) times as the
- # user is just batching only over parameters
-
- messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many]
- if len(messages) < count:
- if len(messages) != 1:
- raise ValueError(
- f"Can't fit {len(messages)} messages to {count} params",
- )
- messages = messages * count
-
- params = self._fit_params(count, params)
-
- with tracer.span(
- f"Chat batch with {self.generator.to_identifier()} ({count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- async with aclosing(
- self._step(span, messages, params, on_failed),
- ) as generator:
- yield generator
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as generator:
+ yield generator
async def run_batch(
self,
@@ -2130,6 +2264,8 @@ async def run_batch(
| str,
params: t.Sequence[GenerateParams | None] | None = None,
*,
+ name: str | None = None,
+ mode: t.Literal["merged", "parallel"] = "parallel",
on_failed: FailMode | None = None,
) -> ChatList:
"""
@@ -2141,21 +2277,76 @@ async def run_batch(
Args:
many: A sequence of sequences of messages to be generated.
params: A sequence of parameters to be used for each set of messages.
+ name: The name of the task for logging purposes.
+ mode: The mode of execution, either "merged" or "parallel".
+ - In "merged" mode, a single pipeline manages all generation simultaneously
+ - In "parallel" mode, independent pipelines are created for each generation
on_failed: The behavior when a message fails to generate.
Returns:
A list of generatated Chats.
"""
+ on_failed = on_failed or self.on_failed
+ count, messages, params = self._fit_batch_args(many, params)
last: PipelineStep | None = None
- async with self.step_batch(many, params=params, on_failed=on_failed) as steps:
- async for step in steps:
- last = step
+ with dn.task_span(
+ name or f"pipeline - {self.task_name} (batch x{count})",
+ label=name or f"pipeline_batch_{self.task_name}",
+ attributes={"rigging.type": "chat_pipeline.run_batch"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ messages=messages,
+ params=params,
+ generator_id=self.generator.to_identifier(),
+ )
+
+ if mode == "merged":
+ try:
+ async with aclosing(
+ self._step(messages, params, on_failed),
+ ) as steps:
+ async for step in steps:
+ last = step
+ finally:
+ if last is not None:
+ dn.log_output("chats", last.chats)
+ await self._score_chats(last.chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", last.chats)
+
+ if last is None or last.state != "final":
+ raise RuntimeError("The pipeline did not complete successfully")
+
+ return last.chats
+
+ if mode == "parallel":
+ tasks = [
+ asyncio.create_task(
+ self.clone().add(_messages).with_(_params).run(on_failed="include")
+ )
+ for _messages, _params in zip(messages, params, strict=True)
+ ]
+ chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True)
+
+ self._raise_if_failed(chats_or_errors, on_failed)
- if last is None or last.state != "final":
- raise ValueError("The generation process did not complete successfully")
+ chats = [
+ chat
+ for chat in chats_or_errors
+ if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed)
+ ]
- return last.chats
+ dn.log_output("chats", chats)
+ # TODO: Remove once Strikes UI is ported
+ task.set_attribute("chats", chats)
+
+ return ChatList(chats)
+
+ raise ValueError(
+ f"Invalid mode '{mode}', expected 'merged' or 'separate'",
+ )
# Generator iteration
@@ -2193,7 +2384,15 @@ async def run_over(
sub.generator = generator
coros.append(sub.run(allow_failed=(on_failed != "raise")))
- with tracer.span(f"Chat over {len(coros)} generators", count=len(coros)):
+ short_generators = [g.to_identifier(short=True) for g in _generators]
+ task_name = "iterate - " + ", ".join(short_generators)
+
+ with dn.task_span(
+ task_name,
+ label="iterate_over",
+ attributes={"rigging.type": "chat_pipeline.run_over"},
+ ):
+ dn.log_input("generators", [g.to_identifier() for g in _generators])
return ChatList(await asyncio.gather(*coros))
# Prompt binding
diff --git a/rigging/completion.py b/rigging/completion.py
index 4ab1e07..89250bb 100644
--- a/rigging/completion.py
+++ b/rigging/completion.py
@@ -11,6 +11,7 @@
from typing import runtime_checkable
from uuid import UUID, uuid4
+import dreadnode as dn
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field, computed_field
@@ -18,7 +19,6 @@
from rigging.generator import GenerateParams, Generator, get_generator
from rigging.generator.base import GeneratedText, StopReason, Usage
from rigging.parsing import parse_many
-from rigging.tracing import Span, tracer
from rigging.util import get_qualified_name
if t.TYPE_CHECKING:
@@ -571,20 +571,19 @@ def _until_parse_callback(self, text: str) -> bool:
return False
async def _watch_callback(self, completions: list[Completion]) -> None:
- def wrap_watch_callback(callback: WatchCompletionCallback) -> WatchCompletionCallback:
- async def traced_watch_callback(completions: list[Completion]) -> None:
- callback_name = get_qualified_name(callback)
- with tracer.span(
- f"Watch with {callback_name}()",
- callback=callback_name,
- competion_count=len(completions),
- completion_ids=[str(c.uuid) for c in completions],
- ):
- await callback(completions)
-
- return traced_watch_callback
-
- coros = [wrap_watch_callback(callback)(completions) for callback in self.watch_callbacks]
+ def wrap_watch_callback(
+ callback: WatchCompletionCallback,
+ ) -> t.Callable[[list[Completion]], t.Awaitable[None]]:
+ callback_name = get_qualified_name(callback)
+ return dn.task(
+ name=f"watch - {callback_name}",
+ attributes={"rigging.type": "completion_pipeline.watch_callback"},
+ log_inputs=True,
+ log_output=False,
+ )(callback)
+
+ traced_callbacks = [wrap_watch_callback(callback) for callback in self.watch_callbacks]
+ coros = [callback(completions) for callback in traced_callbacks]
await asyncio.gather(*coros)
# TODO: It's opaque exactly how we should blend multiple
@@ -633,37 +632,32 @@ async def _post_run(
for map_callback in self.map_callbacks:
callback_name = get_qualified_name(map_callback)
- with tracer.span(
- f"Map with {callback_name}()",
- callback=callback_name,
- completion_count=len(completions),
- completion_ids=[str(c.uuid) for c in completions],
- ):
- completions = await map_callback(completions)
- if not all(isinstance(c, Completion) for c in completions):
- raise ValueError(
- f".map() callback must return a Completion object or None ({callback_name})",
- )
-
- def wrap_then_callback(callback: ThenCompletionCallback) -> ThenCompletionCallback:
- callback_name = get_qualified_name(callback)
-
- async def traced_then_callback(completion: Completion) -> Completion | None:
- with tracer.span(
- f"Then with {callback_name}()",
- callback=callback_name,
- completion_id=str(completion.uuid),
- ):
- return await callback(completion)
-
- return traced_then_callback
+ traced_map_callback = dn.task(
+ name=f"map - {callback_name}",
+ attributes={"rigging.type": "completion_pipeline.map_callback"},
+ log_inputs=True,
+ log_output=True,
+ )(map_callback)
+ completions = await traced_map_callback(completions)
+ if not all(isinstance(c, Completion) for c in completions):
+ raise ValueError(
+ f".map() callback must return a Completion object or None ({callback_name})",
+ )
for then_callback in self.then_callbacks:
- coros = [wrap_then_callback(then_callback)(completion) for completion in completions]
+ callback_name = get_qualified_name(then_callback)
+ traced_then_callback = dn.task(
+ name=f"then - {callback_name}",
+ attributes={"rigging.type": "completion_pipeline.then_callback"},
+ log_inputs=True,
+ log_output=True,
+ )(then_callback)
+
+ coros = [traced_then_callback(completion) for completion in completions]
new_completions = await asyncio.gather(*coros)
if not all(isinstance(c, Completion) or c is None for c in new_completions):
raise ValueError(
- f".then() callback must return a Completion object or None ({get_qualified_name(then_callback)})",
+ f".then() callback must return a Completion object or None ({callback_name})",
)
completions = [
@@ -729,7 +723,7 @@ def _initialize_states(
async def _run( # noqa: PLR0912
self,
- span: Span,
+ span: dn.Span,
states: list[RunState],
on_failed: "FailMode",
batch_mode: bool = False, # noqa: FBT001, FBT002
@@ -805,12 +799,10 @@ async def _run( # noqa: PLR0912
for state in to_watch_states:
state.watched = True
- completions = await self._post_run(
+ return await self._post_run(
[s.completion for s in states if s.completion is not None],
on_failed,
)
- span.set_attribute("completions", completions)
- return completions
async def run(
self,
@@ -841,12 +833,21 @@ async def run(
on_failed = on_failed or self.on_failed
states = self._initialize_states(1)
- with tracer.span(
- f"Completion with {self.generator.to_identifier()}",
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return (await self._run(span, states, on_failed))[0]
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)}",
+ label=f"pipeline_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run"},
+ ) as task:
+ dn.log_inputs(
+ text=self.text,
+ params=self.params.to_dict() if self.params is not None else {},
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed)
+ completion = completions[0]
+ dn.log_output("completion", completion)
+ task.set_attribute("completions", completions)
+ return completion
__call__ = run
@@ -873,13 +874,21 @@ async def run_many(
on_failed = on_failed or self.on_failed
states = self._initialize_states(count, params)
- with tracer.span(
- f"Completion with {self.generator.to_identifier()} (x{count})",
- count=count,
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return await self._run(span, states, on_failed)
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)} (x{count})",
+ label=f"pipeline_many_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run_many"},
+ ) as task:
+ dn.log_inputs(
+ count=count,
+ text=self.text,
+ params=self.params.to_dict() if self.params is not None else {},
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed)
+ dn.log_output("completions", completions)
+ task.set_attribute("completions", completions)
+ return completions
# Batch completions
@@ -913,13 +922,21 @@ async def run_batch(
for state in states:
next(state.processor)
- with tracer.span(
- f"Completion batch with {self.generator.to_identifier()} ({len(states)})",
- count=len(states),
- generator_id=self.generator.to_identifier(),
- params=self.params.to_dict() if self.params is not None else {},
- ) as span:
- return await self._run(span, states, on_failed, batch_mode=True)
+ with dn.task_span(
+ f"pipeline - {self.generator.to_identifier(short=True)} (batch x{len(states)})",
+ label=f"pipeline_batch_{self.generator.to_identifier(short=True)}",
+ attributes={"rigging.type": "completion_pipeline.run_batch"},
+ ) as task:
+ dn.log_inputs(
+ count=len(states),
+ many=many,
+ params=params,
+ generator_id=self.generator.to_identifier(),
+ )
+ completions = await self._run(task, states, on_failed, batch_mode=True)
+ dn.log_output("completions", completions)
+ task.set_attribute("completions", completions)
+ return completions
# Generator iteration
@@ -957,6 +974,17 @@ async def run_over(
sub.generator = generator
coros.append(sub.run(allow_failed=(on_failed != "raise")))
- with tracer.span(f"Completion over {len(coros)} generators", count=len(coros)):
+ short_generators = [g.to_identifier(short=True) for g in _generators]
+ task_name = "iterate - " + ", ".join(short_generators)
+
+ with dn.task_span(
+ task_name,
+ label="iterate_over",
+ attributes={"rigging.type": "completion_pipeline.run_over"},
+ ) as task:
+ dn.log_input("generators", [g.to_identifier() for g in _generators])
completions = await asyncio.gather(*coros)
- return await self._post_run(completions, on_failed)
+ final_completions = await self._post_run(completions, on_failed)
+ dn.log_output("completions", final_completions)
+ task.set_attribute("completions", final_completions)
+ return final_completions
diff --git a/rigging/generator/__init__.py b/rigging/generator/__init__.py
index 8acf3a8..9cb346a 100644
--- a/rigging/generator/__init__.py
+++ b/rigging/generator/__init__.py
@@ -2,6 +2,8 @@
Generators produce completions for a given set of messages or text.
"""
+import typing as t
+
from rigging.generator.base import (
GeneratedMessage,
GeneratedText,
@@ -68,5 +70,12 @@ def get_transformers_lazy() -> type[Generator]:
"get_generator",
"get_identifier",
"register_generator",
- # TODO: We can't add VLLM and Transformers here because they are lazy loaded
]
+
+
+def __getattr__(name: str) -> t.Any:
+ if name == "VLLMGenerator":
+ return get_vllm_lazy()
+ if name == "TransformersGenerator":
+ return get_transformers_lazy()
+ raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/rigging/generator/__init__.pyi b/rigging/generator/__init__.pyi
new file mode 100644
index 0000000..027a502
--- /dev/null
+++ b/rigging/generator/__init__.pyi
@@ -0,0 +1,36 @@
+from rigging.generator.base import (
+ GeneratedMessage,
+ GeneratedText,
+ GenerateParams,
+ Generator,
+ StopReason,
+ Usage,
+ chat,
+ complete,
+ get_generator,
+ get_identifier,
+ register_generator,
+)
+from rigging.generator.http import HTTPGenerator
+from rigging.generator.litellm_ import LiteLLMGenerator
+from rigging.generator.transformers_ import TransformersGenerator
+from rigging.generator.vllm_ import VLLMGenerator
+
+__all__ = [
+ "GenerateParams",
+ "GeneratedMessage",
+ "GeneratedText",
+ "Generator",
+ "HTTPGenerator",
+ "LiteLLMGenerator",
+ "StopReason",
+ "TransformersGenerator",
+ "Usage",
+ "VLLMGenerator",
+ "chat",
+ "complete",
+ "get_generator",
+ "get_generator",
+ "get_identifier",
+ "register_generator",
+]
diff --git a/rigging/generator/base.py b/rigging/generator/base.py
index b612efa..336b15d 100644
--- a/rigging/generator/base.py
+++ b/rigging/generator/base.py
@@ -7,7 +7,14 @@
from functools import lru_cache
from loguru import logger
-from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, TypeAdapter, field_validator
+from pydantic import (
+ BaseModel,
+ BeforeValidator,
+ ConfigDict,
+ Field,
+ TypeAdapter,
+ field_validator,
+)
from typing_extensions import Self
from rigging.error import InvalidGeneratorError
@@ -123,7 +130,7 @@ async def wrapper(
return result
- return wrapper # type: ignore[return-value]
+ return wrapper # type: ignore [return-value]
return decorator
@@ -144,7 +151,7 @@ class GenerateParams(BaseModel):
Use the `extra` field to pass additional parameters to the API.
"""
- model_config = ConfigDict(extra="forbid")
+ model_config = ConfigDict(extra="forbid", json_schema_extra={"rigging.type": "generate_params"})
temperature: float | None = None
"""The sampling temperature."""
@@ -294,6 +301,15 @@ class Usage(BaseModel):
total_tokens: int
"""The total number of tokens processed."""
+ def __add__(self, other: "Usage") -> "Usage":
+ if not isinstance(other, Usage):
+ raise TypeError(f"Cannot add {type(other)} to Usage")
+ return Usage(
+ input_tokens=self.input_tokens + other.input_tokens,
+ output_tokens=self.output_tokens + other.output_tokens,
+ total_tokens=self.total_tokens + other.total_tokens,
+ )
+
GeneratedT = t.TypeVar("GeneratedT", Message, str)
@@ -301,6 +317,8 @@ class Usage(BaseModel):
class GeneratedMessage(BaseModel):
"""A generated message with additional generation information."""
+ model_config = ConfigDict(json_schema_extra={"rigging.type": "generated_message"})
+
message: Message
"""The generated message."""
@@ -364,6 +382,8 @@ class Generator(BaseModel):
- `generate_texts`: Process a batch of texts.
"""
+ model_config = ConfigDict(json_schema_extra={"rigging.type": "generator"})
+
model: str
"""The model name to be used by the generator."""
api_key: str | None = Field(None, exclude=True)
@@ -374,7 +394,7 @@ class Generator(BaseModel):
_watch_callbacks: list["WatchChatCallback | WatchCompletionCallback"] = []
_wrap: t.Callable[[CallableT], CallableT] | None = None
- def to_identifier(self, params: GenerateParams | None = None) -> str:
+ def to_identifier(self, params: GenerateParams | None = None, *, short: bool = False) -> str:
"""
Converts the generator instance back into a rigging identifier string.
@@ -386,7 +406,7 @@ def to_identifier(self, params: GenerateParams | None = None) -> str:
Returns:
The identifier string.
"""
- return get_identifier(self, params)
+ return get_identifier(self, params, short=short)
def watch(
self,
@@ -651,7 +671,9 @@ def complete(
return generator.complete(text, params)
-def get_identifier(generator: Generator, params: GenerateParams | None = None) -> str:
+def get_identifier(
+ generator: Generator, params: GenerateParams | None = None, *, short: bool = False
+) -> str:
"""
Converts the generator instance back into a rigging identifier string.
@@ -671,7 +693,10 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
for name, klass in g_generators.items()
if isinstance(klass, type) and isinstance(generator, klass)
)
- identifier = f"{provider}!{generator.model}"
+ identifier = f"{provider}!{generator.model}" if provider != "litellm" else generator.model
+
+ if short:
+ return identifier
identifier_extra = generator.model_dump(
exclude_unset=True,
diff --git a/rigging/generator/http.py b/rigging/generator/http.py
index 70db36c..d2eac0d 100644
--- a/rigging/generator/http.py
+++ b/rigging/generator/http.py
@@ -328,7 +328,7 @@ async def _generate_message(
all_content="\n".join(m.content for m in messages),
messages=[m.to_openai() for m in messages],
params=params.to_dict(),
- api_key=self.api_key,
+ api_key=self.api_key or "",
model=self.model,
)
diff --git a/rigging/generator/litellm_.py b/rigging/generator/litellm_.py
index e5f5612..a82d809 100644
--- a/rigging/generator/litellm_.py
+++ b/rigging/generator/litellm_.py
@@ -4,6 +4,7 @@
import re
import typing as t
+import dreadnode as dn
import litellm
import litellm.types.utils
from loguru import logger
@@ -20,7 +21,6 @@
)
from rigging.message import ContentAudioInput, ContentImageUrl, ContentText, Message
from rigging.tools.base import FunctionDefinition, ToolDefinition
-from rigging.tracing import tracer
# We should probably let people configure
# this independently, but for now we'll
@@ -172,7 +172,7 @@ async def supports_function_calling(self) -> bool | None:
# Otherwise we'll run a small check to see if we can
- with tracer.span(f"Checking '{self.model}' for function calling support") as span:
+ with dn.span(f"Checking '{self.model}' for function calling support") as span:
try:
generated = await self.generate_messages(
[[Message(role="user", content="Call the test function")]],
diff --git a/rigging/generator/transformers_.py b/rigging/generator/transformers_.py
index 9eb08a8..40276f8 100644
--- a/rigging/generator/transformers_.py
+++ b/rigging/generator/transformers_.py
@@ -80,14 +80,14 @@ def llm(self) -> AutoModelForCausalLM:
"load_in_4bit",
},
)
- self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call]
+ self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call] # nosec
return self._llm
@property
def tokenizer(self) -> AutoTokenizer:
"""The underlying `AutoTokenizer` instance."""
if self._tokenizer is None:
- self._tokenizer = AutoTokenizer.from_pretrained(self.model)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec
return self._tokenizer
@property
@@ -122,7 +122,7 @@ def from_obj(
Returns:
The TransformersGenerator instance.
"""
- instance = cls(model=model, params=params or GenerateParams())
+ instance = cls(model=model, params=params or GenerateParams(), api_key=None)
instance._llm = model
instance._tokenizer = tokenizer
instance._pipeline = pipeline
diff --git a/rigging/generator/vllm_.py b/rigging/generator/vllm_.py
index a2a0e0e..12674b5 100644
--- a/rigging/generator/vllm_.py
+++ b/rigging/generator/vllm_.py
@@ -90,7 +90,7 @@ def from_obj(
Returns:
The VLLMGenerator instance.
"""
- generator = cls(model=model, params=params or GenerateParams())
+ generator = cls(model=model, params=params or GenerateParams(), api_key=None)
generator._llm = llm
return generator
@@ -142,7 +142,7 @@ def _generate(
return [
GeneratedText(
text=o.outputs[-1].text,
- stop_reason=o.outputs[-1].finish_reason,
+ stop_reason=o.outputs[-1].finish_reason or "unknown",
extra={
"request_id": o.request_id,
"metrics": o.metrics,
diff --git a/rigging/message.py b/rigging/message.py
index 4020658..6741817 100644
--- a/rigging/message.py
+++ b/rigging/message.py
@@ -121,14 +121,16 @@ def clone(self) -> "MessageSlice":
Returns:
A new MessageSlice instance with the same properties.
"""
- return MessageSlice(
+ cloned = MessageSlice(
type=self.type,
obj=self.obj,
start=self.start,
stop=self.stop,
metadata=copy.deepcopy(self.metadata),
- _message=self._message, # Keep the reference to the original message
)
+ # Leaving this detached to align with tests
+ # cloned._message = self._message
+ return cloned # noqa: RET504
class ContentText(BaseModel):
@@ -405,7 +407,9 @@ class Message(BaseModel):
`content_parts` to `content` for compatibility.
"""
- model_config = ConfigDict(serialize_by_alias=True)
+ model_config = ConfigDict(
+ serialize_by_alias=True, json_schema_extra={"rigging.type": "message"}
+ )
uuid: UUID = Field(default_factory=uuid4, repr=False)
"""The unique identifier for the message."""
diff --git a/rigging/prompt.py b/rigging/prompt.py
index 70bfb4c..7491c0a 100644
--- a/rigging/prompt.py
+++ b/rigging/prompt.py
@@ -9,6 +9,7 @@
import typing as t
from collections import OrderedDict
+import dreadnode as dn
from jinja2 import Environment, StrictUndefined, meta
from pydantic import ValidationError
from typing_extensions import Concatenate, ParamSpec # noqa: UP035
@@ -25,7 +26,6 @@
from rigging.message import Message
from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive
from rigging.tools import Tool
-from rigging.tracing import tracer
from rigging.util import escape_xml, get_qualified_name, to_snake, to_xml_tag
DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})."
@@ -561,7 +561,12 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
try:
# A bit weird, but we need from_chat to properly handle
# wrapping Chat output types inside lists/dataclasses
- self.output.from_chat(chat)
+ with dn.task_span(
+ f"prompt parse - {self.output.tag}",
+ attributes={"rigging.type": "prompt.parse"},
+ ):
+ dn.log_input("message", chat.last)
+ self.output.from_chat(chat)
except ValidationError as e:
next_pipeline.add(
Message.from_model(
@@ -836,10 +841,36 @@ def say_hello(name: str) -> str:
raise NotImplementedError(
"pipeline.on_failed='skip' cannot be used for prompt methods that return one object",
)
+ if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput):
+ raise NotImplementedError(
+ "pipeline.on_failed='include' cannot be used with prompts that process outputs",
+ )
async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- results = await self.bind_many(pipeline)(1, *args, **kwargs)
- return results[0]
+ name = get_qualified_name(self.func) if self.func else ""
+ with dn.task_span(
+ f"prompt - {name}",
+ attributes={"prompt_name": name, "rigging.type": "prompt.run"},
+ ):
+ dn.log_inputs(**self._bind_args(*args, **kwargs))
+ content = self.render(*args, **kwargs)
+ _pipeline = (
+ pipeline.fork(content)
+ .using(*self.tools, max_depth=self.max_tool_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
+ .then(*self.then_callbacks)
+ .map(*self.map_callbacks)
+ .watch(*self.watch_callbacks)
+ .with_(self.params)
+ )
+
+ if self.system_prompt:
+ _pipeline.chat.inject_system_content(self.system_prompt)
+
+ chat = await _pipeline.run()
+ output = self.process(chat)
+ dn.log_output("output", output)
+ return output
run.__signature__ = self.__signature__ # type: ignore [attr-defined]
run.__name__ = self.__name__
@@ -867,7 +898,7 @@ def bind_many(
def say_hello(name: str) -> str:
\"""Say hello to {{ name }}\"""
- await say_hello.bind("gpt-3.5-turbo")(5, "the world")
+ await say_hello.bind_many("gpt-4.1")(5, "the world")
```
"""
pipeline = self._resolve_to_pipeline(other)
@@ -878,17 +909,17 @@ def say_hello(name: str) -> str:
async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]:
name = get_qualified_name(self.func) if self.func else ""
- with tracer.span(
- f"Prompt {name}()" if count == 1 else f"Prompt {name}() (x{count})",
- count=count,
- name=name,
- arguments=self._bind_args(*args, **kwargs),
+ with dn.task_span(
+ f"prompt - {name} (x{count})",
+ label=f"prompt_{name}",
+ attributes={"prompt_name": name, "rigging.type": "prompt.run_many"},
) as span:
+ dn.log_inputs(**self._bind_args(*args, **kwargs))
content = self.render(*args, **kwargs)
_pipeline = (
pipeline.fork(content)
.using(*self.tools, max_depth=self.max_tool_rounds)
- .then(self._then_parse, max_depth=self.max_parsing_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
.then(*self.then_callbacks)
.map(*self.map_callbacks)
.watch(*self.watch_callbacks)
@@ -899,35 +930,13 @@ async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]:
_pipeline.chat.inject_system_content(self.system_prompt)
chats = await _pipeline.run_many(count)
-
- # TODO: I can't remember why we don't just pass the watch_callbacks to the pipeline
- # Maybe it has something to do with uniqueness and merging?
-
- def wrap_watch_callback(callback: "WatchChatCallback") -> "WatchChatCallback":
- async def traced_watch_callback(chats: list[Chat]) -> None:
- callback_name = get_qualified_name(callback)
- with tracer.span(
- f"Watch with {callback_name}()",
- callback=callback_name,
- chat_count=len(chats),
- chat_ids=[str(c.uuid) for c in chats],
- ):
- await callback(chats)
-
- return traced_watch_callback
-
- coros = [
- wrap_watch_callback(watch)(chats)
- for watch in self.watch_callbacks
- if watch not in pipeline.watch_callbacks
- ]
- await asyncio.gather(*coros)
-
- results = [self.process(chat) for chat in chats]
- span.set_attribute("results", results)
- return results
+ outputs = [self.process(chat) for chat in chats]
+ span.log_output("outputs", outputs)
+ return outputs
run_many.__rg_prompt__ = self # type: ignore [attr-defined]
+ run_many.__name__ = self.__name__
+ run_many.__doc__ = self.__doc__
return run_many
@@ -953,7 +962,7 @@ def bind_over(
def say_hello(name: str) -> str:
\"""Say hello to {{ name }}\"""
- await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world")
+ await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world")
```
"""
include_original = other is not None
@@ -980,7 +989,7 @@ async def run_over(
_pipeline = (
pipeline.fork(content)
.using(*self.tools, max_depth=self.max_tool_rounds)
- .then(self._then_parse, max_depth=self.max_parsing_rounds)
+ .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False)
.then(*self.then_callbacks)
.map(*self.map_callbacks)
.watch(*self.watch_callbacks)
@@ -992,13 +1001,6 @@ async def run_over(
chats = await _pipeline.run_over(*generators, include_original=include_original)
- coros = [
- watch(chats)
- for watch in self.watch_callbacks
- if watch not in pipeline.watch_callbacks
- ]
- await asyncio.gather(*coros)
-
return [self.process(chat) for chat in chats]
run_over.__rg_prompt__ = self # type: ignore [attr-defined]
diff --git a/rigging/tokenizer/__init__.py b/rigging/tokenizer/__init__.py
index 4180205..a8c2b02 100644
--- a/rigging/tokenizer/__init__.py
+++ b/rigging/tokenizer/__init__.py
@@ -2,6 +2,8 @@
Tokenizers encode chats and associated message data into tokens for training and inference.
"""
+import typing as t
+
from rigging.tokenizer.base import (
TokenizedChat,
Tokenizer,
@@ -31,3 +33,9 @@ def get_transformers_lazy() -> type[Tokenizer]:
"get_tokenizer",
"register_tokenizer",
]
+
+
+def __getattr__(name: str) -> t.Any:
+ if name == "TransformersTokenizer":
+ return get_transformers_lazy()
+ raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/rigging/tokenizer/__init__.pyi b/rigging/tokenizer/__init__.pyi
new file mode 100644
index 0000000..1c74d9a
--- /dev/null
+++ b/rigging/tokenizer/__init__.pyi
@@ -0,0 +1,17 @@
+from rigging.tokenizer.base import (
+ TokenizedChat,
+ Tokenizer,
+ TokenSlice,
+ get_tokenizer,
+ register_tokenizer,
+)
+from rigging.tokenizer.transformers_ import TransformersTokenizer
+
+__all__ = [
+ "TokenSlice",
+ "TokenizedChat",
+ "Tokenizer",
+ "TransformersTokenizer",
+ "get_tokenizer",
+ "register_tokenizer",
+]
diff --git a/rigging/tokenizer/transformers_.py b/rigging/tokenizer/transformers_.py
index ec24aef..8f2616c 100644
--- a/rigging/tokenizer/transformers_.py
+++ b/rigging/tokenizer/transformers_.py
@@ -34,7 +34,7 @@ class TransformersTokenizer(Tokenizer):
def tokenizer(self) -> "PreTrainedTokenizer":
"""The underlying `PreTrainedTokenizer` instance."""
if self._tokenizer is None:
- self._tokenizer = AutoTokenizer.from_pretrained(self.model)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec
return self._tokenizer
@classmethod
diff --git a/rigging/tools/base.py b/rigging/tools/base.py
index d6ff410..f8d8a6f 100644
--- a/rigging/tools/base.py
+++ b/rigging/tools/base.py
@@ -8,11 +8,19 @@
import re
import typing as t
import warnings
-from dataclasses import dataclass, field
from functools import cached_property
+import dreadnode as dn
import typing_extensions as te
-from pydantic import BaseModel, TypeAdapter, ValidationError, field_validator
+from pydantic import (
+ BaseModel,
+ ConfigDict,
+ Field,
+ PrivateAttr,
+ TypeAdapter,
+ ValidationError,
+ field_validator,
+)
from pydantic_xml import attr
from rigging.error import Stop, ToolDefinitionError, ToolWarning
@@ -23,7 +31,6 @@
make_from_schema,
make_from_signature,
)
-from rigging.tracing import tracer
from rigging.util import deref_json
if t.TYPE_CHECKING:
@@ -103,6 +110,8 @@ class FunctionCall(BaseModel):
class ToolCall(BaseModel):
+ model_config = ConfigDict(json_schema_extra={"rigging.type": "tool_call"})
+
id: str
type: t.Literal["function"] = "function"
function: FunctionCall
@@ -133,8 +142,7 @@ def _is_unbound_method(func: t.Any) -> bool:
return is_method is not hasattr(func, "__self__")
-@dataclass
-class Tool(t.Generic[P, R]):
+class Tool(BaseModel, t.Generic[P, R]):
"""Base class for representing a tool to a generator."""
name: str
@@ -143,7 +151,10 @@ class Tool(t.Generic[P, R]):
"""A description of the tool."""
parameters_schema: dict[str, t.Any]
"""The JSON schema for the tool's parameters."""
- fn: t.Callable[P, R]
+ fn: t.Callable[P, R] = Field( # type: ignore [assignment]
+ default_factory=lambda: lambda *args, **kwargs: None, # noqa: ARG005
+ exclude=True,
+ )
"""The function to call."""
catch: bool | set[type[Exception]] = False
"""
@@ -156,13 +167,9 @@ class Tool(t.Generic[P, R]):
truncate: int | None = None
"""If set, the maximum number of characters to truncate any tool output to."""
- _signature: inspect.Signature | None = field(default=None, init=False, repr=False)
- _type_adapter: TypeAdapter[t.Any] | None = field(
- default=None,
- init=False,
- repr=False,
- )
- _model: type[Model] | None = field(default=None, init=False, repr=False)
+ _signature: inspect.Signature | None = PrivateAttr(default=None, init=False)
+ _type_adapter: TypeAdapter[t.Any] | None = PrivateAttr(default=None, init=False)
+ _model: type[Model] | None = PrivateAttr(default=None, init=False)
# In general we are split between 2 strategies for handling the data translations:
#
@@ -281,7 +288,7 @@ def empty_func(*args, **kwargs): # type: ignore [no-untyped-def] # noqa: ARG001
)
self._signature = signature
- self.__signature__ = signature # type: ignore [attr-defined]
+ self.__signature__ = signature # type: ignore [misc]
self.__name__ = self.name # type: ignore [attr-defined]
self.__doc__ = self.description
@@ -351,7 +358,12 @@ async def handle_tool_call( # noqa: PLR0912
from rigging.message import ContentText, ContentTypes, Message
- with tracer.span(f"Tool {self.name}()", name=self.name) as span:
+ with dn.task_span(
+ f"tool - {self.name}",
+ attributes={"tool_name": self.name, "rigging.type": "tool"},
+ ) as task:
+ dn.log_input("tool_call", tool_call)
+
if tool_call.name != self.name:
warnings.warn(
f"Tool call name mismatch: {tool_call.name} != {self.name}",
@@ -361,7 +373,7 @@ async def handle_tool_call( # noqa: PLR0912
return Message.from_model(SystemErrorModel(content="Invalid tool call.")), True
if hasattr(tool_call, "id") and isinstance(tool_call.id, str):
- span.set_attribute("tool_call_id", tool_call.id)
+ task.set_attribute("tool_call_id", tool_call.id)
result: t.Any
stop = False
@@ -372,8 +384,9 @@ async def handle_tool_call( # noqa: PLR0912
kwargs = json.loads(tool_call.function.arguments)
if self._type_adapter is not None:
kwargs = self._type_adapter.validate_python(kwargs)
- span.set_attribute("arguments", kwargs)
+ dn.log_inputs(**kwargs)
except (json.JSONDecodeError, ValidationError) as e:
+ task.set_exception(e)
result = ErrorModel.from_exception(e)
# Call the function
@@ -388,17 +401,18 @@ async def handle_tool_call( # noqa: PLR0912
raise result # noqa: TRY301
except Stop as e:
result = f"<{TOOL_STOP_TAG}>{e.message}{TOOL_STOP_TAG}>"
- span.set_attribute("stop", True)
+ task.set_attribute("stop", True)
stop = True
except Exception as e:
if self.catch is True or (
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
):
+ task.set_exception(e)
result = ErrorModel.from_exception(e)
else:
raise
- span.set_attribute("result", result)
+ dn.log_output("output", result)
message = Message(role="tool", tool_call_id=tool_call.id)
@@ -521,7 +535,7 @@ def __get__(self, instance: t.Any, owner: t.Any) -> "Tool[P, R]":
catch=self.catch,
)
- bound_tool.__signature__ = self.__signature__ # type: ignore [attr-defined]
+ bound_tool.__signature__ = self.__signature__ # type: ignore [misc]
bound_tool._signature = self._signature # noqa: SLF001
bound_tool._type_adapter = self._type_adapter # noqa: SLF001
bound_tool._model = self._model # noqa: SLF001
diff --git a/rigging/tools/robopages.py b/rigging/tools/robopages.py
index a84b08f..e3fe335 100644
--- a/rigging/tools/robopages.py
+++ b/rigging/tools/robopages.py
@@ -92,10 +92,10 @@ def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.A
tools.append(
Tool(
- function.name,
- function.description or "",
- function.parameters or {},
- make_execute_on_server(url, function.name),
+ name=function.name,
+ description=function.description or "",
+ parameters_schema=function.parameters or {},
+ fn=make_execute_on_server(url, function.name),
),
)
diff --git a/rigging/tracing.py b/rigging/tracing.py
deleted file mode 100644
index ff9e103..0000000
--- a/rigging/tracing.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import typing as t
-
-import logfire_api
-
-Span = logfire_api.LogfireSpan
-
-
-class Tracer(logfire_api.Logfire):
- def span(
- self,
- msg_template: str,
- /,
- *,
- _tags: t.Sequence[str] | None = None,
- _span_name: str | None = None,
- _level: t.Any | None = None,
- _links: t.Any = (),
- **attributes: t.Any,
- ) -> logfire_api.LogfireSpan:
- # Pass msg_template as the span name
- # to avoid weird fstring behaviors
- return super().span(
- msg_template,
- _tags=_tags,
- _span_name=msg_template,
- _level=_level,
- _links=_links,
- **attributes,
- )
-
-
-tracer = Tracer(otel_scope="rigging")
diff --git a/rigging/transform/xml_tools.py b/rigging/transform/xml_tools.py
index da4b239..1baf998 100644
--- a/rigging/transform/xml_tools.py
+++ b/rigging/transform/xml_tools.py
@@ -4,7 +4,7 @@
import uuid
import warnings
-import xmltodict # type: ignore[import-untyped]
+import xmltodict # type: ignore [import-untyped]
from pydantic.fields import FieldInfo
from pydantic_xml import attr
diff --git a/tests/test_chat.py b/tests/test_chat.py
index 66e0152..e625aa2 100644
--- a/tests/test_chat.py
+++ b/tests/test_chat.py
@@ -171,7 +171,7 @@ def test_message_double_content_part_separation() -> None:
def test_chat_generator_id() -> None:
generator = get_generator("gpt-3.5")
chat = Chat([], generator=generator)
- assert chat.generator_id == "litellm!gpt-3.5"
+ assert chat.generator_id == "gpt-3.5"
other = Chat([])
assert other.generator_id is None
diff --git a/tests/test_chat_pipeline.py b/tests/test_chat_pipeline.py
index 96f6d69..ae16f04 100644
--- a/tests/test_chat_pipeline.py
+++ b/tests/test_chat_pipeline.py
@@ -188,13 +188,36 @@ async def double_chats(chats: list[Chat]) -> list[Chat]:
return chats + new_chats
chats = (
- await generator.chat([{"role": "user", "content": "Hello"}]).map(double_chats).run_many(1)
+ await generator.chat([{"role": "user", "content": "Hello"}])
+ .map(double_chats)
+ .run_many(1, mode="merged")
)
assert len(chats) == 2
assert chats[0].last.content == "Response 1"
assert chats[1].last.content == "Modified: Response 1"
+ # in parallel mode, we expect only one chat per internal pipeline
+
+ chats = (
+ await generator.chat([{"role": "user", "content": "Hello"}])
+ .map(double_chats)
+ .run_many(1, mode="parallel")
+ )
+
+ assert len(chats) == 1
+ assert chats[0].last.content == "Modified: Response 1"
+
+ chats = (
+ await generator.chat([{"role": "user", "content": "Hello"}])
+ .map(double_chats)
+ .run_many(2, mode="parallel")
+ )
+
+ assert len(chats) == 2
+ assert chats[0].last.content == "Modified: Response 1"
+ assert chats[1].last.content == "Modified: Response 1"
+
@pytest.mark.asyncio
async def test_watch_callback() -> None:
diff --git a/tests/test_completion_pipeline.py b/tests/test_completion_pipeline.py
index 1530bd6..8387130 100644
--- a/tests/test_completion_pipeline.py
+++ b/tests/test_completion_pipeline.py
@@ -9,7 +9,7 @@
def test_completion_generator_id() -> None:
generator = get_generator("gpt-3.5")
completion = Completion("foo", "bar", generator)
- assert completion.generator_id == "litellm!gpt-3.5"
+ assert completion.generator_id == "gpt-3.5"
completion.generator = None
assert completion.generator_id is None
diff --git a/tests/test_generator_ids.py b/tests/test_generator_ids.py
index 134e407..a53560e 100644
--- a/tests/test_generator_ids.py
+++ b/tests/test_generator_ids.py
@@ -48,10 +48,10 @@ def test_get_generator_with_params(identifier: str, valid_params: GenerateParams
@pytest.mark.parametrize(
"identifier",
[
- ("litellm!test_model,max_tokens=1024,top_p=0.1"),
- ("litellm!custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"),
- ("litellm!many/model/slashes,stop=a;b;c;"),
- ("litellm!with_cls_args,max_connections=10"),
+ ("test_model,max_tokens=1024,top_p=0.1"),
+ ("custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"),
+ ("many/model/slashes,stop=a;b;c;"),
+ ("http!with_cls_args"),
],
)
def test_identifier_roundtrip(identifier: str) -> None:
diff --git a/tests/test_tool.py b/tests/test_tool.py
index f462c26..dfcdb50 100644
--- a/tests/test_tool.py
+++ b/tests/test_tool.py
@@ -24,8 +24,8 @@ def simple_function(name: str, age: int) -> str:
assert tool.name == "simple_function"
assert tool.description == "A simple function that returns a greeting."
- assert "_signature" in tool.__dict__
- assert "_type_adapter" in tool.__dict__
+ assert getattr(tool, "_signature", None) is not None
+ assert getattr(tool, "_type_adapter", None) is not None
# Check schema
assert "name" in tool.parameters_schema["properties"]