diff --git a/docs/api/chat.mdx b/docs/api/chat.mdx index 3d1469f..e0bbab8 100644 --- a/docs/api/chat.mdx +++ b/docs/api/chat.mdx @@ -10,17 +10,6 @@ Chats are used pre and post generation to hold messages. They are the primary way to interact with the generator. -CacheMode ---------- - -```python -CacheMode = Literal['latest'] -``` - -How to handle cache\_control entries on messages. - -* latest: Assign cache\_control to the latest 2 non-assistant messages in the pipeline before inference. - DEFAULT\_MAX\_DEPTH ------------------- @@ -1274,14 +1263,18 @@ def __init__( """How to handle failures in the pipeline unless overridden in calls.""" self.caching: CacheMode | None = None """How to handle cache_control entries on messages.""" + self.task_name: str = generator.to_identifier(short=True) + """The name of the pipeline task, used for logging and debugging.""" + self.scorers: list[dn.Scorer[Chat]] = [] + """List of dreadnode scorers to evaluate the generated chat upon completion.""" self.until_types: list[type[Model]] = [] self.tools: list[Tool[..., t.Any]] = [] self.tool_mode: ToolMode = "auto" self.inject_tool_prompt = True self.add_tool_stop_token = True - self.then_callbacks: list[tuple[ThenChatCallback, int]] = [] - self.map_callbacks: list[tuple[MapChatCallback, int]] = [] + self.then_callbacks: list[tuple[ThenChatCallback, int, bool]] = [] + self.map_callbacks: list[tuple[MapChatCallback, int, bool]] = [] self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or [] self.transforms: list[Transform] = [] ``` @@ -1356,6 +1349,22 @@ params = params The parameters for generating messages. +### scorers + +```python +scorers: list[Scorer[Chat]] = [] +``` + +List of dreadnode scorers to evaluate the generated chat upon completion. + +### task\_name + +```python +task_name: str = to_identifier(short=True) +``` + +The name of the pipeline task, used for logging and debugging. + ### add ```python @@ -1725,6 +1734,9 @@ def clone( new.errors_to_catch = self.errors_to_catch.copy() new.errors_to_exclude = self.errors_to_exclude.copy() new.caching = self.caching + new.task_name = self.task_name + new.scorers = self.scorers.copy() + new.transforms = self.transforms.copy() new.watch_callbacks = self.watch_callbacks.copy() @@ -1736,18 +1748,18 @@ def clone( return new new.then_callbacks = [ - (callback, max_depth) + (callback, max_depth, as_task) if not hasattr(callback, "__self__") or not isinstance(callback.__self__, ChatPipeline) - else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr] - for callback, max_depth in self.then_callbacks.copy() + else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr] + for callback, max_depth, as_task in self.then_callbacks.copy() ] new.map_callbacks = [ - (callback, max_depth) + (callback, max_depth, as_task) if not hasattr(callback, "__self__") or not isinstance(callback.__self__, ChatPipeline) - else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr] - for callback, max_depth in self.map_callbacks.copy() + else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr] + for callback, max_depth, as_task in self.map_callbacks.copy() ] new.transforms = [ callback @@ -1759,13 +1771,13 @@ def clone( if not isinstance(callbacks, bool): new.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback in callbacks ] new.map_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.map_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.map_callbacks if callback in callbacks ] new.transforms = [callback for callback in self.transforms if callback in callbacks] @@ -1833,6 +1845,7 @@ map( *callbacks: MapChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> ChatPipeline ``` @@ -1861,6 +1874,11 @@ the final return value from the pipeline. `False` ) –Whether to allow (seemingly) duplicate callbacks to be added. +* **`as_task`** + (`bool`, default: + `True` + ) + –Whether to create a task for this callback. **Returns:** @@ -1884,6 +1902,7 @@ def map( *callbacks: MapChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> "ChatPipeline": """ Registers a callback to be executed after the generation process completes. @@ -1897,6 +1916,7 @@ def map( callbacks: The callback function to be executed. max_depth: The maximum depth to allow recursive pipeline calls during this callback. allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added. + as_task: Whether to create a task for this callback. Returns: The updated pipeline. @@ -1918,7 +1938,7 @@ def map( f"Callback '{get_qualified_name(callback)}' is already registered.", ) - self.map_callbacks.extend([(callback, max_depth) for callback in callbacks]) + self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) return self ``` @@ -1963,6 +1983,44 @@ def meta(self, **kwargs: t.Any) -> "ChatPipeline": ``` + + +### name + +```python +name(name: str) -> ChatPipeline +``` + +Sets the name of the pipeline. + +**Parameters:** + +* **`name`** + (`str`) + –The name to set for the pipeline. + +**Returns:** + +* `ChatPipeline` + –The updated pipeline. + + +```python +def name(self, name: str) -> "ChatPipeline": + """ + Sets the name of the pipeline. + + Args: + name: The name to set for the pipeline. + + Returns: + The updated pipeline. + """ + self.task_name = name + return self +``` + + ### prompt @@ -2015,6 +2073,7 @@ def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> "Prompt[P, ```python run( *, + name: str | None = None, on_failed: FailMode | None = None, allow_failed: bool = False, ) -> Chat @@ -2024,6 +2083,11 @@ Execute the generation process for a single message. **Parameters:** +* **`name`** + (`str | None`, default: + `None` + ) + –The name of the task for logging purposes. * **`on_failed`** (`FailMode | None`, default: `None` @@ -2045,6 +2109,7 @@ Execute the generation process for a single message. async def run( self, *, + name: str | None = None, on_failed: FailMode | None = None, allow_failed: bool = False, ) -> Chat: @@ -2052,6 +2117,7 @@ async def run( Execute the generation process for a single message. Args: + name: The name of the task for logging purposes. on_failed: The behavior when a message fails to generate. allow_failed: Deprecated, use `on_failed="include"`. @@ -2068,16 +2134,44 @@ async def run( if on_failed is None: on_failed = "include" if allow_failed else self.on_failed - last: PipelineStep | None = None - async with self.step(on_failed=on_failed) as steps: - async for step in steps: - last = step + if on_failed == "skip": + raise ValueError( + "Cannot use 'skip' mode with single message generation (pass allow_failed=True or on_failed='include'/'raise')", + ) - if last is None or last.state != "final": - raise RuntimeError("The pipeline did not complete successfully") + messages = [self.chat.all] + params = self._fit_params(1, [self.params]) - if not last.chats: - raise RuntimeError("The pipeline process did not produce any chats") + last: PipelineStep | None = None + with dn.task_span( + name or f"pipeline - {self.task_name}", + label=name or f"pipeline_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run"}, + ) as task: + dn.log_inputs( + messages=messages[0], + params=params[0], + generator_id=self.generator.to_identifier(), + ) + + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None and last.chats: + dn.log_output("chat", last.chats[-1]) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) + + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") + + if not last.chats: + raise RuntimeError("The pipeline process did not produce any chats") return last.chats[-1] ``` @@ -2097,6 +2191,8 @@ run_batch( | str, params: Sequence[GenerateParams | None] | None = None, *, + name: str | None = None, + mode: Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList ``` @@ -2117,6 +2213,18 @@ Anything already in this chat pipeline will be prepended to the input messages. `None` ) –A sequence of parameters to be used for each set of messages. +* **`name`** + (`str | None`, default: + `None` + ) + –The name of the task for logging purposes. +* **`mode`** + (`Literal['merged', 'parallel']`, default: + `'parallel'` + ) + –The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation * **`on_failed`** (`FailMode | None`, default: `None` @@ -2140,6 +2248,8 @@ async def run_batch( | str, params: t.Sequence[GenerateParams | None] | None = None, *, + name: str | None = None, + mode: t.Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList: """ @@ -2151,21 +2261,76 @@ async def run_batch( Args: many: A sequence of sequences of messages to be generated. params: A sequence of parameters to be used for each set of messages. + name: The name of the task for logging purposes. + mode: The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation on_failed: The behavior when a message fails to generate. Returns: A list of generatated Chats. """ + on_failed = on_failed or self.on_failed + count, messages, params = self._fit_batch_args(many, params) last: PipelineStep | None = None - async with self.step_batch(many, params=params, on_failed=on_failed) as steps: - async for step in steps: - last = step + with dn.task_span( + name or f"pipeline - {self.task_name} (batch x{count})", + label=name or f"pipeline_batch_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run_batch"}, + ) as task: + dn.log_inputs( + count=count, + messages=messages, + params=params, + generator_id=self.generator.to_identifier(), + ) + + if mode == "merged": + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None: + dn.log_output("chats", last.chats) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) + + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") + + return last.chats + + if mode == "parallel": + tasks = [ + asyncio.create_task( + self.clone().add(_messages).with_(_params).run(on_failed="include") + ) + for _messages, _params in zip(messages, params, strict=True) + ] + chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True) + + self._raise_if_failed(chats_or_errors, on_failed) - if last is None or last.state != "final": - raise ValueError("The generation process did not complete successfully") + chats = [ + chat + for chat in chats_or_errors + if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed) + ] + + dn.log_output("chats", chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", chats) + + return ChatList(chats) - return last.chats + raise ValueError( + f"Invalid mode '{mode}', expected 'merged' or 'separate'", + ) ``` @@ -2178,6 +2343,8 @@ run_many( count: int, *, params: Sequence[GenerateParams | None] | None = None, + name: str | None = None, + mode: Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList ``` @@ -2194,6 +2361,18 @@ Executes the generation process in parallel over the same input. `None` ) –A sequence of parameters to be used for each execution. +* **`name`** + (`str | None`, default: + `None` + ) + –The name of the task for logging purposes. +* **`mode`** + (`Literal['merged', 'parallel']`, default: + `'parallel'` + ) + –The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation * **`on_failed`** (`FailMode | None`, default: `None` @@ -2203,7 +2382,7 @@ Executes the generation process in parallel over the same input. **Returns:** * `ChatList` - –A list of generatated Chats. + –A list of generated Chats. ```python @@ -2212,6 +2391,8 @@ async def run_many( count: int, *, params: t.Sequence[GenerateParams | None] | None = None, + name: str | None = None, + mode: t.Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList: """ @@ -2220,21 +2401,76 @@ async def run_many( Args: count: The number of times to execute the generation process. params: A sequence of parameters to be used for each execution. + name: The name of the task for logging purposes. + mode: The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation on_failed: The behavior when a message fails to generate. Returns: - A list of generatated Chats. + A list of generated Chats. """ + if count < 1: + raise ValueError("Count must be greater than 0") + + on_failed = on_failed or self.on_failed + + messages = [self.chat.all] * count + params = self._fit_params(count, params) last: PipelineStep | None = None - async with self.step_many(count, params=params, on_failed=on_failed) as steps: - async for step in steps: - last = step + with dn.task_span( + name or f"pipeline - {self.task_name} (x{count})", + label=name or f"pipeline_many_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run_many"}, + ) as task: + dn.log_inputs( + count=count, + messages=messages[0], + params=params[0], + generator_id=self.generator.to_identifier(), + ) + + if mode == "merged": + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None: + dn.log_output("chats", last.chats) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) + + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") + + return last.chats + + if mode == "parallel": + tasks = [asyncio.create_task(self.run(on_failed="include")) for _ in range(count)] + chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True) + + self._raise_if_failed(chats_or_errors, on_failed) + + chats = [ + chat + for chat in chats_or_errors + if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed) + ] - if last is None or last.state != "final": - raise ValueError("The generation process did not complete successfully") + dn.log_output("chats", chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", chats) - return last.chats + return ChatList(chats) + + raise ValueError( + f"Invalid mode '{mode}', expected 'merged' or 'parallel'", + ) ``` @@ -2314,11 +2550,100 @@ async def run_over( sub.generator = generator coros.append(sub.run(allow_failed=(on_failed != "raise"))) - with tracer.span(f"Chat over {len(coros)} generators", count=len(coros)): + short_generators = [g.to_identifier(short=True) for g in _generators] + task_name = "iterate - " + ", ".join(short_generators) + + with dn.task_span( + task_name, + label="iterate_over", + attributes={"rigging.type": "chat_pipeline.run_over"}, + ): + dn.log_input("generators", [g.to_identifier() for g in _generators]) return ChatList(await asyncio.gather(*coros)) ``` + + +### score + +```python +score( + *scorers: Scorer[Chat] | ScorerCallable[Chat], + filter: ChatFilterMode | ChatFilterFunction = "last", +) -> ChatPipeline +``` + +Adds one or more scorers to the pipeline to evaluate the generated chat upon completion. + +**Parameters:** + +* **`*scorers`** + (`Scorer[Chat] | ScorerCallable[Chat]`, default: + `()` + ) + –The scorer or scorers to be added. These can be either: + - A dreadnode.Scorer instance. + - A callable function that can be converted to a dreadnode.Scorer. +* **`filter`** + (`ChatFilterMode | ChatFilterFunction`, default: + `'last'` + ) + –The strategy for filtering which messages to include: + - "all": Use all messages in the chat. + - "last": Use only the last message. + - "first": Use only the first message. + - "user": Use only user messages. + - "assistant": Use only assistant messages. + - "last\_user": Use only the last user message. + - "last\_assistant": Use only the last assistant message. + - A callable that takes a list of `Message` objects and returns a filtered list. + +**Returns:** + +* `ChatPipeline` + –The updated pipeline. + + +```python +def score( + self, + *scorers: dn.Scorer[Chat] | ScorerCallable[Chat], + filter: "ChatFilterMode | ChatFilterFunction" = "last", +) -> "ChatPipeline": + """ + Adds one or more scorers to the pipeline to evaluate the generated chat upon completion. + + Args: + *scorers: The scorer or scorers to be added. These can be either: + - A dreadnode.Scorer instance. + - A callable function that can be converted to a dreadnode.Scorer. + filter: The strategy for filtering which messages to include: + - "all": Use all messages in the chat. + - "last": Use only the last message. + - "first": Use only the first message. + - "user": Use only user messages. + - "assistant": Use only assistant messages. + - "last_user": Use only the last user message. + - "last_assistant": Use only the last assistant message. + - A callable that takes a list of `Message` objects and returns a filtered list. + + Returns: + The updated pipeline. + """ + self.scorers.extend( + [ + dn.scorers.wrap_chat( + scorer if isinstance(scorer, dn.Scorer) else dn.Scorer.from_callable(scorer), + filter=filter, + ) + for scorer in scorers + ] + ) + return self +``` + + ### step @@ -2374,15 +2699,10 @@ async def step( messages = [self.chat.all] params = self._fit_params(1, [self.params]) - with tracer.span( - f"Chat with {self.generator.to_identifier()}", - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator ``` @@ -2461,37 +2781,12 @@ async def step_batch( Pipeline steps. """ on_failed = on_failed or self.on_failed + _, messages, params = self._fit_batch_args(many, params) - # Get the maximum of either incoming messages or params - - count = max(len(many), len(params) if params is not None else 0) - - # If we have less messages than params, we need to either: - # 1. Error because we have >1 messages that we can't reasonably - # zip with our parameters of a different length - # 2. Duplicate a single message we have len(params) times as the - # user is just batching only over parameters - - messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many] - if len(messages) < count: - if len(messages) != 1: - raise ValueError( - f"Can't fit {len(messages)} messages to {count} params", - ) - messages = messages * count - - params = self._fit_params(count, params) - - with tracer.span( - f"Chat batch with {self.generator.to_identifier()} ({count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator ``` @@ -2557,16 +2852,10 @@ async def step_many( messages = [self.chat.all] * count params = self._fit_params(count, params) - with tracer.span( - f"Chat with {self.generator.to_identifier()} (x{count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator ``` @@ -2579,6 +2868,7 @@ then( *callbacks: ThenChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> ChatPipeline ``` @@ -2607,6 +2897,11 @@ from the pipeline. `False` ) –Whether to allow (seemingly) duplicate callbacks to be added. +* **`as_task`** + (`bool`, default: + `True` + ) + –Whether to create a task for this callback. **Returns:** @@ -2630,6 +2925,7 @@ def then( *callbacks: ThenChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> "ChatPipeline": """ Registers one or many callbacks to be executed after the generation process completes. @@ -2643,6 +2939,7 @@ def then( callbacks: The callback functions to be added. max_depth: The maximum depth to allow recursive pipeline calls during this callback. allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added. + as_task: Whether to create a task for this callback. Returns: The updated pipeline. @@ -2664,7 +2961,7 @@ def then( f"Callback '{get_qualified_name(callback)}' is already registered.", ) - self.then_callbacks.extend([(callback, max_depth) for callback in callbacks]) + self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) return self ``` @@ -2867,11 +3164,11 @@ def until_parsed_as( max_depth = max_rounds or max_depth self.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback != self._then_parse ] - self.then_callbacks.append((self._then_parse, max_depth)) + self.then_callbacks.append((self._then_parse, max_depth, False)) return self ``` @@ -3037,13 +3334,13 @@ def using( self.tools = [tool for tool in self.tools if tool.name not in new_names] + _tools self.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback != self._then_tools # Always remove to update max_depth ] self.then_callbacks.insert( 0, # make sure this is first - (self._then_tools, max_depth), + (self._then_tools, max_depth, False), ) if mode is not None: diff --git a/docs/api/completion.mdx b/docs/api/completion.mdx index e829262..2c394cb 100644 --- a/docs/api/completion.mdx +++ b/docs/api/completion.mdx @@ -897,12 +897,21 @@ async def run( on_failed = on_failed or self.on_failed states = self._initialize_states(1) - with tracer.span( - f"Completion with {self.generator.to_identifier()}", - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return (await self._run(span, states, on_failed))[0] + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)}", + label=f"pipeline_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run"}, + ) as task: + dn.log_inputs( + text=self.text, + params=self.params.to_dict() if self.params is not None else {}, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed) + completion = completions[0] + dn.log_output("completion", completion) + task.set_attribute("completions", completions) + return completion ``` @@ -978,13 +987,21 @@ async def run_batch( for state in states: next(state.processor) - with tracer.span( - f"Completion batch with {self.generator.to_identifier()} ({len(states)})", - count=len(states), - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return await self._run(span, states, on_failed, batch_mode=True) + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)} (batch x{len(states)})", + label=f"pipeline_batch_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run_batch"}, + ) as task: + dn.log_inputs( + count=len(states), + many=many, + params=params, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed, batch_mode=True) + dn.log_output("completions", completions) + task.set_attribute("completions", completions) + return completions ``` @@ -1047,13 +1064,21 @@ async def run_many( on_failed = on_failed or self.on_failed states = self._initialize_states(count, params) - with tracer.span( - f"Completion with {self.generator.to_identifier()} (x{count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return await self._run(span, states, on_failed) + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)} (x{count})", + label=f"pipeline_many_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run_many"}, + ) as task: + dn.log_inputs( + count=count, + text=self.text, + params=self.params.to_dict() if self.params is not None else {}, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed) + dn.log_output("completions", completions) + task.set_attribute("completions", completions) + return completions ``` @@ -1133,9 +1158,20 @@ async def run_over( sub.generator = generator coros.append(sub.run(allow_failed=(on_failed != "raise"))) - with tracer.span(f"Completion over {len(coros)} generators", count=len(coros)): + short_generators = [g.to_identifier(short=True) for g in _generators] + task_name = "iterate - " + ", ".join(short_generators) + + with dn.task_span( + task_name, + label="iterate_over", + attributes={"rigging.type": "completion_pipeline.run_over"}, + ) as task: + dn.log_input("generators", [g.to_identifier() for g in _generators]) completions = await asyncio.gather(*coros) - return await self._post_run(completions, on_failed) + final_completions = await self._post_run(completions, on_failed) + dn.log_output("completions", final_completions) + task.set_attribute("completions", final_completions) + return final_completions ``` diff --git a/docs/api/generator.mdx b/docs/api/generator.mdx index f9c4c17..e77a07b 100644 --- a/docs/api/generator.mdx +++ b/docs/api/generator.mdx @@ -793,7 +793,11 @@ async def supports_function_calling(self) -> bool | None: ### to\_identifier ```python -to_identifier(params: GenerateParams | None = None) -> str +to_identifier( + params: GenerateParams | None = None, + *, + short: bool = False, +) -> str ``` Converts the generator instance back into a rigging identifier string. @@ -815,7 +819,7 @@ This calls [rigging.generator.get\_identifier][] with the current instance. ```python -def to_identifier(self, params: GenerateParams | None = None) -> str: +def to_identifier(self, params: GenerateParams | None = None, *, short: bool = False) -> str: """ Converts the generator instance back into a rigging identifier string. @@ -827,7 +831,7 @@ def to_identifier(self, params: GenerateParams | None = None) -> str: Returns: The identifier string. """ - return get_identifier(self, params) + return get_identifier(self, params, short=short) ``` @@ -1075,6 +1079,163 @@ min_delay_between_requests: float = 0.0 Minimum time (ms) between each request. This is useful to set when you run into API limits at a provider. +TransformersGenerator +--------------------- + +Generator backed by the Transformers library for local model loading. + + +The use of Transformers requires the `transformers` package to be installed directly or by +installing rigging as `rigging[all]`. + + + +The `transformers` library is expansive with many different models, tokenizers, +options, constructors, etc. We do our best to implement a consistent interface, +but there may be limitations. Where needed, use +[`.from_obj()`][rigging.generator.transformers\_.TransformersGenerator.from\_obj]. + + + +This generator doesn't leverage any async capabilities. + + + +The model load into memory will occur lazily when the first generation is requested. +If you'd want to force this to happen earlier, you can use the +[`.load()`][rigging.generator.Generator.load] method. + +To unload, call [`.unload()`][rigging.generator.Generator.unload]. + + +### device\_map + +```python +device_map: str = 'auto' +``` + +Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto) + +### llm + +```python +llm: AutoModelForCausalLM +``` + +The underlying `AutoModelForCausalLM` instance. + +### load\_in\_4bit + +```python +load_in_4bit: bool = False +``` + +Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto) + +### load\_in\_8bit + +```python +load_in_8bit: bool = False +``` + +Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto) + +### pipeline + +```python +pipeline: TextGenerationPipeline +``` + +The underlying `TextGenerationPipeline` instance. + +### tokenizer + +```python +tokenizer: AutoTokenizer +``` + +The underlying `AutoTokenizer` instance. + +### torch\_dtype + +```python +torch_dtype: str = 'auto' +``` + +Torch dtype passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto) + +### trust\_remote\_code + +```python +trust_remote_code: bool = False +``` + +Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto) + +### from\_obj + +```python +from_obj( + model: Any, + tokenizer: AutoTokenizer, + *, + pipeline: TextGenerationPipeline | None = None, + params: GenerateParams | None = None, +) -> TransformersGenerator +``` + +Create a new instance of TransformersGenerator from an already loaded model and tokenizer. + +**Parameters:** + +* **`model`** + (`Any`) + –The loaded model for text generation. +* **`tokenizer`** + –The tokenizer associated with the model. +* **`pipeline`** + (`TextGenerationPipeline | None`, default: + `None` + ) + –The text generation pipeline. Defaults to None. + +**Returns:** + +* `TransformersGenerator` + –The TransformersGenerator instance. + + +```python +@classmethod +def from_obj( + cls, + model: t.Any, + tokenizer: AutoTokenizer, + *, + pipeline: TextGenerationPipeline | None = None, + params: GenerateParams | None = None, +) -> "TransformersGenerator": + """ + Create a new instance of TransformersGenerator from an already loaded model and tokenizer. + + Args: + model: The loaded model for text generation. + tokenizer : The tokenizer associated with the model. + pipeline: The text generation pipeline. Defaults to None. + + Returns: + The TransformersGenerator instance. + """ + instance = cls(model=model, params=params or GenerateParams(), api_key=None) + instance._llm = model + instance._tokenizer = tokenizer + instance._pipeline = pipeline + return instance +``` + + + + Usage ----- @@ -1104,6 +1265,128 @@ total_tokens: int The total number of tokens processed. +VLLMGenerator +------------- + +Generator backed by the vLLM library for local model loading. + +Find more information about supported models and formats [in their docs.](https://docs.vllm.ai/en/latest/index.html) + + +The use of VLLM requires the `vllm` package to be installed directly or by +installing rigging as `rigging[all]`. + + + +This generator doesn't leverage any async capabilities. + + + +The model load into memory will occur lazily when the first generation is requested. +If you'd want to force this to happen earlier, you can use the +[`.load()`][rigging.generator.Generator.load] method. + +To unload, call [`.unload()`][rigging.generator.Generator.unload]. + + +### dtype + +```python +dtype: str = 'auto' +``` + +Tensor dtype passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) + +### enforce\_eager + +```python +enforce_eager: bool = False +``` + +Eager enforcement passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) + +### gpu\_memory\_utilization + +```python +gpu_memory_utilization: float = 0.9 +``` + +Memory utilization passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) + +### llm + +```python +llm: LLM +``` + +The underlying [`vLLM model`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) instance. + +### quantization + +```python +quantization: str | None = None +``` + +Quantiziation passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) + +### trust\_remote\_code + +```python +trust_remote_code: bool = False +``` + +Trust remote code passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) + +### from\_obj + +```python +from_obj( + model: str, + llm: LLM, + *, + params: GenerateParams | None = None, +) -> VLLMGenerator +``` + +Create a generator from an existing vLLM instance. + +**Parameters:** + +* **`llm`** + (`LLM`) + –The vLLM instance to create the generator from. + +**Returns:** + +* `VLLMGenerator` + –The VLLMGenerator instance. + + +```python +@classmethod +def from_obj( + cls, + model: str, + llm: vllm.LLM, + *, + params: GenerateParams | None = None, +) -> "VLLMGenerator": + """Create a generator from an existing vLLM instance. + + Args: + llm: The vLLM instance to create the generator from. + + Returns: + The VLLMGenerator instance. + """ + generator = cls(model=model, params=params or GenerateParams(), api_key=None) + generator._llm = llm + return generator +``` + + + + chat ---- @@ -1361,6 +1644,8 @@ get\_identifier get_identifier( generator: Generator, params: GenerateParams | None = None, + *, + short: bool = False, ) -> str ``` @@ -1388,7 +1673,9 @@ The `extra` parameter field is not currently supported in identifiers. ```python -def get_identifier(generator: Generator, params: GenerateParams | None = None) -> str: +def get_identifier( + generator: Generator, params: GenerateParams | None = None, *, short: bool = False +) -> str: """ Converts the generator instance back into a rigging identifier string. @@ -1408,7 +1695,10 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) - for name, klass in g_generators.items() if isinstance(klass, type) and isinstance(generator, klass) ) - identifier = f"{provider}!{generator.model}" + identifier = f"{provider}!{generator.model}" if provider != "litellm" else generator.model + + if short: + return identifier identifier_extra = generator.model_dump( exclude_unset=True, @@ -1613,7 +1903,7 @@ def from_obj( Returns: The VLLMGenerator instance. """ - generator = cls(model=model, params=params or GenerateParams()) + generator = cls(model=model, params=params or GenerateParams(), api_key=None) generator._llm = llm return generator ``` @@ -1776,7 +2066,7 @@ def from_obj( Returns: The TransformersGenerator instance. """ - instance = cls(model=model, params=params or GenerateParams()) + instance = cls(model=model, params=params or GenerateParams(), api_key=None) instance._llm = model instance._tokenizer = tokenizer instance._pipeline = pipeline diff --git a/docs/api/message.mdx b/docs/api/message.mdx index 2050fd4..87fe06b 100644 --- a/docs/api/message.mdx +++ b/docs/api/message.mdx @@ -2471,14 +2471,16 @@ def clone(self) -> "MessageSlice": Returns: A new MessageSlice instance with the same properties. """ - return MessageSlice( + cloned = MessageSlice( type=self.type, obj=self.obj, start=self.start, stop=self.stop, metadata=copy.deepcopy(self.metadata), - _message=self._message, # Keep the reference to the original message ) + # Leaving this detached to align with tests + # cloned._message = self._message + return cloned # noqa: RET504 ``` diff --git a/docs/api/prompt.mdx b/docs/api/prompt.mdx index d798f9f..2f0559d 100644 --- a/docs/api/prompt.mdx +++ b/docs/api/prompt.mdx @@ -269,10 +269,36 @@ def bind( raise NotImplementedError( "pipeline.on_failed='skip' cannot be used for prompt methods that return one object", ) + if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput): + raise NotImplementedError( + "pipeline.on_failed='include' cannot be used with prompts that process outputs", + ) async def run(*args: P.args, **kwargs: P.kwargs) -> R: - results = await self.bind_many(pipeline)(1, *args, **kwargs) - return results[0] + name = get_qualified_name(self.func) if self.func else "" + with dn.task_span( + f"prompt - {name}", + attributes={"prompt_name": name, "rigging.type": "prompt.run"}, + ): + dn.log_inputs(**self._bind_args(*args, **kwargs)) + content = self.render(*args, **kwargs) + _pipeline = ( + pipeline.fork(content) + .using(*self.tools, max_depth=self.max_tool_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) + .then(*self.then_callbacks) + .map(*self.map_callbacks) + .watch(*self.watch_callbacks) + .with_(self.params) + ) + + if self.system_prompt: + _pipeline.chat.inject_system_content(self.system_prompt) + + chat = await _pipeline.run() + output = self.process(chat) + dn.log_output("output", output) + return output run.__signature__ = self.__signature__ # type: ignore [attr-defined] run.__name__ = self.__name__ @@ -315,7 +341,7 @@ Example def say_hello(name: str) -> str: """Say hello to {{ name }}""" -await say_hello.bind("gpt-3.5-turbo")(5, "the world") +await say_hello.bind_many("gpt-4.1")(5, "the world") ``` @@ -340,7 +366,7 @@ def bind_many( def say_hello(name: str) -> str: \"""Say hello to {{ name }}\""" - await say_hello.bind("gpt-3.5-turbo")(5, "the world") + await say_hello.bind_many("gpt-4.1")(5, "the world") ~~~ """ pipeline = self._resolve_to_pipeline(other) @@ -351,17 +377,17 @@ def bind_many( async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]: name = get_qualified_name(self.func) if self.func else "" - with tracer.span( - f"Prompt {name}()" if count == 1 else f"Prompt {name}() (x{count})", - count=count, - name=name, - arguments=self._bind_args(*args, **kwargs), + with dn.task_span( + f"prompt - {name} (x{count})", + label=f"prompt_{name}", + attributes={"prompt_name": name, "rigging.type": "prompt.run_many"}, ) as span: + dn.log_inputs(**self._bind_args(*args, **kwargs)) content = self.render(*args, **kwargs) _pipeline = ( pipeline.fork(content) .using(*self.tools, max_depth=self.max_tool_rounds) - .then(self._then_parse, max_depth=self.max_parsing_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) .then(*self.then_callbacks) .map(*self.map_callbacks) .watch(*self.watch_callbacks) @@ -372,35 +398,13 @@ def bind_many( _pipeline.chat.inject_system_content(self.system_prompt) chats = await _pipeline.run_many(count) - - # TODO: I can't remember why we don't just pass the watch_callbacks to the pipeline - # Maybe it has something to do with uniqueness and merging? - - def wrap_watch_callback(callback: "WatchChatCallback") -> "WatchChatCallback": - async def traced_watch_callback(chats: list[Chat]) -> None: - callback_name = get_qualified_name(callback) - with tracer.span( - f"Watch with {callback_name}()", - callback=callback_name, - chat_count=len(chats), - chat_ids=[str(c.uuid) for c in chats], - ): - await callback(chats) - - return traced_watch_callback - - coros = [ - wrap_watch_callback(watch)(chats) - for watch in self.watch_callbacks - if watch not in pipeline.watch_callbacks - ] - await asyncio.gather(*coros) - - results = [self.process(chat) for chat in chats] - span.set_attribute("results", results) - return results + outputs = [self.process(chat) for chat in chats] + span.log_output("outputs", outputs) + return outputs run_many.__rg_prompt__ = self # type: ignore [attr-defined] + run_many.__name__ = self.__name__ + run_many.__doc__ = self.__doc__ return run_many ``` @@ -445,7 +449,7 @@ Example def say_hello(name: str) -> str: """Say hello to {{ name }}""" -await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world") +await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world") ``` @@ -473,7 +477,7 @@ def bind_over( def say_hello(name: str) -> str: \"""Say hello to {{ name }}\""" - await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world") + await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world") ~~~ """ include_original = other is not None @@ -500,7 +504,7 @@ def bind_over( _pipeline = ( pipeline.fork(content) .using(*self.tools, max_depth=self.max_tool_rounds) - .then(self._then_parse, max_depth=self.max_parsing_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) .then(*self.then_callbacks) .map(*self.map_callbacks) .watch(*self.watch_callbacks) @@ -512,13 +516,6 @@ def bind_over( chats = await _pipeline.run_over(*generators, include_original=include_original) - coros = [ - watch(chats) - for watch in self.watch_callbacks - if watch not in pipeline.watch_callbacks - ] - await asyncio.gather(*coros) - return [self.process(chat) for chat in chats] run_over.__rg_prompt__ = self # type: ignore [attr-defined] diff --git a/docs/api/tokenize.mdx b/docs/api/tokenize.mdx index 3b4c849..527053d 100644 --- a/docs/api/tokenize.mdx +++ b/docs/api/tokenize.mdx @@ -366,6 +366,125 @@ async def tokenize_chat(self, chat: "Chat") -> TokenizedChat: ``` + + +TransformersTokenizer +--------------------- + +A tokenizer implementation using Hugging Face Transformers. + +This class provides tokenization capabilities for chat conversations +using transformers models and their associated tokenizers. + +### apply\_chat\_template\_kwargs + +```python +apply_chat_template_kwargs: dict[str, Any] = Field( + default_factory=dict +) +``` + +Additional keyword arguments for applying the chat template. + +### decode\_kwargs + +```python +decode_kwargs: dict[str, Any] = Field(default_factory=dict) +``` + +Additional keyword arguments for decoding tokens. + +### encode\_kwargs + +```python +encode_kwargs: dict[str, Any] = Field(default_factory=dict) +``` + +Additional keyword arguments for encoding text. + +### tokenizer + +```python +tokenizer: PreTrainedTokenizer +``` + +The underlying `PreTrainedTokenizer` instance. + +### encode + +```python +encode(text: str) -> list[int] +``` + +Encodes the given text into a list of tokens. + +**Parameters:** + +* **`text`** + (`str`) + –The text to encode. + +**Returns:** + +* `list[int]` + –A list of tokens representing the encoded text. + + +```python +def encode(self, text: str) -> list[int]: + """ + Encodes the given text into a list of tokens. + + Args: + text: The text to encode. + + Returns: + A list of tokens representing the encoded text. + """ + return self.tokenizer.encode(text, **self.encode_kwargs) # type: ignore [no-any-return] +``` + + + + +### from\_obj + +```python +from_obj( + tokenizer: PreTrainedTokenizer, +) -> TransformersTokenizer +``` + +Create a new instance of TransformersTokenizer from an already loaded tokenizer. + +**Parameters:** + +* **`tokenizer`** + (`PreTrainedTokenizer`) + –The tokenizer associated with the model. + +**Returns:** + +* `TransformersTokenizer` + –The TransformersTokenizer instance. + + +```python +@classmethod +def from_obj(cls, tokenizer: "PreTrainedTokenizer") -> "TransformersTokenizer": + """ + Create a new instance of TransformersTokenizer from an already loaded tokenizer. + + Args: + tokenizer: The tokenizer associated with the model. + + Returns: + The TransformersTokenizer instance. + """ + return cls(model=str(tokenizer), _tokenizer=tokenizer) +``` + + get\_tokenizer diff --git a/docs/api/tools.mdx b/docs/api/tools.mdx index 259d132..4c78436 100644 --- a/docs/api/tools.mdx +++ b/docs/api/tools.mdx @@ -36,17 +36,6 @@ How tool calls are handled. Tool ---- -```python -Tool( - name: str, - description: str, - parameters_schema: dict[str, Any], - fn: Callable[P, R], - catch: bool | set[type[Exception]] = False, - truncate: int | None = None, -) -``` - Base class for representing a tool to a generator. ### catch @@ -82,7 +71,10 @@ A description of the tool. ### fn ```python -fn: Callable[P, R] +fn: Callable[P, R] = Field( + default_factory=lambda: lambda *args, **kwargs: None, + exclude=True, +) ``` The function to call. @@ -153,7 +145,12 @@ async def handle_tool_call( # noqa: PLR0912 from rigging.message import ContentText, ContentTypes, Message - with tracer.span(f"Tool {self.name}()", name=self.name) as span: + with dn.task_span( + f"tool - {self.name}", + attributes={"tool_name": self.name, "rigging.type": "tool"}, + ) as task: + dn.log_input("tool_call", tool_call) + if tool_call.name != self.name: warnings.warn( f"Tool call name mismatch: {tool_call.name} != {self.name}", @@ -163,7 +160,7 @@ async def handle_tool_call( # noqa: PLR0912 return Message.from_model(SystemErrorModel(content="Invalid tool call.")), True if hasattr(tool_call, "id") and isinstance(tool_call.id, str): - span.set_attribute("tool_call_id", tool_call.id) + task.set_attribute("tool_call_id", tool_call.id) result: t.Any stop = False @@ -174,8 +171,9 @@ async def handle_tool_call( # noqa: PLR0912 kwargs = json.loads(tool_call.function.arguments) if self._type_adapter is not None: kwargs = self._type_adapter.validate_python(kwargs) - span.set_attribute("arguments", kwargs) + dn.log_inputs(**kwargs) except (json.JSONDecodeError, ValidationError) as e: + task.set_exception(e) result = ErrorModel.from_exception(e) # Call the function @@ -190,17 +188,18 @@ async def handle_tool_call( # noqa: PLR0912 raise result # noqa: TRY301 except Stop as e: result = f"<{TOOL_STOP_TAG}>{e.message}" - span.set_attribute("stop", True) + task.set_attribute("stop", True) stop = True except Exception as e: if self.catch is True or ( not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch)) ): + task.set_exception(e) result = ErrorModel.from_exception(e) else: raise - span.set_attribute("result", result) + dn.log_output("output", result) message = Message(role="tool", tool_call_id=tool_call.id) @@ -232,17 +231,6 @@ async def handle_tool_call( # noqa: PLR0912 ToolMethod ---------- -```python -ToolMethod( - name: str, - description: str, - parameters_schema: dict[str, Any], - fn: Callable[P, R], - catch: bool | set[type[Exception]] = False, - truncate: int | None = None, -) -``` - A Tool wrapping a class method. tool @@ -683,10 +671,10 @@ def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.A tools.append( Tool( - function.name, - function.description or "", - function.parameters or {}, - make_execute_on_server(url, function.name), + name=function.name, + description=function.description or "", + parameters_schema=function.parameters or {}, + fn=make_execute_on_server(url, function.name), ), ) diff --git a/poetry.lock b/poetry.lock index 9674924..e53e55d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1156,6 +1156,18 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "coolname" +version = "2.2.0" +description = "Random name and slug generator" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "coolname-2.2.0-py2.py3-none-any.whl", hash = "sha256:4d1563186cfaf71b394d5df4c744f8c41303b6846413645e31d31915cdeb13e8"}, + {file = "coolname-2.2.0.tar.gz", hash = "sha256:6c5d5731759104479e7ca195a9b64f7900ac5bead40183c09323c7d0be9e75c7"}, +] + [[package]] name = "coverage" version = "7.9.2" @@ -1432,6 +1444,32 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "dreadnode" +version = "1.12.0" +description = "Dreadnode SDK" +optional = false +python-versions = "<3.14,>=3.10" +groups = ["main"] +files = [ + {file = "dreadnode-1.12.0-py3-none-any.whl", hash = "sha256:0286ba18c47718891e43e39bc8f330f80045d80be2efce8e89c082b1f7101c5a"}, + {file = "dreadnode-1.12.0.tar.gz", hash = "sha256:73204c6ac0424931e505d6ca0598a6703dd7465a61f54d8bc62e0a52e0f98b67"}, +] + +[package.dependencies] +coolname = ">=2.2.0,<3.0.0" +fsspec = {version = ">=2023.1.0,<=2025.3.0", extras = ["s3"]} +httpx = ">=0.28.0,<0.29.0" +logfire = ">=3.5.3,<=3.20.0" +pandas = ">=2.2.3,<3.0.0" +pydantic = ">=2.9.2,<3.0.0" +python-ulid = ">=3.0.0,<4.0.0" +rigging = ">=3.1.1,<4.0.0" + +[package.extras] +multimodal = ["moviepy (>=2.1.2,<3.0.0)", "pillow (>=11.2.1,<12.0.0)", "soundfile (>=0.13.1,<0.14.0)"] +training = ["transformers (>=4.41.0,<5.0.0)"] + [[package]] name = "elastic-transport" version = "8.17.1" @@ -1504,7 +1542,7 @@ version = "2.2.0" description = "Get the currently executing AST node of a frame, and other information" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa"}, {file = "executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755"}, @@ -1680,6 +1718,7 @@ files = [ [package.dependencies] aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} +s3fs = {version = "*", optional = true, markers = "extra == \"s3\""} [package.extras] abfs = ["adlfs"] @@ -1744,6 +1783,24 @@ python-dateutil = ">=2.8.1" [package.extras] dev = ["flake8", "markdown", "twine", "wheel"] +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"}, + {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0)"] + [[package]] name = "griffe" version = "1.7.3" @@ -1874,14 +1931,14 @@ test = ["Cython (>=0.29.24)"] [[package]] name = "httpx" -version = "0.27.2" +version = "0.28.1" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, - {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, ] [package.dependencies] @@ -1889,7 +1946,6 @@ anyio = "*" certifi = "*" httpcore = "==1.*" idna = "*" -sniffio = "*" [package.extras] brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] @@ -2588,17 +2644,51 @@ pydantic = ">=1.10.8" pyyaml = "*" [[package]] -name = "logfire-api" -version = "3.25.0" -description = "Shim for the Logfire SDK which does nothing unless Logfire is installed" +name = "logfire" +version = "3.20.0" +description = "The best Python observability tool! 🪵🔥" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "logfire_api-3.25.0-py3-none-any.whl", hash = "sha256:cc1c2482d6a738e15cd165c483577f8ef7a8a4c462eafa0f6129aa9077676a8d"}, - {file = "logfire_api-3.25.0.tar.gz", hash = "sha256:d6aeeeb246cc8d7aeb14a503523422292047db5e7be35d47c8979f70b0962bb0"}, + {file = "logfire-3.20.0-py3-none-any.whl", hash = "sha256:561ea5f197f4c3a4e521e893f35535b955fc22592fd6cbd5901434c5ad16226d"}, + {file = "logfire-3.20.0.tar.gz", hash = "sha256:592f242edb6ef7e33cc245de6f457ac92c5d012cf48cff0725830bdc5ba602bf"}, ] +[package.dependencies] +executing = ">=2.0.1" +opentelemetry-exporter-otlp-proto-http = ">=1.21.0,<1.35.0" +opentelemetry-instrumentation = ">=0.41b0" +opentelemetry-sdk = ">=1.21.0,<1.35.0" +protobuf = ">=4.23.4" +rich = ">=13.4.2" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +aiohttp = ["opentelemetry-instrumentation-aiohttp-client (>=0.42b0)"] +aiohttp-client = ["opentelemetry-instrumentation-aiohttp-client (>=0.42b0)"] +aiohttp-server = ["opentelemetry-instrumentation-aiohttp-server (>=0.55b0)"] +asgi = ["opentelemetry-instrumentation-asgi (>=0.42b0)"] +asyncpg = ["opentelemetry-instrumentation-asyncpg (>=0.42b0)"] +aws-lambda = ["opentelemetry-instrumentation-aws-lambda (>=0.42b0)"] +celery = ["opentelemetry-instrumentation-celery (>=0.42b0)"] +django = ["opentelemetry-instrumentation-asgi (>=0.42b0)", "opentelemetry-instrumentation-django (>=0.42b0)"] +fastapi = ["opentelemetry-instrumentation-fastapi (>=0.42b0)"] +flask = ["opentelemetry-instrumentation-flask (>=0.42b0)"] +httpx = ["opentelemetry-instrumentation-httpx (>=0.42b0)"] +mysql = ["opentelemetry-instrumentation-mysql (>=0.42b0)"] +psycopg = ["opentelemetry-instrumentation-psycopg (>=0.42b0)", "packaging"] +psycopg2 = ["opentelemetry-instrumentation-psycopg2 (>=0.42b0)", "packaging"] +pymongo = ["opentelemetry-instrumentation-pymongo (>=0.42b0)"] +redis = ["opentelemetry-instrumentation-redis (>=0.42b0)"] +requests = ["opentelemetry-instrumentation-requests (>=0.42b0)"] +sqlalchemy = ["opentelemetry-instrumentation-sqlalchemy (>=0.42b0)"] +sqlite3 = ["opentelemetry-instrumentation-sqlite3 (>=0.42b0)"] +starlette = ["opentelemetry-instrumentation-starlette (>=0.42b0)"] +system-metrics = ["opentelemetry-instrumentation-system-metrics (>=0.42b0)"] +wsgi = ["opentelemetry-instrumentation-wsgi (>=0.42b0)"] + [[package]] name = "loguru" version = "0.7.3" @@ -2634,6 +2724,31 @@ files = [ docs = ["mdx_gh_links (>=0.2)", "mkdocs (>=1.6)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] testing = ["coverage", "pyyaml"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markdownify" version = "1.1.0" @@ -2766,6 +2881,18 @@ cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] rich = ["rich (>=13.9.4)"] ws = ["websockets (>=15.0.1)"] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mergedeep" version = "1.3.4" @@ -3584,6 +3711,124 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] realtime = ["websockets (>=13,<16)"] voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"] +[[package]] +name = "opentelemetry-api" +version = "1.34.1" +description = "OpenTelemetry Python API" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_api-1.34.1-py3-none-any.whl", hash = "sha256:b7df4cb0830d5a6c29ad0c0691dbae874d8daefa934b8b1d642de48323d32a8c"}, + {file = "opentelemetry_api-1.34.1.tar.gz", hash = "sha256:64f0bd06d42824843731d05beea88d4d4b6ae59f9fe347ff7dfa2cc14233bbb3"}, +] + +[package.dependencies] +importlib-metadata = ">=6.0,<8.8.0" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.34.1" +description = "OpenTelemetry Protobuf encoding" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_exporter_otlp_proto_common-1.34.1-py3-none-any.whl", hash = "sha256:8e2019284bf24d3deebbb6c59c71e6eef3307cd88eff8c633e061abba33f7e87"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.34.1.tar.gz", hash = "sha256:b59a20a927facd5eac06edaf87a07e49f9e4a13db487b7d8a52b37cb87710f8b"}, +] + +[package.dependencies] +opentelemetry-proto = "1.34.1" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.34.1" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_exporter_otlp_proto_http-1.34.1-py3-none-any.whl", hash = "sha256:5251f00ca85872ce50d871f6d3cc89fe203b94c3c14c964bbdc3883366c705d8"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.34.1.tar.gz", hash = "sha256:aaac36fdce46a8191e604dcf632e1f9380c7d5b356b27b3e0edb5610d9be28ad"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.34.1" +opentelemetry-proto = "1.34.1" +opentelemetry-sdk = ">=1.34.1,<1.35.0" +requests = ">=2.7,<3.0" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.55b1" +description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_instrumentation-0.55b1-py3-none-any.whl", hash = "sha256:cbb1496b42bc394e01bc63701b10e69094e8564e281de063e4328d122cc7a97e"}, + {file = "opentelemetry_instrumentation-0.55b1.tar.gz", hash = "sha256:2dc50aa207b9bfa16f70a1a0571e011e737a9917408934675b89ef4d5718c87b"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.4,<2.0" +opentelemetry-semantic-conventions = "0.55b1" +packaging = ">=18.0" +wrapt = ">=1.0.0,<2.0.0" + +[[package]] +name = "opentelemetry-proto" +version = "1.34.1" +description = "OpenTelemetry Python Proto" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_proto-1.34.1-py3-none-any.whl", hash = "sha256:eb4bb5ac27f2562df2d6857fc557b3a481b5e298bc04f94cc68041f00cebcbd2"}, + {file = "opentelemetry_proto-1.34.1.tar.gz", hash = "sha256:16286214e405c211fc774187f3e4bbb1351290b8dfb88e8948af209ce85b719e"}, +] + +[package.dependencies] +protobuf = ">=5.0,<6.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.34.1" +description = "OpenTelemetry Python SDK" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_sdk-1.34.1-py3-none-any.whl", hash = "sha256:308effad4059562f1d92163c61c8141df649da24ce361827812c40abb2a1e96e"}, + {file = "opentelemetry_sdk-1.34.1.tar.gz", hash = "sha256:8091db0d763fcd6098d4781bbc80ff0971f94e260739aa6afe6fd379cdf3aa4d"}, +] + +[package.dependencies] +opentelemetry-api = "1.34.1" +opentelemetry-semantic-conventions = "0.55b1" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.55b1" +description = "OpenTelemetry Semantic Conventions" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_semantic_conventions-0.55b1-py3-none-any.whl", hash = "sha256:5da81dfdf7d52e3d37f8fe88d5e771e191de924cfff5f550ab0b8f7b2409baed"}, + {file = "opentelemetry_semantic_conventions-0.55b1.tar.gz", hash = "sha256:ef95b1f009159c28d7a7849f5cbc71c4c34c845bb514d66adfdf1b3fff3598b3"}, +] + +[package.dependencies] +opentelemetry-api = "1.34.1" +typing-extensions = ">=4.5.0" + [[package]] name = "outlines" version = "0.0.46" @@ -4107,22 +4352,23 @@ files = [ [[package]] name = "protobuf" -version = "6.31.0" +version = "5.29.5" description = "" -optional = true -python-versions = ">=3.9" +optional = false +python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"all\"" files = [ - {file = "protobuf-6.31.0-cp310-abi3-win32.whl", hash = "sha256:10bd62802dfa0588649740a59354090eaf54b8322f772fbdcca19bc78d27f0d6"}, - {file = "protobuf-6.31.0-cp310-abi3-win_amd64.whl", hash = "sha256:3e987c99fd634be8347246a02123250f394ba20573c953de133dc8b2c107dd71"}, - {file = "protobuf-6.31.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2c812f0f96ceb6b514448cefeb1df54ec06dde456783f5099c0e2f8a0f2caa89"}, - {file = "protobuf-6.31.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:67ce50195e4e584275623b8e6bc6d3d3dfd93924bf6116b86b3b8975ab9e4571"}, - {file = "protobuf-6.31.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:5353e38844168a327acd2b2aa440044411cd8d1b6774d5701008bd1dba067c79"}, - {file = "protobuf-6.31.0-cp39-cp39-win32.whl", hash = "sha256:96d8da25c83b11db5fe9e0376351ce25e7205e13224d939e097b6f82a72af824"}, - {file = "protobuf-6.31.0-cp39-cp39-win_amd64.whl", hash = "sha256:00a873c06efdfb854145d9ded730b09cf57d206075c38132674093370e2edabb"}, - {file = "protobuf-6.31.0-py3-none-any.whl", hash = "sha256:6ac2e82556e822c17a8d23aa1190bbc1d06efb9c261981da95c71c9da09e9e23"}, - {file = "protobuf-6.31.0.tar.gz", hash = "sha256:314fab1a6a316469dc2dd46f993cbbe95c861ea6807da910becfe7475bc26ffe"}, + {file = "protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079"}, + {file = "protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc"}, + {file = "protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671"}, + {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015"}, + {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61"}, + {file = "protobuf-5.29.5-cp38-cp38-win32.whl", hash = "sha256:ef91363ad4faba7b25d844ef1ada59ff1604184c0bcd8b39b8a6bef15e1af238"}, + {file = "protobuf-5.29.5-cp38-cp38-win_amd64.whl", hash = "sha256:7318608d56b6402d2ea7704ff1e1e4597bee46d760e7e4dd42a3d45e24b87f2e"}, + {file = "protobuf-5.29.5-cp39-cp39-win32.whl", hash = "sha256:6f642dc9a61782fa72b90878af134c5afe1917c89a568cd3476d758d3c3a0736"}, + {file = "protobuf-5.29.5-cp39-cp39-win_amd64.whl", hash = "sha256:470f3af547ef17847a28e1f47200a1cbf0ba3ff57b7de50d22776607cd2ea353"}, + {file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"}, + {file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"}, ] [[package]] @@ -4483,7 +4729,7 @@ version = "2.19.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -4597,6 +4843,21 @@ files = [ {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, ] +[[package]] +name = "python-ulid" +version = "3.0.0" +description = "Universally unique lexicographically sortable identifier" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "python_ulid-3.0.0-py3-none-any.whl", hash = "sha256:e4c4942ff50dbd79167ad01ac725ec58f924b4018025ce22c858bfcff99a5e31"}, + {file = "python_ulid-3.0.0.tar.gz", hash = "sha256:e50296a47dc8209d28629a22fc81ca26c00982c78934bd7766377ba37ea49a9f"}, +] + +[package.extras] +pydantic = ["pydantic (>=2.0)"] + [[package]] name = "pytz" version = "2025.2" @@ -5027,6 +5288,26 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "14.0.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.8.0" +groups = ["main"] +files = [ + {file = "rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0"}, + {file = "rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.25.1" @@ -5180,7 +5461,7 @@ description = "C version of reader, parser and emitter for ruamel.yaml derived f optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"CPython\" and python_version < \"3.14\"" +markers = "platform_python_implementation == \"CPython\"" files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:11f891336688faf5156a36293a9c362bdc7c88f03a8a027c2c1d8e0bcde998e5"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"}, @@ -5188,7 +5469,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, @@ -5197,7 +5477,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, @@ -5206,7 +5485,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, @@ -5215,7 +5493,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, @@ -5224,7 +5501,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -5258,6 +5534,22 @@ files = [ {file = "ruff-0.10.0.tar.gz", hash = "sha256:fa1554e18deaf8aa097dbcfeafaf38b17a2a1e98fdc18f50e62e8a836abee392"}, ] +[[package]] +name = "s3fs" +version = "0.4.2" +description = "Convenient Filesystem interface over S3" +optional = false +python-versions = ">= 3.5" +groups = ["main"] +files = [ + {file = "s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3"}, + {file = "s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0"}, +] + +[package.dependencies] +botocore = ">=1.12.91" +fsspec = ">=0.6.0" + [[package]] name = "s3transfer" version = "0.13.0" @@ -5821,7 +6113,7 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, @@ -6652,6 +6944,95 @@ files = [ [package.extras] dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"] +[[package]] +name = "wrapt" +version = "1.17.2" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"}, + {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"}, + {file = "wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7"}, + {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c"}, + {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72"}, + {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061"}, + {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2"}, + {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c"}, + {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62"}, + {file = "wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563"}, + {file = "wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f"}, + {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58"}, + {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda"}, + {file = "wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438"}, + {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a"}, + {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000"}, + {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6"}, + {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b"}, + {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662"}, + {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72"}, + {file = "wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317"}, + {file = "wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3"}, + {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925"}, + {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392"}, + {file = "wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40"}, + {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d"}, + {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b"}, + {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98"}, + {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82"}, + {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae"}, + {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9"}, + {file = "wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9"}, + {file = "wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991"}, + {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125"}, + {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998"}, + {file = "wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5"}, + {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8"}, + {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6"}, + {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc"}, + {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2"}, + {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b"}, + {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504"}, + {file = "wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a"}, + {file = "wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845"}, + {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192"}, + {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b"}, + {file = "wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0"}, + {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306"}, + {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb"}, + {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681"}, + {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6"}, + {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6"}, + {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f"}, + {file = "wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555"}, + {file = "wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c"}, + {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5c803c401ea1c1c18de70a06a6f79fcc9c5acfc79133e9869e730ad7f8ad8ef9"}, + {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f917c1180fdb8623c2b75a99192f4025e412597c50b2ac870f156de8fb101119"}, + {file = "wrapt-1.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ecc840861360ba9d176d413a5489b9a0aff6d6303d7e733e2c4623cfa26904a6"}, + {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb87745b2e6dc56361bfde481d5a378dc314b252a98d7dd19a651a3fa58f24a9"}, + {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58455b79ec2661c3600e65c0a716955adc2410f7383755d537584b0de41b1d8a"}, + {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4e42a40a5e164cbfdb7b386c966a588b1047558a990981ace551ed7e12ca9c2"}, + {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:91bd7d1773e64019f9288b7a5101f3ae50d3d8e6b1de7edee9c2ccc1d32f0c0a"}, + {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:bb90fb8bda722a1b9d48ac1e6c38f923ea757b3baf8ebd0c82e09c5c1a0e7a04"}, + {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:08e7ce672e35efa54c5024936e559469436f8b8096253404faeb54d2a878416f"}, + {file = "wrapt-1.17.2-cp38-cp38-win32.whl", hash = "sha256:410a92fefd2e0e10d26210e1dfb4a876ddaf8439ef60d6434f21ef8d87efc5b7"}, + {file = "wrapt-1.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:95c658736ec15602da0ed73f312d410117723914a5c91a14ee4cdd72f1d790b3"}, + {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:99039fa9e6306880572915728d7f6c24a86ec57b0a83f6b2491e1d8ab0235b9a"}, + {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2696993ee1eebd20b8e4ee4356483c4cb696066ddc24bd70bcbb80fa56ff9061"}, + {file = "wrapt-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:612dff5db80beef9e649c6d803a8d50c409082f1fedc9dbcdfde2983b2025b82"}, + {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c2caa1585c82b3f7a7ab56afef7b3602021d6da34fbc1cf234ff139fed3cd9"}, + {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c958bcfd59bacc2d0249dcfe575e71da54f9dcf4a8bdf89c4cb9a68a1170d73f"}, + {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc78a84e2dfbc27afe4b2bd7c80c8db9bca75cc5b85df52bfe634596a1da846b"}, + {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ba0f0eb61ef00ea10e00eb53a9129501f52385c44853dbd6c4ad3f403603083f"}, + {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1e1fe0e6ab7775fd842bc39e86f6dcfc4507ab0ffe206093e76d61cde37225c8"}, + {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c86563182421896d73858e08e1db93afdd2b947a70064b813d515d66549e15f9"}, + {file = "wrapt-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f393cda562f79828f38a819f4788641ac7c4085f30f1ce1a68672baa686482bb"}, + {file = "wrapt-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:36ccae62f64235cf8ddb682073a60519426fdd4725524ae38874adf72b5f2aeb"}, + {file = "wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8"}, + {file = "wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3"}, +] + [[package]] name = "xformers" version = "0.0.27.post2" @@ -6970,5 +7351,5 @@ tracing = [] [metadata] lock-version = "2.1" -python-versions = "^3.10" -content-hash = "70c9b2fdea3938a6ed1aa0a1316aaadf2d0a0633fbed8340bdf29d3f35337511" +python-versions = ">=3.10,<3.14" +content-hash = "f89b57179a89ba18d1d0cf3cbc8785348925eabc3b921147c0bbf80326f987f7" diff --git a/pyproject.toml b/pyproject.toml index 81ed7c2..e4c945e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ packages = [{ include = "rigging" }] # Dependencies [tool.poetry.dependencies] -python = "^3.10" +python = ">=3.10,<3.14" pydantic = "^2.7.3" pydantic-xml = "^2.11.0" loguru = "^0.7.2" @@ -21,11 +21,11 @@ xmltodict = "^0.13.0" colorama = "^0.4.6" boto3 = "^1.35.0" boto3-stubs = { extras = ["s3"], version = "^1.35.0" } -logfire-api = "^3.1.1" jsonpath-ng = "^1.7.0" ruamel-yaml = "^0.18.10" jsonref = "^1.1.0" mcp = "^1.5.0" +dreadnode = ">=1.12.0" vllm = { version = "^0.5.0", optional = true } transformers = { version = "^4.41.0", optional = true } @@ -34,10 +34,11 @@ elasticsearch = { version = "^8.13.2", optional = true } asyncssh = { version = "^2.14.2", optional = true } click = { version = "^8.1.7", optional = true } -httpx = { version = "^0.27.0", optional = true } +httpx = { version = "^0.28.0", optional = true } aiodocker = { version = "^0.22.2", optional = true } websockets = { version = "^13.0", optional = true } + [tool.poetry.extras] tracing = ["logfire"] examples = ["asyncssh", "click", "httpx", "aiodocker", "websockets"] diff --git a/rigging/__init__.py b/rigging/__init__.py index 279f8e9..65712ac 100644 --- a/rigging/__init__.py +++ b/rigging/__init__.py @@ -1,4 +1,5 @@ from rigging import ( + caching, data, error, generator, @@ -94,6 +95,7 @@ "Transform", "attr", "await_", + "caching", "chat", "complete", "data", diff --git a/rigging/caching.py b/rigging/caching.py new file mode 100644 index 0000000..765aabc --- /dev/null +++ b/rigging/caching.py @@ -0,0 +1,42 @@ +import typing as t + +from loguru import logger + +if t.TYPE_CHECKING: + from rigging.message import Message + +CacheMode = t.Literal["latest"] +""" +How to handle cache_control entries on messages. + +- latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference. +""" + + +def apply_cache_mode_to_messages( + mode: CacheMode | None, + messages: "list[list[Message]]", +) -> "list[list[Message]]": + if mode is None: + return messages + + if mode != "latest": + logger.warning( + f"Unknown caching mode '{mode}', defaulting to 'latest'", + ) + mode = "latest" + + # first remove existing cache settings + updated: list[list[Message]] = [] + for _messages in messages: + updated = [ + *updated, + [m.clone().cache(cache_control=False) for m in _messages], + ] + + # then apply the latest cache settings + for _messages in updated: + for message in [m for m in _messages if m.role != "assistant"][-2:]: + message.cache(cache_control=True) + + return updated diff --git a/rigging/chat.py b/rigging/chat.py index 5e5c242..0646930 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -18,7 +18,8 @@ from typing import runtime_checkable from uuid import UUID, uuid4 -from loguru import logger +import dreadnode as dn +from dreadnode.metric import ScorerCallable from pydantic import ( BaseModel, ConfigDict, @@ -29,6 +30,7 @@ computed_field, ) +from rigging.caching import CacheMode, apply_cache_mode_to_messages from rigging.error import MaxDepthError, PipelineWarning from rigging.generator import GenerateParams, Generator, get_generator from rigging.generator.base import StopReason, Usage @@ -46,7 +48,6 @@ from rigging.model import Model, ModelT, SystemErrorModel, ValidationErrorModel from rigging.tokenizer import TokenizedChat, Tokenizer, get_tokenizer from rigging.tools import Tool, ToolCall, ToolChoice, ToolMode -from rigging.tracing import Span, tracer from rigging.transform import ( PostTransform, Transform, @@ -59,6 +60,7 @@ from rigging.util import flatten_list, get_qualified_name if t.TYPE_CHECKING: + from dreadnode.scorers.rigging import ChatFilterFunction, ChatFilterMode from elasticsearch import AsyncElasticsearch from rigging.data import ElasticOpType @@ -84,20 +86,15 @@ - include: Mark the message as failed and include it in the final output. """ -CacheMode = t.Literal["latest"] -""" -How to handle cache_control entries on messages. - -- latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference. -""" - class Chat(BaseModel): """ A completed chat interaction. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict( + arbitrary_types_allowed=True, json_schema_extra={"rigging.type": "chat"} + ) uuid: UUID = Field(default_factory=uuid4) """The unique identifier for the chat.""" @@ -765,19 +762,12 @@ def depth(self) -> int: def _wrap_watch_callback(callback: WatchChatCallback) -> WatchChatCallback: callback_name = get_qualified_name(callback) - - async def traced_watch_callback(chats: list[Chat]) -> None: - with tracer.span( - f"Watch with {callback_name}()", - callback=callback_name, - chat_count=len(chats), - chat_ids=[str(c.uuid) for c in chats], - ): - result = callback(chats) - if inspect.isawaitable(result): - await result - - return traced_watch_callback + return dn.task( + name=f"watch - {callback_name}", + attributes={"rigging.type": "chat_pipeline.watch_callback"}, + log_inputs=True, + log_output=False, + )(callback) # Pipeline @@ -812,14 +802,18 @@ def __init__( """How to handle failures in the pipeline unless overridden in calls.""" self.caching: CacheMode | None = None """How to handle cache_control entries on messages.""" + self.task_name: str = generator.to_identifier(short=True) + """The name of the pipeline task, used for logging and debugging.""" + self.scorers: list[dn.Scorer[Chat]] = [] + """List of dreadnode scorers to evaluate the generated chat upon completion.""" self.until_types: list[type[Model]] = [] self.tools: list[Tool[..., t.Any]] = [] self.tool_mode: ToolMode = "auto" self.inject_tool_prompt = True self.add_tool_stop_token = True - self.then_callbacks: list[tuple[ThenChatCallback, int]] = [] - self.map_callbacks: list[tuple[MapChatCallback, int]] = [] + self.then_callbacks: list[tuple[ThenChatCallback, int, bool]] = [] + self.map_callbacks: list[tuple[MapChatCallback, int, bool]] = [] self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or [] self.transforms: list[Transform] = [] @@ -1011,6 +1005,9 @@ def clone( new.errors_to_catch = self.errors_to_catch.copy() new.errors_to_exclude = self.errors_to_exclude.copy() new.caching = self.caching + new.task_name = self.task_name + new.scorers = self.scorers.copy() + new.transforms = self.transforms.copy() new.watch_callbacks = self.watch_callbacks.copy() @@ -1022,18 +1019,18 @@ def clone( return new new.then_callbacks = [ - (callback, max_depth) + (callback, max_depth, as_task) if not hasattr(callback, "__self__") or not isinstance(callback.__self__, ChatPipeline) - else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr] - for callback, max_depth in self.then_callbacks.copy() + else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr] + for callback, max_depth, as_task in self.then_callbacks.copy() ] new.map_callbacks = [ - (callback, max_depth) + (callback, max_depth, as_task) if not hasattr(callback, "__self__") or not isinstance(callback.__self__, ChatPipeline) - else (types.MethodType(callback.__func__, new), max_depth) # type: ignore [union-attr] - for callback, max_depth in self.map_callbacks.copy() + else (types.MethodType(callback.__func__, new), max_depth, as_task) # type: ignore [union-attr] + for callback, max_depth, as_task in self.map_callbacks.copy() ] new.transforms = [ callback @@ -1045,13 +1042,13 @@ def clone( if not isinstance(callbacks, bool): new.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback in callbacks ] new.map_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.map_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.map_callbacks if callback in callbacks ] new.transforms = [callback for callback in self.transforms if callback in callbacks] @@ -1071,11 +1068,25 @@ def meta(self, **kwargs: t.Any) -> "ChatPipeline": self.metadata.update(kwargs) return self + def name(self, name: str) -> "ChatPipeline": + """ + Sets the name of the pipeline. + + Args: + name: The name to set for the pipeline. + + Returns: + The updated pipeline. + """ + self.task_name = name + return self + def then( self, *callbacks: ThenChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> "ChatPipeline": """ Registers one or many callbacks to be executed after the generation process completes. @@ -1089,6 +1100,7 @@ def then( callbacks: The callback functions to be added. max_depth: The maximum depth to allow recursive pipeline calls during this callback. allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added. + as_task: Whether to create a task for this callback. Returns: The updated pipeline. @@ -1110,7 +1122,7 @@ async def process(chat: Chat) -> Chat | None: f"Callback '{get_qualified_name(callback)}' is already registered.", ) - self.then_callbacks.extend([(callback, max_depth) for callback in callbacks]) + self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) return self def map( @@ -1118,6 +1130,7 @@ def map( *callbacks: MapChatCallback, max_depth: int = DEFAULT_MAX_DEPTH, allow_duplicates: bool = False, + as_task: bool = True, ) -> "ChatPipeline": """ Registers a callback to be executed after the generation process completes. @@ -1131,6 +1144,7 @@ def map( callbacks: The callback function to be executed. max_depth: The maximum depth to allow recursive pipeline calls during this callback. allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added. + as_task: Whether to create a task for this callback. Returns: The updated pipeline. @@ -1152,7 +1166,7 @@ async def process(chats: list[Chat]) -> list[Chat]: f"Callback '{get_qualified_name(callback)}' is already registered.", ) - self.map_callbacks.extend([(callback, max_depth) for callback in callbacks]) + self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) return self def transform( @@ -1337,13 +1351,13 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"]) self.tools = [tool for tool in self.tools if tool.name not in new_names] + _tools self.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback != self._then_tools # Always remove to update max_depth ] self.then_callbacks.insert( 0, # make sure this is first - (self._then_tools, max_depth), + (self._then_tools, max_depth, False), ) if mode is not None: @@ -1364,6 +1378,44 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"]) return self + def score( + self, + *scorers: dn.Scorer[Chat] | ScorerCallable[Chat], + filter: "ChatFilterMode | ChatFilterFunction" = "last", + ) -> "ChatPipeline": + """ + Adds one or more scorers to the pipeline to evaluate the generated chat upon completion. + + Args: + *scorers: The scorer or scorers to be added. These can be either: + - A dreadnode.Scorer instance. + - A callable function that can be converted to a dreadnode.Scorer. + filter: The strategy for filtering which messages to include: + - "all": Use all messages in the chat. + - "last": Use only the last message. + - "first": Use only the first message. + - "user": Use only user messages. + - "assistant": Use only assistant messages. + - "last_user": Use only the last user message. + - "last_assistant": Use only the last assistant message. + - A callable that takes a list of `Message` objects and returns a filtered list. + + Returns: + The updated pipeline. + """ + self.scorers.extend( + [ + dn.scorers.wrap_chat( + scorer if isinstance(scorer, dn.Scorer) else dn.Scorer.from_callable(scorer), + filter=filter, + ) + for scorer in scorers + ] + ) + return self + + # Internal callbacks for handling tools and parsing + def until_parsed_as( self, *types: type[ModelT], @@ -1410,16 +1462,14 @@ def until_parsed_as( max_depth = max_rounds or max_depth self.then_callbacks = [ - (callback, max_depth) - for callback, max_depth in self.then_callbacks + (callback, max_depth, as_task) + for callback, max_depth, as_task in self.then_callbacks if callback != self._then_parse ] - self.then_callbacks.append((self._then_parse, max_depth)) + self.then_callbacks.append((self._then_parse, max_depth, False)) return self - # Internal callbacks for handling tools and parsing - async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None: if not chat.last.tool_calls: return None @@ -1467,8 +1517,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: next_pipeline = self.clone(chat=chat) + type_names = " | ".join(sorted(until_type.__name__ for until_type in self.until_types)) + task_name = f"parse - {type_names}" + try: - chat.last.parse_many(*self.until_types) + with dn.task_span(task_name, attributes={"rigging.type": "chat_pipeline.parse"}): + dn.log_input("message", chat.last) + parsed = chat.last.parse_many(*self.until_types) + dn.log_output("parsed", parsed) except ValidationError as e: next_pipeline.add( Message.from_model( @@ -1502,33 +1558,6 @@ def _fit_params( params = [self.params.merge_with(p) for p in params] return [(p or GenerateParams()) for p in params] - def _apply_cache_mode_to_messages( - self, - messages: list[list[Message]], - ) -> list[list[Message]]: - if self.caching is None: - return messages - - if self.caching != "latest": - logger.warning( - f"Unknown caching mode '{self.caching}', defaulting to 'latest'", - ) - - # first remove existing cache settings - updated: list[list[Message]] = [] - for _messages in messages: - updated = [ - *updated, - [m.clone().cache(cache_control=False) for m in _messages], - ] - - # then apply the latest cache settings - for _messages in updated: - for message in [m for m in _messages if m.role != "assistant"][-2:]: - message.cache(cache_control=True) - - return updated - @dataclass class CallbackState: chat: Chat @@ -1548,51 +1577,69 @@ async def complete() -> None: state.completed = True state.ready_event.set() - with tracer.span( - f"Then with {callback_name}()", - callback=callback_name, - chat_id=str(state.chat.uuid), - ): - async with contextlib.AsyncExitStack() as exit_stack: - exit_stack.push_async_callback(complete) - - result = callback(state.chat) - if inspect.isawaitable(result): - result = await result + async with contextlib.AsyncExitStack() as exit_stack: + exit_stack.push_async_callback(complete) - if result is None or isinstance(result, Chat): - state.chat = result or state.chat - return + result = callback(state.chat) + if inspect.isawaitable(result): + result = await result - if isinstance(result, contextlib.AbstractAsyncContextManager): - result = await exit_stack.enter_async_context(result) + if result is None or isinstance(result, Chat): + state.chat = result or state.chat + return - if not inspect.isasyncgen(result): - raise TypeError( - f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None", - ) + if isinstance(result, contextlib.AbstractAsyncContextManager): + result = await exit_stack.enter_async_context(result) - generator = t.cast( - "PipelineStepGenerator", - await exit_stack.enter_async_context(aclosing(result)), + if not inspect.isasyncgen(result): + raise TypeError( + f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None", ) - async for step in generator: - state.step = step - state.ready_event.set() - await state.continue_event.wait() + generator = t.cast( + "PipelineStepGenerator", + await exit_stack.enter_async_context(aclosing(result)), + ) + async for step in generator: + state.step = step + + state.ready_event.set() + await state.continue_event.wait() + + state.ready_event.clear() + state.continue_event.clear() + state.step = None + + state.chat = step.chats[-1] if step.chats else state.chat - state.ready_event.clear() - state.continue_event.clear() - state.step = None + async def _score_chats(self, chats: list[Chat]) -> None: + if not self.scorers: + return + + for scorer in self.scorers: + for metric in await asyncio.gather( + *[scorer(chat) for chat in chats], + ): + dn.log_metric(scorer.name, metric) - state.chat = step.chats[-1] if step.chats else state.chat + def _raise_if_failed( + self, + chats: list[Chat | BaseException] | ChatList, + on_failed: FailMode | None = None, + ) -> None: + for chat in chats: + error = chat.error if isinstance(chat, Chat) else chat + if error is not None and ( + on_failed == "raise" + or not any(isinstance(error, t) for t in self.errors_to_catch) + or any(isinstance(error, t) for t in self.errors_to_exclude) + ): + raise error # Run methods async def _step( # noqa: PLR0915, PLR0912 self, - span: Span, messages: list[list[Message]], params: list[GenerateParams], on_failed: FailMode, @@ -1644,8 +1691,16 @@ async def _step( # noqa: PLR0915, PLR0912 # Pass the messages to the generator try: - messages = self._apply_cache_mode_to_messages(messages) - generated = await self.generator.generate_messages(messages, params) + messages = apply_cache_mode_to_messages(self.caching, messages) + + with dn.task_span( + f"generate - {self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "chat_pipeline.generate"}, + ): + dn.log_input("messages", messages) + dn.log_input("params", params) + generated = await self.generator.generate_messages(messages, params) + dn.log_output("generated", generated) # If we got a total failure here for generation as a whole, # we can't distinguish between incoming messages in terms @@ -1653,9 +1708,6 @@ async def _step( # noqa: PLR0915, PLR0912 # on all of them. except Exception as error: # noqa: BLE001 - span.set_attribute("failed", True) - span.set_attribute("error", error) - chats = ChatList( [ Chat( @@ -1724,7 +1776,6 @@ async def _step( # noqa: PLR0915, PLR0912 # Yield what we generated - span.set_attribute("chats", chats) current_step = PipelineStep( state="generated", chats=chats, @@ -1734,23 +1785,13 @@ async def _step( # noqa: PLR0915, PLR0912 # Check if we should immediately raise - for chat in chats: - if chat.error is not None and ( - on_failed == "raise" - or not any(isinstance(chat.error, t) for t in self.errors_to_catch) - or any(isinstance(chat.error, t) for t in self.errors_to_exclude) - ): - span.set_attribute("error", chat.error) - span.set_attribute("failed", True) - raise chat.error + self._raise_if_failed(chats, on_failed) # Chat cleanup if on_failed == "skip": chats = ChatList([chat for chat in chats if not chat.failed]) - span.set_attribute("chats", chats) - if len(chats) == 0 or all(chat.failed for chat in chats): yield PipelineStep( state="final", @@ -1761,7 +1802,7 @@ async def _step( # noqa: PLR0915, PLR0912 # Then callbacks - for then_callback, max_depth in self.then_callbacks: + for then_callback, max_depth, as_task in self.then_callbacks: callback_name = get_qualified_name(then_callback) states = [ @@ -1773,8 +1814,19 @@ async def _step( # noqa: PLR0915, PLR0912 for chat in chats ] + callback_task = ( + dn.task( + name=f"then - {callback_name}", + attributes={"rigging.type": "chat_pipeline.then_callback"}, + log_inputs=True, + log_output=True, + )(then_callback) + if as_task + else then_callback + ) + tasks = [ - asyncio.create_task(self._process_then_callback(then_callback, state)) + asyncio.create_task(self._process_then_callback(callback_task, state)) # type: ignore [arg-type] for state in states ] @@ -1818,7 +1870,6 @@ async def _step( # noqa: PLR0915, PLR0912 chats = ChatList([state.chat for state in states if state.chat]) - span.set_attribute("chats", chats) current_step = PipelineStep( state="callback", chats=chats, @@ -1832,8 +1883,6 @@ async def _step( # noqa: PLR0915, PLR0912 if on_failed == "skip": chats = ChatList([chat for chat in chats if not chat.failed]) - span.set_attribute("chats", chats) - if len(chats) == 0 or all(chat.failed for chat in chats): yield PipelineStep( state="final", @@ -1844,56 +1893,59 @@ async def _step( # noqa: PLR0915, PLR0912 # Map callbacks - for map_callback, max_depth in self.map_callbacks: + for map_callback, max_depth, as_task in self.map_callbacks: callback_name = get_qualified_name(map_callback) - with tracer.span( - f"Map with {callback_name}()", - callback=callback_name, - chat_count=len(chats), - chat_ids=[str(c.uuid) for c in chats], - ): - async with contextlib.AsyncExitStack() as exit_stack: - result = map_callback(chats) - chats_or_generator = await result if inspect.isawaitable(result) else result - - if isinstance(result, contextlib.AbstractAsyncContextManager): - result = await exit_stack.enter_async_context(result) - - if inspect.isasyncgen(chats_or_generator): - generator = t.cast( - "PipelineStepGenerator", - await exit_stack.enter_async_context( - aclosing(chats_or_generator), - ), - ) - async for step in generator: - _step = step.with_parent(current_step) - if _step.depth > max_depth: - max_depth_error = MaxDepthError( - max_depth, - _step, - callback_name, - ) - if on_failed == "raise": - raise max_depth_error + map_task = ( + dn.task( + name=f"map - {callback_name}", + attributes={"rigging.type": "chat_pipeline.map_callback"}, + log_inputs=True, + log_output=True, + )(map_callback) + if as_task + else map_callback + ) - chats = ChatList(chats) - for chat in chats: - chat.error = max_depth_error - chat.failed = True - else: - yield _step - chats = step.chats + async with contextlib.AsyncExitStack() as exit_stack: + result = map_task(chats) + if inspect.isawaitable(result): + result = await result - chats = step.chats + if isinstance(result, contextlib.AbstractAsyncContextManager): + result = await exit_stack.enter_async_context(result) - elif isinstance(chats_or_generator, list) and all( - isinstance(c, Chat) for c in chats_or_generator - ): - chats = ChatList(chats_or_generator) + if inspect.isasyncgen(result): + generator = t.cast( + "PipelineStepGenerator", + await exit_stack.enter_async_context( + aclosing(result), + ), + ) + async for step in generator: + _step = step.with_parent(current_step) + if _step.depth > max_depth: + max_depth_error = MaxDepthError( + max_depth, + _step, + callback_name, + ) + if on_failed == "raise": + raise max_depth_error + + chats = ChatList(chats) + for chat in chats: + chat.error = max_depth_error + chat.failed = True + else: + yield _step + chats = step.chats + + chats = step.chats + + elif isinstance(result, list) and all(isinstance(c, Chat) for c in result): + chats = ChatList(result) - span.set_attribute("chats", chats) current_step = PipelineStep( state="callback", chats=chats, @@ -1905,7 +1957,6 @@ async def _step( # noqa: PLR0915, PLR0912 if on_failed == "skip": chats = ChatList([chat for chat in chats if not chat.failed]) - span.set_attribute("chats", chats) yield PipelineStep( state="final", chats=chats, @@ -1941,19 +1992,15 @@ async def step( messages = [self.chat.all] params = self._fit_params(1, [self.params]) - with tracer.span( - f"Chat with {self.generator.to_identifier()}", - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator async def run( self, *, + name: str | None = None, on_failed: FailMode | None = None, allow_failed: bool = False, ) -> Chat: @@ -1961,6 +2008,7 @@ async def run( Execute the generation process for a single message. Args: + name: The name of the task for logging purposes. on_failed: The behavior when a message fails to generate. allow_failed: Deprecated, use `on_failed="include"`. @@ -1977,16 +2025,44 @@ async def run( if on_failed is None: on_failed = "include" if allow_failed else self.on_failed + if on_failed == "skip": + raise ValueError( + "Cannot use 'skip' mode with single message generation (pass allow_failed=True or on_failed='include'/'raise')", + ) + + messages = [self.chat.all] + params = self._fit_params(1, [self.params]) + last: PipelineStep | None = None - async with self.step(on_failed=on_failed) as steps: - async for step in steps: - last = step + with dn.task_span( + name or f"pipeline - {self.task_name}", + label=name or f"pipeline_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run"}, + ) as task: + dn.log_inputs( + messages=messages[0], + params=params[0], + generator_id=self.generator.to_identifier(), + ) + + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None and last.chats: + dn.log_output("chat", last.chats[-1]) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) - if last is None or last.state != "final": - raise RuntimeError("The pipeline did not complete successfully") + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") - if not last.chats: - raise RuntimeError("The pipeline process did not produce any chats") + if not last.chats: + raise RuntimeError("The pipeline process did not produce any chats") return last.chats[-1] @@ -2018,22 +2094,18 @@ async def step_many( messages = [self.chat.all] * count params = self._fit_params(count, params) - with tracer.span( - f"Chat with {self.generator.to_identifier()} (x{count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator async def run_many( self, count: int, *, params: t.Sequence[GenerateParams | None] | None = None, + name: str | None = None, + mode: t.Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList: """ @@ -2042,24 +2114,111 @@ async def run_many( Args: count: The number of times to execute the generation process. params: A sequence of parameters to be used for each execution. + name: The name of the task for logging purposes. + mode: The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation on_failed: The behavior when a message fails to generate. Returns: - A list of generatated Chats. + A list of generated Chats. """ + if count < 1: + raise ValueError("Count must be greater than 0") + + on_failed = on_failed or self.on_failed + + messages = [self.chat.all] * count + params = self._fit_params(count, params) last: PipelineStep | None = None - async with self.step_many(count, params=params, on_failed=on_failed) as steps: - async for step in steps: - last = step + with dn.task_span( + name or f"pipeline - {self.task_name} (x{count})", + label=name or f"pipeline_many_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run_many"}, + ) as task: + dn.log_inputs( + count=count, + messages=messages[0], + params=params[0], + generator_id=self.generator.to_identifier(), + ) + + if mode == "merged": + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None: + dn.log_output("chats", last.chats) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) + + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") + + return last.chats + + if mode == "parallel": + tasks = [asyncio.create_task(self.run(on_failed="include")) for _ in range(count)] + chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True) + + self._raise_if_failed(chats_or_errors, on_failed) + + chats = [ + chat + for chat in chats_or_errors + if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed) + ] - if last is None or last.state != "final": - raise ValueError("The generation process did not complete successfully") + dn.log_output("chats", chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", chats) - return last.chats + return ChatList(chats) + + raise ValueError( + f"Invalid mode '{mode}', expected 'merged' or 'parallel'", + ) # Batch messages + def _fit_batch_args( + self, + many: t.Sequence[t.Sequence[Message]] + | t.Sequence[Message] + | t.Sequence[MessageDict] + | t.Sequence[str] + | MessageDict + | str, + params: t.Sequence[GenerateParams | None] | None = None, + ) -> tuple[int, list[list[Message]], list[GenerateParams]]: + # Get the maximum of either incoming messages or params + + count = max(len(many), len(params) if params is not None else 0) + + # If we have less messages than params, we need to either: + # 1. Error because we have >1 messages that we can't reasonably + # zip with our parameters of a different length + # 2. Duplicate a single message we have len(params) times as the + # user is just batching only over parameters + + messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many] + if len(messages) < count: + if len(messages) != 1: + raise ValueError( + f"Can't fit {len(messages)} messages to {count} params", + ) + messages = messages * count + + params = self._fit_params(count, params) + + return count, messages, params + @asynccontextmanager async def step_batch( self, @@ -2088,37 +2247,12 @@ async def step_batch( Pipeline steps. """ on_failed = on_failed or self.on_failed + _, messages, params = self._fit_batch_args(many, params) - # Get the maximum of either incoming messages or params - - count = max(len(many), len(params) if params is not None else 0) - - # If we have less messages than params, we need to either: - # 1. Error because we have >1 messages that we can't reasonably - # zip with our parameters of a different length - # 2. Duplicate a single message we have len(params) times as the - # user is just batching only over parameters - - messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many] - if len(messages) < count: - if len(messages) != 1: - raise ValueError( - f"Can't fit {len(messages)} messages to {count} params", - ) - messages = messages * count - - params = self._fit_params(count, params) - - with tracer.span( - f"Chat batch with {self.generator.to_identifier()} ({count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - async with aclosing( - self._step(span, messages, params, on_failed), - ) as generator: - yield generator + async with aclosing( + self._step(messages, params, on_failed), + ) as generator: + yield generator async def run_batch( self, @@ -2130,6 +2264,8 @@ async def run_batch( | str, params: t.Sequence[GenerateParams | None] | None = None, *, + name: str | None = None, + mode: t.Literal["merged", "parallel"] = "parallel", on_failed: FailMode | None = None, ) -> ChatList: """ @@ -2141,21 +2277,76 @@ async def run_batch( Args: many: A sequence of sequences of messages to be generated. params: A sequence of parameters to be used for each set of messages. + name: The name of the task for logging purposes. + mode: The mode of execution, either "merged" or "parallel". + - In "merged" mode, a single pipeline manages all generation simultaneously + - In "parallel" mode, independent pipelines are created for each generation on_failed: The behavior when a message fails to generate. Returns: A list of generatated Chats. """ + on_failed = on_failed or self.on_failed + count, messages, params = self._fit_batch_args(many, params) last: PipelineStep | None = None - async with self.step_batch(many, params=params, on_failed=on_failed) as steps: - async for step in steps: - last = step + with dn.task_span( + name or f"pipeline - {self.task_name} (batch x{count})", + label=name or f"pipeline_batch_{self.task_name}", + attributes={"rigging.type": "chat_pipeline.run_batch"}, + ) as task: + dn.log_inputs( + count=count, + messages=messages, + params=params, + generator_id=self.generator.to_identifier(), + ) + + if mode == "merged": + try: + async with aclosing( + self._step(messages, params, on_failed), + ) as steps: + async for step in steps: + last = step + finally: + if last is not None: + dn.log_output("chats", last.chats) + await self._score_chats(last.chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", last.chats) + + if last is None or last.state != "final": + raise RuntimeError("The pipeline did not complete successfully") + + return last.chats + + if mode == "parallel": + tasks = [ + asyncio.create_task( + self.clone().add(_messages).with_(_params).run(on_failed="include") + ) + for _messages, _params in zip(messages, params, strict=True) + ] + chats_or_errors = await asyncio.gather(*tasks, return_exceptions=True) + + self._raise_if_failed(chats_or_errors, on_failed) - if last is None or last.state != "final": - raise ValueError("The generation process did not complete successfully") + chats = [ + chat + for chat in chats_or_errors + if isinstance(chat, Chat) and (on_failed != "skip" or not chat.failed) + ] - return last.chats + dn.log_output("chats", chats) + # TODO: Remove once Strikes UI is ported + task.set_attribute("chats", chats) + + return ChatList(chats) + + raise ValueError( + f"Invalid mode '{mode}', expected 'merged' or 'separate'", + ) # Generator iteration @@ -2193,7 +2384,15 @@ async def run_over( sub.generator = generator coros.append(sub.run(allow_failed=(on_failed != "raise"))) - with tracer.span(f"Chat over {len(coros)} generators", count=len(coros)): + short_generators = [g.to_identifier(short=True) for g in _generators] + task_name = "iterate - " + ", ".join(short_generators) + + with dn.task_span( + task_name, + label="iterate_over", + attributes={"rigging.type": "chat_pipeline.run_over"}, + ): + dn.log_input("generators", [g.to_identifier() for g in _generators]) return ChatList(await asyncio.gather(*coros)) # Prompt binding diff --git a/rigging/completion.py b/rigging/completion.py index 4ab1e07..89250bb 100644 --- a/rigging/completion.py +++ b/rigging/completion.py @@ -11,6 +11,7 @@ from typing import runtime_checkable from uuid import UUID, uuid4 +import dreadnode as dn from loguru import logger from pydantic import BaseModel, ConfigDict, Field, computed_field @@ -18,7 +19,6 @@ from rigging.generator import GenerateParams, Generator, get_generator from rigging.generator.base import GeneratedText, StopReason, Usage from rigging.parsing import parse_many -from rigging.tracing import Span, tracer from rigging.util import get_qualified_name if t.TYPE_CHECKING: @@ -571,20 +571,19 @@ def _until_parse_callback(self, text: str) -> bool: return False async def _watch_callback(self, completions: list[Completion]) -> None: - def wrap_watch_callback(callback: WatchCompletionCallback) -> WatchCompletionCallback: - async def traced_watch_callback(completions: list[Completion]) -> None: - callback_name = get_qualified_name(callback) - with tracer.span( - f"Watch with {callback_name}()", - callback=callback_name, - competion_count=len(completions), - completion_ids=[str(c.uuid) for c in completions], - ): - await callback(completions) - - return traced_watch_callback - - coros = [wrap_watch_callback(callback)(completions) for callback in self.watch_callbacks] + def wrap_watch_callback( + callback: WatchCompletionCallback, + ) -> t.Callable[[list[Completion]], t.Awaitable[None]]: + callback_name = get_qualified_name(callback) + return dn.task( + name=f"watch - {callback_name}", + attributes={"rigging.type": "completion_pipeline.watch_callback"}, + log_inputs=True, + log_output=False, + )(callback) + + traced_callbacks = [wrap_watch_callback(callback) for callback in self.watch_callbacks] + coros = [callback(completions) for callback in traced_callbacks] await asyncio.gather(*coros) # TODO: It's opaque exactly how we should blend multiple @@ -633,37 +632,32 @@ async def _post_run( for map_callback in self.map_callbacks: callback_name = get_qualified_name(map_callback) - with tracer.span( - f"Map with {callback_name}()", - callback=callback_name, - completion_count=len(completions), - completion_ids=[str(c.uuid) for c in completions], - ): - completions = await map_callback(completions) - if not all(isinstance(c, Completion) for c in completions): - raise ValueError( - f".map() callback must return a Completion object or None ({callback_name})", - ) - - def wrap_then_callback(callback: ThenCompletionCallback) -> ThenCompletionCallback: - callback_name = get_qualified_name(callback) - - async def traced_then_callback(completion: Completion) -> Completion | None: - with tracer.span( - f"Then with {callback_name}()", - callback=callback_name, - completion_id=str(completion.uuid), - ): - return await callback(completion) - - return traced_then_callback + traced_map_callback = dn.task( + name=f"map - {callback_name}", + attributes={"rigging.type": "completion_pipeline.map_callback"}, + log_inputs=True, + log_output=True, + )(map_callback) + completions = await traced_map_callback(completions) + if not all(isinstance(c, Completion) for c in completions): + raise ValueError( + f".map() callback must return a Completion object or None ({callback_name})", + ) for then_callback in self.then_callbacks: - coros = [wrap_then_callback(then_callback)(completion) for completion in completions] + callback_name = get_qualified_name(then_callback) + traced_then_callback = dn.task( + name=f"then - {callback_name}", + attributes={"rigging.type": "completion_pipeline.then_callback"}, + log_inputs=True, + log_output=True, + )(then_callback) + + coros = [traced_then_callback(completion) for completion in completions] new_completions = await asyncio.gather(*coros) if not all(isinstance(c, Completion) or c is None for c in new_completions): raise ValueError( - f".then() callback must return a Completion object or None ({get_qualified_name(then_callback)})", + f".then() callback must return a Completion object or None ({callback_name})", ) completions = [ @@ -729,7 +723,7 @@ def _initialize_states( async def _run( # noqa: PLR0912 self, - span: Span, + span: dn.Span, states: list[RunState], on_failed: "FailMode", batch_mode: bool = False, # noqa: FBT001, FBT002 @@ -805,12 +799,10 @@ async def _run( # noqa: PLR0912 for state in to_watch_states: state.watched = True - completions = await self._post_run( + return await self._post_run( [s.completion for s in states if s.completion is not None], on_failed, ) - span.set_attribute("completions", completions) - return completions async def run( self, @@ -841,12 +833,21 @@ async def run( on_failed = on_failed or self.on_failed states = self._initialize_states(1) - with tracer.span( - f"Completion with {self.generator.to_identifier()}", - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return (await self._run(span, states, on_failed))[0] + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)}", + label=f"pipeline_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run"}, + ) as task: + dn.log_inputs( + text=self.text, + params=self.params.to_dict() if self.params is not None else {}, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed) + completion = completions[0] + dn.log_output("completion", completion) + task.set_attribute("completions", completions) + return completion __call__ = run @@ -873,13 +874,21 @@ async def run_many( on_failed = on_failed or self.on_failed states = self._initialize_states(count, params) - with tracer.span( - f"Completion with {self.generator.to_identifier()} (x{count})", - count=count, - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return await self._run(span, states, on_failed) + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)} (x{count})", + label=f"pipeline_many_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run_many"}, + ) as task: + dn.log_inputs( + count=count, + text=self.text, + params=self.params.to_dict() if self.params is not None else {}, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed) + dn.log_output("completions", completions) + task.set_attribute("completions", completions) + return completions # Batch completions @@ -913,13 +922,21 @@ async def run_batch( for state in states: next(state.processor) - with tracer.span( - f"Completion batch with {self.generator.to_identifier()} ({len(states)})", - count=len(states), - generator_id=self.generator.to_identifier(), - params=self.params.to_dict() if self.params is not None else {}, - ) as span: - return await self._run(span, states, on_failed, batch_mode=True) + with dn.task_span( + f"pipeline - {self.generator.to_identifier(short=True)} (batch x{len(states)})", + label=f"pipeline_batch_{self.generator.to_identifier(short=True)}", + attributes={"rigging.type": "completion_pipeline.run_batch"}, + ) as task: + dn.log_inputs( + count=len(states), + many=many, + params=params, + generator_id=self.generator.to_identifier(), + ) + completions = await self._run(task, states, on_failed, batch_mode=True) + dn.log_output("completions", completions) + task.set_attribute("completions", completions) + return completions # Generator iteration @@ -957,6 +974,17 @@ async def run_over( sub.generator = generator coros.append(sub.run(allow_failed=(on_failed != "raise"))) - with tracer.span(f"Completion over {len(coros)} generators", count=len(coros)): + short_generators = [g.to_identifier(short=True) for g in _generators] + task_name = "iterate - " + ", ".join(short_generators) + + with dn.task_span( + task_name, + label="iterate_over", + attributes={"rigging.type": "completion_pipeline.run_over"}, + ) as task: + dn.log_input("generators", [g.to_identifier() for g in _generators]) completions = await asyncio.gather(*coros) - return await self._post_run(completions, on_failed) + final_completions = await self._post_run(completions, on_failed) + dn.log_output("completions", final_completions) + task.set_attribute("completions", final_completions) + return final_completions diff --git a/rigging/generator/__init__.py b/rigging/generator/__init__.py index 8acf3a8..9cb346a 100644 --- a/rigging/generator/__init__.py +++ b/rigging/generator/__init__.py @@ -2,6 +2,8 @@ Generators produce completions for a given set of messages or text. """ +import typing as t + from rigging.generator.base import ( GeneratedMessage, GeneratedText, @@ -68,5 +70,12 @@ def get_transformers_lazy() -> type[Generator]: "get_generator", "get_identifier", "register_generator", - # TODO: We can't add VLLM and Transformers here because they are lazy loaded ] + + +def __getattr__(name: str) -> t.Any: + if name == "VLLMGenerator": + return get_vllm_lazy() + if name == "TransformersGenerator": + return get_transformers_lazy() + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/rigging/generator/__init__.pyi b/rigging/generator/__init__.pyi new file mode 100644 index 0000000..027a502 --- /dev/null +++ b/rigging/generator/__init__.pyi @@ -0,0 +1,36 @@ +from rigging.generator.base import ( + GeneratedMessage, + GeneratedText, + GenerateParams, + Generator, + StopReason, + Usage, + chat, + complete, + get_generator, + get_identifier, + register_generator, +) +from rigging.generator.http import HTTPGenerator +from rigging.generator.litellm_ import LiteLLMGenerator +from rigging.generator.transformers_ import TransformersGenerator +from rigging.generator.vllm_ import VLLMGenerator + +__all__ = [ + "GenerateParams", + "GeneratedMessage", + "GeneratedText", + "Generator", + "HTTPGenerator", + "LiteLLMGenerator", + "StopReason", + "TransformersGenerator", + "Usage", + "VLLMGenerator", + "chat", + "complete", + "get_generator", + "get_generator", + "get_identifier", + "register_generator", +] diff --git a/rigging/generator/base.py b/rigging/generator/base.py index b612efa..336b15d 100644 --- a/rigging/generator/base.py +++ b/rigging/generator/base.py @@ -7,7 +7,14 @@ from functools import lru_cache from loguru import logger -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, TypeAdapter, field_validator +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + TypeAdapter, + field_validator, +) from typing_extensions import Self from rigging.error import InvalidGeneratorError @@ -123,7 +130,7 @@ async def wrapper( return result - return wrapper # type: ignore[return-value] + return wrapper # type: ignore [return-value] return decorator @@ -144,7 +151,7 @@ class GenerateParams(BaseModel): Use the `extra` field to pass additional parameters to the API. """ - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", json_schema_extra={"rigging.type": "generate_params"}) temperature: float | None = None """The sampling temperature.""" @@ -294,6 +301,15 @@ class Usage(BaseModel): total_tokens: int """The total number of tokens processed.""" + def __add__(self, other: "Usage") -> "Usage": + if not isinstance(other, Usage): + raise TypeError(f"Cannot add {type(other)} to Usage") + return Usage( + input_tokens=self.input_tokens + other.input_tokens, + output_tokens=self.output_tokens + other.output_tokens, + total_tokens=self.total_tokens + other.total_tokens, + ) + GeneratedT = t.TypeVar("GeneratedT", Message, str) @@ -301,6 +317,8 @@ class Usage(BaseModel): class GeneratedMessage(BaseModel): """A generated message with additional generation information.""" + model_config = ConfigDict(json_schema_extra={"rigging.type": "generated_message"}) + message: Message """The generated message.""" @@ -364,6 +382,8 @@ class Generator(BaseModel): - `generate_texts`: Process a batch of texts. """ + model_config = ConfigDict(json_schema_extra={"rigging.type": "generator"}) + model: str """The model name to be used by the generator.""" api_key: str | None = Field(None, exclude=True) @@ -374,7 +394,7 @@ class Generator(BaseModel): _watch_callbacks: list["WatchChatCallback | WatchCompletionCallback"] = [] _wrap: t.Callable[[CallableT], CallableT] | None = None - def to_identifier(self, params: GenerateParams | None = None) -> str: + def to_identifier(self, params: GenerateParams | None = None, *, short: bool = False) -> str: """ Converts the generator instance back into a rigging identifier string. @@ -386,7 +406,7 @@ def to_identifier(self, params: GenerateParams | None = None) -> str: Returns: The identifier string. """ - return get_identifier(self, params) + return get_identifier(self, params, short=short) def watch( self, @@ -651,7 +671,9 @@ def complete( return generator.complete(text, params) -def get_identifier(generator: Generator, params: GenerateParams | None = None) -> str: +def get_identifier( + generator: Generator, params: GenerateParams | None = None, *, short: bool = False +) -> str: """ Converts the generator instance back into a rigging identifier string. @@ -671,7 +693,10 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) - for name, klass in g_generators.items() if isinstance(klass, type) and isinstance(generator, klass) ) - identifier = f"{provider}!{generator.model}" + identifier = f"{provider}!{generator.model}" if provider != "litellm" else generator.model + + if short: + return identifier identifier_extra = generator.model_dump( exclude_unset=True, diff --git a/rigging/generator/http.py b/rigging/generator/http.py index 70db36c..d2eac0d 100644 --- a/rigging/generator/http.py +++ b/rigging/generator/http.py @@ -328,7 +328,7 @@ async def _generate_message( all_content="\n".join(m.content for m in messages), messages=[m.to_openai() for m in messages], params=params.to_dict(), - api_key=self.api_key, + api_key=self.api_key or "", model=self.model, ) diff --git a/rigging/generator/litellm_.py b/rigging/generator/litellm_.py index e5f5612..a82d809 100644 --- a/rigging/generator/litellm_.py +++ b/rigging/generator/litellm_.py @@ -4,6 +4,7 @@ import re import typing as t +import dreadnode as dn import litellm import litellm.types.utils from loguru import logger @@ -20,7 +21,6 @@ ) from rigging.message import ContentAudioInput, ContentImageUrl, ContentText, Message from rigging.tools.base import FunctionDefinition, ToolDefinition -from rigging.tracing import tracer # We should probably let people configure # this independently, but for now we'll @@ -172,7 +172,7 @@ async def supports_function_calling(self) -> bool | None: # Otherwise we'll run a small check to see if we can - with tracer.span(f"Checking '{self.model}' for function calling support") as span: + with dn.span(f"Checking '{self.model}' for function calling support") as span: try: generated = await self.generate_messages( [[Message(role="user", content="Call the test function")]], diff --git a/rigging/generator/transformers_.py b/rigging/generator/transformers_.py index 9eb08a8..40276f8 100644 --- a/rigging/generator/transformers_.py +++ b/rigging/generator/transformers_.py @@ -80,14 +80,14 @@ def llm(self) -> AutoModelForCausalLM: "load_in_4bit", }, ) - self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call] + self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call] # nosec return self._llm @property def tokenizer(self) -> AutoTokenizer: """The underlying `AutoTokenizer` instance.""" if self._tokenizer is None: - self._tokenizer = AutoTokenizer.from_pretrained(self.model) + self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec return self._tokenizer @property @@ -122,7 +122,7 @@ def from_obj( Returns: The TransformersGenerator instance. """ - instance = cls(model=model, params=params or GenerateParams()) + instance = cls(model=model, params=params or GenerateParams(), api_key=None) instance._llm = model instance._tokenizer = tokenizer instance._pipeline = pipeline diff --git a/rigging/generator/vllm_.py b/rigging/generator/vllm_.py index a2a0e0e..12674b5 100644 --- a/rigging/generator/vllm_.py +++ b/rigging/generator/vllm_.py @@ -90,7 +90,7 @@ def from_obj( Returns: The VLLMGenerator instance. """ - generator = cls(model=model, params=params or GenerateParams()) + generator = cls(model=model, params=params or GenerateParams(), api_key=None) generator._llm = llm return generator @@ -142,7 +142,7 @@ def _generate( return [ GeneratedText( text=o.outputs[-1].text, - stop_reason=o.outputs[-1].finish_reason, + stop_reason=o.outputs[-1].finish_reason or "unknown", extra={ "request_id": o.request_id, "metrics": o.metrics, diff --git a/rigging/message.py b/rigging/message.py index 4020658..6741817 100644 --- a/rigging/message.py +++ b/rigging/message.py @@ -121,14 +121,16 @@ def clone(self) -> "MessageSlice": Returns: A new MessageSlice instance with the same properties. """ - return MessageSlice( + cloned = MessageSlice( type=self.type, obj=self.obj, start=self.start, stop=self.stop, metadata=copy.deepcopy(self.metadata), - _message=self._message, # Keep the reference to the original message ) + # Leaving this detached to align with tests + # cloned._message = self._message + return cloned # noqa: RET504 class ContentText(BaseModel): @@ -405,7 +407,9 @@ class Message(BaseModel): `content_parts` to `content` for compatibility. """ - model_config = ConfigDict(serialize_by_alias=True) + model_config = ConfigDict( + serialize_by_alias=True, json_schema_extra={"rigging.type": "message"} + ) uuid: UUID = Field(default_factory=uuid4, repr=False) """The unique identifier for the message.""" diff --git a/rigging/prompt.py b/rigging/prompt.py index 70bfb4c..7491c0a 100644 --- a/rigging/prompt.py +++ b/rigging/prompt.py @@ -9,6 +9,7 @@ import typing as t from collections import OrderedDict +import dreadnode as dn from jinja2 import Environment, StrictUndefined, meta from pydantic import ValidationError from typing_extensions import Concatenate, ParamSpec # noqa: UP035 @@ -25,7 +26,6 @@ from rigging.message import Message from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive from rigging.tools import Tool -from rigging.tracing import tracer from rigging.util import escape_xml, get_qualified_name, to_snake, to_xml_tag DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})." @@ -561,7 +561,12 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: try: # A bit weird, but we need from_chat to properly handle # wrapping Chat output types inside lists/dataclasses - self.output.from_chat(chat) + with dn.task_span( + f"prompt parse - {self.output.tag}", + attributes={"rigging.type": "prompt.parse"}, + ): + dn.log_input("message", chat.last) + self.output.from_chat(chat) except ValidationError as e: next_pipeline.add( Message.from_model( @@ -836,10 +841,36 @@ def say_hello(name: str) -> str: raise NotImplementedError( "pipeline.on_failed='skip' cannot be used for prompt methods that return one object", ) + if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput): + raise NotImplementedError( + "pipeline.on_failed='include' cannot be used with prompts that process outputs", + ) async def run(*args: P.args, **kwargs: P.kwargs) -> R: - results = await self.bind_many(pipeline)(1, *args, **kwargs) - return results[0] + name = get_qualified_name(self.func) if self.func else "" + with dn.task_span( + f"prompt - {name}", + attributes={"prompt_name": name, "rigging.type": "prompt.run"}, + ): + dn.log_inputs(**self._bind_args(*args, **kwargs)) + content = self.render(*args, **kwargs) + _pipeline = ( + pipeline.fork(content) + .using(*self.tools, max_depth=self.max_tool_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) + .then(*self.then_callbacks) + .map(*self.map_callbacks) + .watch(*self.watch_callbacks) + .with_(self.params) + ) + + if self.system_prompt: + _pipeline.chat.inject_system_content(self.system_prompt) + + chat = await _pipeline.run() + output = self.process(chat) + dn.log_output("output", output) + return output run.__signature__ = self.__signature__ # type: ignore [attr-defined] run.__name__ = self.__name__ @@ -867,7 +898,7 @@ def bind_many( def say_hello(name: str) -> str: \"""Say hello to {{ name }}\""" - await say_hello.bind("gpt-3.5-turbo")(5, "the world") + await say_hello.bind_many("gpt-4.1")(5, "the world") ``` """ pipeline = self._resolve_to_pipeline(other) @@ -878,17 +909,17 @@ def say_hello(name: str) -> str: async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]: name = get_qualified_name(self.func) if self.func else "" - with tracer.span( - f"Prompt {name}()" if count == 1 else f"Prompt {name}() (x{count})", - count=count, - name=name, - arguments=self._bind_args(*args, **kwargs), + with dn.task_span( + f"prompt - {name} (x{count})", + label=f"prompt_{name}", + attributes={"prompt_name": name, "rigging.type": "prompt.run_many"}, ) as span: + dn.log_inputs(**self._bind_args(*args, **kwargs)) content = self.render(*args, **kwargs) _pipeline = ( pipeline.fork(content) .using(*self.tools, max_depth=self.max_tool_rounds) - .then(self._then_parse, max_depth=self.max_parsing_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) .then(*self.then_callbacks) .map(*self.map_callbacks) .watch(*self.watch_callbacks) @@ -899,35 +930,13 @@ async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]: _pipeline.chat.inject_system_content(self.system_prompt) chats = await _pipeline.run_many(count) - - # TODO: I can't remember why we don't just pass the watch_callbacks to the pipeline - # Maybe it has something to do with uniqueness and merging? - - def wrap_watch_callback(callback: "WatchChatCallback") -> "WatchChatCallback": - async def traced_watch_callback(chats: list[Chat]) -> None: - callback_name = get_qualified_name(callback) - with tracer.span( - f"Watch with {callback_name}()", - callback=callback_name, - chat_count=len(chats), - chat_ids=[str(c.uuid) for c in chats], - ): - await callback(chats) - - return traced_watch_callback - - coros = [ - wrap_watch_callback(watch)(chats) - for watch in self.watch_callbacks - if watch not in pipeline.watch_callbacks - ] - await asyncio.gather(*coros) - - results = [self.process(chat) for chat in chats] - span.set_attribute("results", results) - return results + outputs = [self.process(chat) for chat in chats] + span.log_output("outputs", outputs) + return outputs run_many.__rg_prompt__ = self # type: ignore [attr-defined] + run_many.__name__ = self.__name__ + run_many.__doc__ = self.__doc__ return run_many @@ -953,7 +962,7 @@ def bind_over( def say_hello(name: str) -> str: \"""Say hello to {{ name }}\""" - await say_hello.bind("gpt-3.5-turbo")(["gpt-4o", "gpt-4"], "the world") + await say_hello.bind_over()(["gpt-4o", "gpt-4.1", "o4-mini"], "the world") ``` """ include_original = other is not None @@ -980,7 +989,7 @@ async def run_over( _pipeline = ( pipeline.fork(content) .using(*self.tools, max_depth=self.max_tool_rounds) - .then(self._then_parse, max_depth=self.max_parsing_rounds) + .then(self._then_parse, max_depth=self.max_parsing_rounds, as_task=False) .then(*self.then_callbacks) .map(*self.map_callbacks) .watch(*self.watch_callbacks) @@ -992,13 +1001,6 @@ async def run_over( chats = await _pipeline.run_over(*generators, include_original=include_original) - coros = [ - watch(chats) - for watch in self.watch_callbacks - if watch not in pipeline.watch_callbacks - ] - await asyncio.gather(*coros) - return [self.process(chat) for chat in chats] run_over.__rg_prompt__ = self # type: ignore [attr-defined] diff --git a/rigging/tokenizer/__init__.py b/rigging/tokenizer/__init__.py index 4180205..a8c2b02 100644 --- a/rigging/tokenizer/__init__.py +++ b/rigging/tokenizer/__init__.py @@ -2,6 +2,8 @@ Tokenizers encode chats and associated message data into tokens for training and inference. """ +import typing as t + from rigging.tokenizer.base import ( TokenizedChat, Tokenizer, @@ -31,3 +33,9 @@ def get_transformers_lazy() -> type[Tokenizer]: "get_tokenizer", "register_tokenizer", ] + + +def __getattr__(name: str) -> t.Any: + if name == "TransformersTokenizer": + return get_transformers_lazy() + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/rigging/tokenizer/__init__.pyi b/rigging/tokenizer/__init__.pyi new file mode 100644 index 0000000..1c74d9a --- /dev/null +++ b/rigging/tokenizer/__init__.pyi @@ -0,0 +1,17 @@ +from rigging.tokenizer.base import ( + TokenizedChat, + Tokenizer, + TokenSlice, + get_tokenizer, + register_tokenizer, +) +from rigging.tokenizer.transformers_ import TransformersTokenizer + +__all__ = [ + "TokenSlice", + "TokenizedChat", + "Tokenizer", + "TransformersTokenizer", + "get_tokenizer", + "register_tokenizer", +] diff --git a/rigging/tokenizer/transformers_.py b/rigging/tokenizer/transformers_.py index ec24aef..8f2616c 100644 --- a/rigging/tokenizer/transformers_.py +++ b/rigging/tokenizer/transformers_.py @@ -34,7 +34,7 @@ class TransformersTokenizer(Tokenizer): def tokenizer(self) -> "PreTrainedTokenizer": """The underlying `PreTrainedTokenizer` instance.""" if self._tokenizer is None: - self._tokenizer = AutoTokenizer.from_pretrained(self.model) + self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec return self._tokenizer @classmethod diff --git a/rigging/tools/base.py b/rigging/tools/base.py index d6ff410..f8d8a6f 100644 --- a/rigging/tools/base.py +++ b/rigging/tools/base.py @@ -8,11 +8,19 @@ import re import typing as t import warnings -from dataclasses import dataclass, field from functools import cached_property +import dreadnode as dn import typing_extensions as te -from pydantic import BaseModel, TypeAdapter, ValidationError, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + TypeAdapter, + ValidationError, + field_validator, +) from pydantic_xml import attr from rigging.error import Stop, ToolDefinitionError, ToolWarning @@ -23,7 +31,6 @@ make_from_schema, make_from_signature, ) -from rigging.tracing import tracer from rigging.util import deref_json if t.TYPE_CHECKING: @@ -103,6 +110,8 @@ class FunctionCall(BaseModel): class ToolCall(BaseModel): + model_config = ConfigDict(json_schema_extra={"rigging.type": "tool_call"}) + id: str type: t.Literal["function"] = "function" function: FunctionCall @@ -133,8 +142,7 @@ def _is_unbound_method(func: t.Any) -> bool: return is_method is not hasattr(func, "__self__") -@dataclass -class Tool(t.Generic[P, R]): +class Tool(BaseModel, t.Generic[P, R]): """Base class for representing a tool to a generator.""" name: str @@ -143,7 +151,10 @@ class Tool(t.Generic[P, R]): """A description of the tool.""" parameters_schema: dict[str, t.Any] """The JSON schema for the tool's parameters.""" - fn: t.Callable[P, R] + fn: t.Callable[P, R] = Field( # type: ignore [assignment] + default_factory=lambda: lambda *args, **kwargs: None, # noqa: ARG005 + exclude=True, + ) """The function to call.""" catch: bool | set[type[Exception]] = False """ @@ -156,13 +167,9 @@ class Tool(t.Generic[P, R]): truncate: int | None = None """If set, the maximum number of characters to truncate any tool output to.""" - _signature: inspect.Signature | None = field(default=None, init=False, repr=False) - _type_adapter: TypeAdapter[t.Any] | None = field( - default=None, - init=False, - repr=False, - ) - _model: type[Model] | None = field(default=None, init=False, repr=False) + _signature: inspect.Signature | None = PrivateAttr(default=None, init=False) + _type_adapter: TypeAdapter[t.Any] | None = PrivateAttr(default=None, init=False) + _model: type[Model] | None = PrivateAttr(default=None, init=False) # In general we are split between 2 strategies for handling the data translations: # @@ -281,7 +288,7 @@ def empty_func(*args, **kwargs): # type: ignore [no-untyped-def] # noqa: ARG001 ) self._signature = signature - self.__signature__ = signature # type: ignore [attr-defined] + self.__signature__ = signature # type: ignore [misc] self.__name__ = self.name # type: ignore [attr-defined] self.__doc__ = self.description @@ -351,7 +358,12 @@ async def handle_tool_call( # noqa: PLR0912 from rigging.message import ContentText, ContentTypes, Message - with tracer.span(f"Tool {self.name}()", name=self.name) as span: + with dn.task_span( + f"tool - {self.name}", + attributes={"tool_name": self.name, "rigging.type": "tool"}, + ) as task: + dn.log_input("tool_call", tool_call) + if tool_call.name != self.name: warnings.warn( f"Tool call name mismatch: {tool_call.name} != {self.name}", @@ -361,7 +373,7 @@ async def handle_tool_call( # noqa: PLR0912 return Message.from_model(SystemErrorModel(content="Invalid tool call.")), True if hasattr(tool_call, "id") and isinstance(tool_call.id, str): - span.set_attribute("tool_call_id", tool_call.id) + task.set_attribute("tool_call_id", tool_call.id) result: t.Any stop = False @@ -372,8 +384,9 @@ async def handle_tool_call( # noqa: PLR0912 kwargs = json.loads(tool_call.function.arguments) if self._type_adapter is not None: kwargs = self._type_adapter.validate_python(kwargs) - span.set_attribute("arguments", kwargs) + dn.log_inputs(**kwargs) except (json.JSONDecodeError, ValidationError) as e: + task.set_exception(e) result = ErrorModel.from_exception(e) # Call the function @@ -388,17 +401,18 @@ async def handle_tool_call( # noqa: PLR0912 raise result # noqa: TRY301 except Stop as e: result = f"<{TOOL_STOP_TAG}>{e.message}" - span.set_attribute("stop", True) + task.set_attribute("stop", True) stop = True except Exception as e: if self.catch is True or ( not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch)) ): + task.set_exception(e) result = ErrorModel.from_exception(e) else: raise - span.set_attribute("result", result) + dn.log_output("output", result) message = Message(role="tool", tool_call_id=tool_call.id) @@ -521,7 +535,7 @@ def __get__(self, instance: t.Any, owner: t.Any) -> "Tool[P, R]": catch=self.catch, ) - bound_tool.__signature__ = self.__signature__ # type: ignore [attr-defined] + bound_tool.__signature__ = self.__signature__ # type: ignore [misc] bound_tool._signature = self._signature # noqa: SLF001 bound_tool._type_adapter = self._type_adapter # noqa: SLF001 bound_tool._model = self._model # noqa: SLF001 diff --git a/rigging/tools/robopages.py b/rigging/tools/robopages.py index a84b08f..e3fe335 100644 --- a/rigging/tools/robopages.py +++ b/rigging/tools/robopages.py @@ -92,10 +92,10 @@ def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.A tools.append( Tool( - function.name, - function.description or "", - function.parameters or {}, - make_execute_on_server(url, function.name), + name=function.name, + description=function.description or "", + parameters_schema=function.parameters or {}, + fn=make_execute_on_server(url, function.name), ), ) diff --git a/rigging/tracing.py b/rigging/tracing.py deleted file mode 100644 index ff9e103..0000000 --- a/rigging/tracing.py +++ /dev/null @@ -1,32 +0,0 @@ -import typing as t - -import logfire_api - -Span = logfire_api.LogfireSpan - - -class Tracer(logfire_api.Logfire): - def span( - self, - msg_template: str, - /, - *, - _tags: t.Sequence[str] | None = None, - _span_name: str | None = None, - _level: t.Any | None = None, - _links: t.Any = (), - **attributes: t.Any, - ) -> logfire_api.LogfireSpan: - # Pass msg_template as the span name - # to avoid weird fstring behaviors - return super().span( - msg_template, - _tags=_tags, - _span_name=msg_template, - _level=_level, - _links=_links, - **attributes, - ) - - -tracer = Tracer(otel_scope="rigging") diff --git a/rigging/transform/xml_tools.py b/rigging/transform/xml_tools.py index da4b239..1baf998 100644 --- a/rigging/transform/xml_tools.py +++ b/rigging/transform/xml_tools.py @@ -4,7 +4,7 @@ import uuid import warnings -import xmltodict # type: ignore[import-untyped] +import xmltodict # type: ignore [import-untyped] from pydantic.fields import FieldInfo from pydantic_xml import attr diff --git a/tests/test_chat.py b/tests/test_chat.py index 66e0152..e625aa2 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -171,7 +171,7 @@ def test_message_double_content_part_separation() -> None: def test_chat_generator_id() -> None: generator = get_generator("gpt-3.5") chat = Chat([], generator=generator) - assert chat.generator_id == "litellm!gpt-3.5" + assert chat.generator_id == "gpt-3.5" other = Chat([]) assert other.generator_id is None diff --git a/tests/test_chat_pipeline.py b/tests/test_chat_pipeline.py index 96f6d69..ae16f04 100644 --- a/tests/test_chat_pipeline.py +++ b/tests/test_chat_pipeline.py @@ -188,13 +188,36 @@ async def double_chats(chats: list[Chat]) -> list[Chat]: return chats + new_chats chats = ( - await generator.chat([{"role": "user", "content": "Hello"}]).map(double_chats).run_many(1) + await generator.chat([{"role": "user", "content": "Hello"}]) + .map(double_chats) + .run_many(1, mode="merged") ) assert len(chats) == 2 assert chats[0].last.content == "Response 1" assert chats[1].last.content == "Modified: Response 1" + # in parallel mode, we expect only one chat per internal pipeline + + chats = ( + await generator.chat([{"role": "user", "content": "Hello"}]) + .map(double_chats) + .run_many(1, mode="parallel") + ) + + assert len(chats) == 1 + assert chats[0].last.content == "Modified: Response 1" + + chats = ( + await generator.chat([{"role": "user", "content": "Hello"}]) + .map(double_chats) + .run_many(2, mode="parallel") + ) + + assert len(chats) == 2 + assert chats[0].last.content == "Modified: Response 1" + assert chats[1].last.content == "Modified: Response 1" + @pytest.mark.asyncio async def test_watch_callback() -> None: diff --git a/tests/test_completion_pipeline.py b/tests/test_completion_pipeline.py index 1530bd6..8387130 100644 --- a/tests/test_completion_pipeline.py +++ b/tests/test_completion_pipeline.py @@ -9,7 +9,7 @@ def test_completion_generator_id() -> None: generator = get_generator("gpt-3.5") completion = Completion("foo", "bar", generator) - assert completion.generator_id == "litellm!gpt-3.5" + assert completion.generator_id == "gpt-3.5" completion.generator = None assert completion.generator_id is None diff --git a/tests/test_generator_ids.py b/tests/test_generator_ids.py index 134e407..a53560e 100644 --- a/tests/test_generator_ids.py +++ b/tests/test_generator_ids.py @@ -48,10 +48,10 @@ def test_get_generator_with_params(identifier: str, valid_params: GenerateParams @pytest.mark.parametrize( "identifier", [ - ("litellm!test_model,max_tokens=1024,top_p=0.1"), - ("litellm!custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"), - ("litellm!many/model/slashes,stop=a;b;c;"), - ("litellm!with_cls_args,max_connections=10"), + ("test_model,max_tokens=1024,top_p=0.1"), + ("custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"), + ("many/model/slashes,stop=a;b;c;"), + ("http!with_cls_args"), ], ) def test_identifier_roundtrip(identifier: str) -> None: diff --git a/tests/test_tool.py b/tests/test_tool.py index f462c26..dfcdb50 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -24,8 +24,8 @@ def simple_function(name: str, age: int) -> str: assert tool.name == "simple_function" assert tool.description == "A simple function that returns a greeting." - assert "_signature" in tool.__dict__ - assert "_type_adapter" in tool.__dict__ + assert getattr(tool, "_signature", None) is not None + assert getattr(tool, "_type_adapter", None) is not None # Check schema assert "name" in tool.parameters_schema["properties"]