diff --git a/.secrets.baseline b/.secrets.baseline index 413c772..cac7826 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -124,6 +124,12 @@ }, { "path": "detect_secrets.filters.heuristic.is_templated_secret" + }, + { + "path": "detect_secrets.filters.regex.should_exclude_file", + "pattern": [ + "examples/*" + ] } ], "results": { @@ -133,14 +139,14 @@ "filename": "docs/topics/generators.mdx", "hashed_secret": "ef5225a03e4f9cc953ab3c4dd41f5c4db7dc2e5b", "is_verified": false, - "line_number": 342 + "line_number": 360 }, { "type": "Secret Keyword", "filename": "docs/topics/generators.mdx", "hashed_secret": "eb6256c862c356b375aafa760fa1851e33aa62a9", "is_verified": false, - "line_number": 366 + "line_number": 384 } ], "tests/test_http_spec.py": [ @@ -153,5 +159,5 @@ } ] }, - "generated_at": "2025-07-23T00:26:51Z" + "generated_at": "2025-09-03T21:56:13Z" } diff --git a/docs/api/error.mdx b/docs/api/error.mdx index 7ee8cec..fb26640 100644 --- a/docs/api/error.mdx +++ b/docs/api/error.mdx @@ -68,6 +68,14 @@ max_rounds = max_rounds The number of rounds which was exceeded. +GeneratorWarning +---------------- + +Base class for all generator warnings. + +This is used to indicate that something unexpected happened during the generator execution, +but it is not critical enough to stop the execution. + InvalidGeneratorError --------------------- diff --git a/docs/api/generator.mdx b/docs/api/generator.mdx index 866ed02..055820b 100644 --- a/docs/api/generator.mdx +++ b/docs/api/generator.mdx @@ -1934,7 +1934,9 @@ get\_generator ```python get_generator( - identifier: str, *, params: GenerateParams | None = None + identifier: str, + *, + params: GenerateParams | dict[str, Any] | None = None, ) -> Generator ``` @@ -1964,7 +1966,7 @@ You can also specify arguments to the generator by comma-separating them: (`str`) –The identifier string to use to get a generator. * **`params`** - (`GenerateParams | None`, default: + (`GenerateParams | dict[str, Any] | None`, default: `None` ) –The generation parameters to use for the generator. @@ -1982,8 +1984,9 @@ You can also specify arguments to the generator by comma-separating them: ```python -@lru_cache(maxsize=128) -def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator: +def get_generator( + identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None +) -> Generator: """ Get a generator by an identifier string. Uses LiteLLM by default. @@ -2080,6 +2083,8 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G if isinstance(v, str) and v.lower() in ["true", "false"]: init_kwargs[k] = v.lower() == "true" + params = GenerateParams(**params) if isinstance(params, dict) else params + try: merged_params = GenerateParams(**kwargs).merge_with(params) except Exception as e: diff --git a/docs/api/model.mdx b/docs/api/model.mdx index b110948..0407ae5 100644 --- a/docs/api/model.mdx +++ b/docs/api/model.mdx @@ -237,7 +237,7 @@ def from_text( # Walk through any fields which are strings, and dedent them for field_name, field_info in cls.model_fields.items(): - if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721 + if isinstance(field_info, XmlEntityInfo) and field_info.annotation is str: model.__dict__[field_name] = textwrap.dedent( model.__dict__[field_name] ).strip() @@ -245,7 +245,7 @@ def from_text( extracted.append((model, slice_)) except Exception as e: # noqa: BLE001 extracted.append((e, slice_)) - continue + continue # sort back to original order extracted.sort(key=lambda x: x[1].start) @@ -471,11 +471,7 @@ def preprocess_with_cdata(cls, content: str) -> str: } else: field_map = { - ( - field_info.path - if isinstance(field_info, XmlEntityInfo) and field_info.path - else field_name - ): field_info + cls._get_field_xml_name(field_name, field_info): field_info for field_name, field_info in cls.model_fields.items() if isinstance(field_info, XmlEntityInfo) and field_info.location == EntityLocation.ELEMENT @@ -722,7 +718,7 @@ def xml_example(cls) -> str: isinstance(field_info, XmlEntityInfo) and field_info.location == EntityLocation.ATTRIBUTE ): - path = field_info.path or field_name + path = cls._get_field_xml_name(field_name, field_info) example = str(next(iter(field_info.examples or []), "")).replace('"', """) attribute_parts.append(f'{path}="{example}"') else: @@ -732,7 +728,7 @@ def xml_example(cls) -> str: lines.append(f"<{cls.__xml_tag__}{attr_string}>") for field_name, field_info in element_fields.items(): - path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name + path = cls._get_field_xml_name(field_name, field_info) description = field_info.description example = str(next(iter(field_info.examples or []), "")) @@ -912,12 +908,16 @@ def make_from_schema( for field_name, field_schema in properties.items(): field_type, field_info = _process_field(field_name, field_schema) + # Use the field name as alias if it differs from python naming conventions + alias = field_name if field_name != field_name.replace("-", "_") else None + fields[field_name] = ( field_type, field_cls( default=... if field_name in required else None, description=field_schema.get("description", ""), title=field_schema.get("title", ""), + alias=alias, **field_info, ) if isinstance(field_info, dict) diff --git a/docs/topics/generators.mdx b/docs/topics/generators.mdx index b8c394b..f1b28ec 100644 --- a/docs/topics/generators.mdx +++ b/docs/topics/generators.mdx @@ -22,7 +22,7 @@ openai/o3-mini gemini/gemini-2.5-pro claude-4-sonnet-latest vllm_hosted/meta-llama/Llama-3.1-8B-Instruct -ollama/qwen3 +ollama_chat/qwen3 openai/gpt-4,api_key=sk-1234 anthropic/claude-3-7-haiku-latest,stop=output:;---,seed=1337 @@ -184,12 +184,30 @@ NAME ID SIZE PROCESSOR UNTIL qwen3:0.6b 7df6b6e09427 2.3 GB 100% GPU 4 minutes from now ``` -Using this model in Rigging is as simple as using the `ollama/` or `ollama_chat/` prefixes: + +Ollama is configured with a maximum context length on the server, by default 4096 tokens. This does not change depending on model and requires configuration to update. + +If the input messages to the API would exceed this length, Ollama will silently truncate them to fit in the context window. This behavior can cause unexpected generation results due to missing context and is very difficult to detect in Rigging. + +We make a best effort by monitoring model responses and checking if the reported input tokens is far less than the input messages we just sent. If observed, the following warning will be emitted. + +``` +GeneratorWarning: Input messages may have been truncated ... +``` + +When in doubt, monitor the Ollama server logs for the following: + +```bash +... msg="truncating input prompt" limit=4096 prompt=6767 keep=4 new=409 +``` + + +Using this model in Rigging is as simple as using the `ollama_chat/` (recommended) or `ollama/` prefixes: ```python import rigging as rg -qwen = rg.get_generator("ollama/qwen3:0.6b") +qwen = rg.get_generator("ollama_chat/qwen3:0.6b") chat = await qwen.chat("Hello!").run() print(chat.conversation) @@ -211,7 +229,7 @@ If you are running the Ollama server somewhere besides localhost, just pass the ```python qwen = rg.get_generator( - "ollama/qwen3:0.6b,api_base=http://remote-server:11434" + "ollama_chat/qwen3:0.6b,api_base=http://remote-server:11434" ) ``` diff --git a/pyproject.toml b/pyproject.toml index c69fd44..acddc54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rigging" -version = "3.3.2" +version = "3.3.3" description = "LLM Interaction Framework" authors = ["Nick Landers "] license = "MIT" diff --git a/rigging/error.py b/rigging/error.py index 321ec09..9d25539 100644 --- a/rigging/error.py +++ b/rigging/error.py @@ -82,6 +82,15 @@ class TokenizerWarning(Warning): """ +class GeneratorWarning(Warning): + """ + Base class for all generator warnings. + + This is used to indicate that something unexpected happened during the generator execution, + but it is not critical enough to stop the execution. + """ + + # System Exceptions diff --git a/rigging/generator/base.py b/rigging/generator/base.py index 4233b3f..64f3884 100644 --- a/rigging/generator/base.py +++ b/rigging/generator/base.py @@ -4,7 +4,6 @@ import functools import inspect import typing as t -from functools import lru_cache from loguru import logger from pydantic import ( @@ -738,8 +737,9 @@ def encode_value(val: t.Any) -> t.Any: return identifier -@lru_cache(maxsize=128) -def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator: +def get_generator( + identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None +) -> Generator: """ Get a generator by an identifier string. Uses LiteLLM by default. @@ -836,6 +836,8 @@ def decode_value(value: str) -> t.Any: if isinstance(v, str) and v.lower() in ["true", "false"]: init_kwargs[k] = v.lower() == "true" + params = GenerateParams(**params) if isinstance(params, dict) else params + try: merged_params = GenerateParams(**kwargs).merge_with(params) except Exception as e: diff --git a/rigging/generator/litellm_.py b/rigging/generator/litellm_.py index f59a829..47c1805 100644 --- a/rigging/generator/litellm_.py +++ b/rigging/generator/litellm_.py @@ -3,9 +3,11 @@ import datetime import re import typing as t +import warnings from loguru import logger +from rigging.error import GeneratorWarning from rigging.generator.base import ( Fixup, GeneratedMessage, @@ -206,7 +208,7 @@ async def supports_function_calling(self) -> bool | None: ): self._supports_function_calling = True except Exception as e: # noqa: BLE001 - logger.warning(f"Failed to check for function calling support: {e}") + logger.warning(f"Error checking for function calling support: {e}") span.set_attribute("error", str(e)) span.set_attribute("supports_function_calling", self._supports_function_calling) @@ -236,6 +238,37 @@ async def _ensure_delay_between_requests(self) -> None: # This seems like a brittle feature at the moment, so we'll # leave it out for now. + def _warn_on_input_truncation( + self, messages: list[Message], response: "GeneratedMessage" + ) -> None: + # Ollama has a known behavior where it performs silent truncation + # of input messages rather than return an error or any API indication. + # + # This code attempts to detect such truncation by comparing the expected + # input length with the reported usage - but it's not foolproof. + # + # See: + # - https://github.com/ollama/ollama/issues/7043 + # - https://github.com/ollama/ollama/issues/7987 + # - https://github.com/ollama/ollama/issues/4967 + + # We can't check with usage info + if not response.usage: + return + + # Get a general view of how long we might expect the input prompt to + # We'll use a gracious 4 char per token estimate + input_tokens_estimate = int(sum(len(message.content) for message in messages) / 4) + + # Check if the response reports that accepted input tokens are less than this + if response.usage.input_tokens < input_tokens_estimate: + warnings.warn( + f"Input messages may have been truncated - see https://github.com/ollama/ollama/issues/7043 " + f"(input tokens: {response.usage.input_tokens} < estimate: {input_tokens_estimate})", + GeneratorWarning, + stacklevel=2, + ) + def _parse_model_response( self, response: "ModelResponse", @@ -359,7 +392,9 @@ async def _generate_message( ) self._last_request_time = datetime.datetime.now(tz=datetime.timezone.utc) - return self._parse_model_response(response) + generated = self._parse_model_response(response) + self._warn_on_input_truncation(list(messages), generated) + return generated async def _generate_text(self, text: str, params: GenerateParams) -> GeneratedText: import litellm diff --git a/rigging/model.py b/rigging/model.py index bb145cf..aaa6570 100644 --- a/rigging/model.py +++ b/rigging/model.py @@ -99,16 +99,24 @@ def __init_subclass__( ) cls.__xml_tag__ = tag or XmlTagDescriptor() # type: ignore [assignment] + @classmethod + def _get_field_xml_name(cls, field_name: str, field_info: t.Any) -> str: + """Get the effective XML name for a field, considering aliases and XML paths.""" + if isinstance(field_info, XmlEntityInfo) and field_info.path: + return field_info.path + + # Check for field alias + if hasattr(field_info, "alias") and field_info.alias: + return field_info.alias + + return field_name + def _postprocess_with_cdata(self, tree: ET.Element) -> ET.Element: # Walk the first elements down and find any that align with str-based fields # If so, we should encode them as CDATA to avoid escaping issues basic_fields = { - ( - field_info.path - if isinstance(field_info, XmlEntityInfo) and field_info.path - else field_name - ): field_info + self._get_field_xml_name(field_name, field_info): field_info for field_name, field_info in self.__class__.model_fields.items() } @@ -357,7 +365,7 @@ def xml_example(cls) -> str: isinstance(field_info, XmlEntityInfo) and field_info.location == EntityLocation.ATTRIBUTE ): - path = field_info.path or field_name + path = cls._get_field_xml_name(field_name, field_info) example = str(next(iter(field_info.examples or []), "")).replace('"', """) attribute_parts.append(f'{path}="{example}"') else: @@ -367,7 +375,7 @@ def xml_example(cls) -> str: lines.append(f"<{cls.__xml_tag__}{attr_string}>") for field_name, field_info in element_fields.items(): - path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name + path = cls._get_field_xml_name(field_name, field_info) description = field_info.description example = str(next(iter(field_info.examples or []), "")) @@ -421,11 +429,7 @@ def preprocess_with_cdata(cls, content: str) -> str: } else: field_map = { - ( - field_info.path - if isinstance(field_info, XmlEntityInfo) and field_info.path - else field_name - ): field_info + cls._get_field_xml_name(field_name, field_info): field_info for field_name, field_info in cls.model_fields.items() if isinstance(field_info, XmlEntityInfo) and field_info.location == EntityLocation.ELEMENT @@ -600,7 +604,7 @@ def from_text( # Walk through any fields which are strings, and dedent them for field_name, field_info in cls.model_fields.items(): - if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721 + if isinstance(field_info, XmlEntityInfo) and field_info.annotation is str: model.__dict__[field_name] = textwrap.dedent( model.__dict__[field_name] ).strip() @@ -608,7 +612,7 @@ def from_text( extracted.append((model, slice_)) except Exception as e: # noqa: BLE001 extracted.append((e, slice_)) - continue + continue # sort back to original order extracted.sort(key=lambda x: x[1].start) @@ -817,12 +821,16 @@ def make_from_schema( for field_name, field_schema in properties.items(): field_type, field_info = _process_field(field_name, field_schema) + # Use the field name as alias if it differs from python naming conventions + alias = field_name if field_name != field_name.replace("-", "_") else None + fields[field_name] = ( field_type, field_cls( default=... if field_name in required else None, description=field_schema.get("description", ""), title=field_schema.get("title", ""), + alias=alias, **field_info, ) if isinstance(field_info, dict)