Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@
},
{
"path": "detect_secrets.filters.heuristic.is_templated_secret"
},
{
"path": "detect_secrets.filters.regex.should_exclude_file",
"pattern": [
"examples/*"
]
}
],
"results": {
Expand All @@ -133,14 +139,14 @@
"filename": "docs/topics/generators.mdx",
"hashed_secret": "ef5225a03e4f9cc953ab3c4dd41f5c4db7dc2e5b",
"is_verified": false,
"line_number": 342
"line_number": 360
},
{
"type": "Secret Keyword",
"filename": "docs/topics/generators.mdx",
"hashed_secret": "eb6256c862c356b375aafa760fa1851e33aa62a9",
"is_verified": false,
"line_number": 366
"line_number": 384
}
],
"tests/test_http_spec.py": [
Expand All @@ -153,5 +159,5 @@
}
]
},
"generated_at": "2025-07-23T00:26:51Z"
"generated_at": "2025-09-03T21:56:13Z"
}
8 changes: 8 additions & 0 deletions docs/api/error.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ max_rounds = max_rounds

The number of rounds which was exceeded.

GeneratorWarning
----------------

Base class for all generator warnings.

This is used to indicate that something unexpected happened during the generator execution,
but it is not critical enough to stop the execution.

InvalidGeneratorError
---------------------

Expand Down
13 changes: 9 additions & 4 deletions docs/api/generator.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,9 @@ get\_generator

```python
get_generator(
identifier: str, *, params: GenerateParams | None = None
identifier: str,
*,
params: GenerateParams | dict[str, Any] | None = None,
) -> Generator
```

Expand Down Expand Up @@ -1964,7 +1966,7 @@ You can also specify arguments to the generator by comma-separating them:
(`str`)
–The identifier string to use to get a generator.
* **`params`**
(`GenerateParams | None`, default:
(`GenerateParams | dict[str, Any] | None`, default:
`None`
)
–The generation parameters to use for the generator.
Expand All @@ -1982,8 +1984,9 @@ You can also specify arguments to the generator by comma-separating them:

<Accordion title="Source code in rigging/generator/base.py" icon="code">
```python
@lru_cache(maxsize=128)
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
def get_generator(
identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None
) -> Generator:
"""
Get a generator by an identifier string. Uses LiteLLM by default.

Expand Down Expand Up @@ -2080,6 +2083,8 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
if isinstance(v, str) and v.lower() in ["true", "false"]:
init_kwargs[k] = v.lower() == "true"

params = GenerateParams(**params) if isinstance(params, dict) else params

try:
merged_params = GenerateParams(**kwargs).merge_with(params)
except Exception as e:
Expand Down
18 changes: 9 additions & 9 deletions docs/api/model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,15 @@ def from_text(
# Walk through any fields which are strings, and dedent them

for field_name, field_info in cls.model_fields.items():
if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721
if isinstance(field_info, XmlEntityInfo) and field_info.annotation is str:
model.__dict__[field_name] = textwrap.dedent(
model.__dict__[field_name]
).strip()

extracted.append((model, slice_))
except Exception as e: # noqa: BLE001
extracted.append((e, slice_))
continue
continue

# sort back to original order
extracted.sort(key=lambda x: x[1].start)
Expand Down Expand Up @@ -471,11 +471,7 @@ def preprocess_with_cdata(cls, content: str) -> str:
}
else:
field_map = {
(
field_info.path
if isinstance(field_info, XmlEntityInfo) and field_info.path
else field_name
): field_info
cls._get_field_xml_name(field_name, field_info): field_info
for field_name, field_info in cls.model_fields.items()
if isinstance(field_info, XmlEntityInfo)
and field_info.location == EntityLocation.ELEMENT
Expand Down Expand Up @@ -722,7 +718,7 @@ def xml_example(cls) -> str:
isinstance(field_info, XmlEntityInfo)
and field_info.location == EntityLocation.ATTRIBUTE
):
path = field_info.path or field_name
path = cls._get_field_xml_name(field_name, field_info)
example = str(next(iter(field_info.examples or []), "")).replace('"', "&quot;")
attribute_parts.append(f'{path}="{example}"')
else:
Expand All @@ -732,7 +728,7 @@ def xml_example(cls) -> str:
lines.append(f"<{cls.__xml_tag__}{attr_string}>")

for field_name, field_info in element_fields.items():
path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name
path = cls._get_field_xml_name(field_name, field_info)
description = field_info.description
example = str(next(iter(field_info.examples or []), ""))

Expand Down Expand Up @@ -912,12 +908,16 @@ def make_from_schema(
for field_name, field_schema in properties.items():
field_type, field_info = _process_field(field_name, field_schema)

# Use the field name as alias if it differs from python naming conventions
alias = field_name if field_name != field_name.replace("-", "_") else None

fields[field_name] = (
field_type,
field_cls(
default=... if field_name in required else None,
description=field_schema.get("description", ""),
title=field_schema.get("title", ""),
alias=alias,
**field_info,
)
if isinstance(field_info, dict)
Expand Down
26 changes: 22 additions & 4 deletions docs/topics/generators.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ openai/o3-mini
gemini/gemini-2.5-pro
claude-4-sonnet-latest
vllm_hosted/meta-llama/Llama-3.1-8B-Instruct
ollama/qwen3
ollama_chat/qwen3

openai/gpt-4,api_key=sk-1234
anthropic/claude-3-7-haiku-latest,stop=output:;---,seed=1337
Expand Down Expand Up @@ -184,12 +184,30 @@ NAME ID SIZE PROCESSOR UNTIL
qwen3:0.6b 7df6b6e09427 2.3 GB 100% GPU 4 minutes from now
```

Using this model in Rigging is as simple as using the `ollama/` or `ollama_chat/` prefixes:
<Warning>
Ollama is configured with a maximum context length on the server, by default 4096 tokens. This does not change depending on model and requires configuration to update.

If the input messages to the API would exceed this length, Ollama will silently truncate them to fit in the context window. This behavior can cause unexpected generation results due to missing context and is very difficult to detect in Rigging.

We make a best effort by monitoring model responses and checking if the reported input tokens is far less than the input messages we just sent. If observed, the following warning will be emitted.

```
GeneratorWarning: Input messages may have been truncated ...
```

When in doubt, monitor the Ollama server logs for the following:

```bash
... msg="truncating input prompt" limit=4096 prompt=6767 keep=4 new=409
```
</Warning>

Using this model in Rigging is as simple as using the `ollama_chat/` (recommended) or `ollama/` prefixes:

```python
import rigging as rg

qwen = rg.get_generator("ollama/qwen3:0.6b")
qwen = rg.get_generator("ollama_chat/qwen3:0.6b")

chat = await qwen.chat("Hello!").run()
print(chat.conversation)
Expand All @@ -211,7 +229,7 @@ If you are running the Ollama server somewhere besides localhost, just pass the

```python
qwen = rg.get_generator(
"ollama/qwen3:0.6b,api_base=http://remote-server:11434"
"ollama_chat/qwen3:0.6b,api_base=http://remote-server:11434"
)
```
</Note>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rigging"
version = "3.3.2"
version = "3.3.3"
description = "LLM Interaction Framework"
authors = ["Nick Landers <monoxgas@gmail.com>"]
license = "MIT"
Expand Down
9 changes: 9 additions & 0 deletions rigging/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ class TokenizerWarning(Warning):
"""


class GeneratorWarning(Warning):
"""
Base class for all generator warnings.

This is used to indicate that something unexpected happened during the generator execution,
but it is not critical enough to stop the execution.
"""


# System Exceptions


Expand Down
8 changes: 5 additions & 3 deletions rigging/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import functools
import inspect
import typing as t
from functools import lru_cache

from loguru import logger
from pydantic import (
Expand Down Expand Up @@ -738,8 +737,9 @@ def encode_value(val: t.Any) -> t.Any:
return identifier


@lru_cache(maxsize=128)
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
def get_generator(
identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None
) -> Generator:
"""
Get a generator by an identifier string. Uses LiteLLM by default.

Expand Down Expand Up @@ -836,6 +836,8 @@ def decode_value(value: str) -> t.Any:
if isinstance(v, str) and v.lower() in ["true", "false"]:
init_kwargs[k] = v.lower() == "true"

params = GenerateParams(**params) if isinstance(params, dict) else params

try:
merged_params = GenerateParams(**kwargs).merge_with(params)
except Exception as e:
Expand Down
39 changes: 37 additions & 2 deletions rigging/generator/litellm_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import datetime
import re
import typing as t
import warnings

from loguru import logger

from rigging.error import GeneratorWarning
from rigging.generator.base import (
Fixup,
GeneratedMessage,
Expand Down Expand Up @@ -206,7 +208,7 @@ async def supports_function_calling(self) -> bool | None:
):
self._supports_function_calling = True
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to check for function calling support: {e}")
logger.warning(f"Error checking for function calling support: {e}")
span.set_attribute("error", str(e))

span.set_attribute("supports_function_calling", self._supports_function_calling)
Expand Down Expand Up @@ -236,6 +238,37 @@ async def _ensure_delay_between_requests(self) -> None:
# This seems like a brittle feature at the moment, so we'll
# leave it out for now.

def _warn_on_input_truncation(
self, messages: list[Message], response: "GeneratedMessage"
) -> None:
# Ollama has a known behavior where it performs silent truncation
# of input messages rather than return an error or any API indication.
#
# This code attempts to detect such truncation by comparing the expected
# input length with the reported usage - but it's not foolproof.
#
# See:
# - https://github.com/ollama/ollama/issues/7043
# - https://github.com/ollama/ollama/issues/7987
# - https://github.com/ollama/ollama/issues/4967

# We can't check with usage info
if not response.usage:
return

# Get a general view of how long we might expect the input prompt to
# We'll use a gracious 4 char per token estimate
input_tokens_estimate = int(sum(len(message.content) for message in messages) / 4)

# Check if the response reports that accepted input tokens are less than this
if response.usage.input_tokens < input_tokens_estimate:
warnings.warn(
f"Input messages may have been truncated - see https://github.com/ollama/ollama/issues/7043 "
f"(input tokens: {response.usage.input_tokens} < estimate: {input_tokens_estimate})",
GeneratorWarning,
stacklevel=2,
)

def _parse_model_response(
self,
response: "ModelResponse",
Expand Down Expand Up @@ -359,7 +392,9 @@ async def _generate_message(
)

self._last_request_time = datetime.datetime.now(tz=datetime.timezone.utc)
return self._parse_model_response(response)
generated = self._parse_model_response(response)
self._warn_on_input_truncation(list(messages), generated)
return generated

async def _generate_text(self, text: str, params: GenerateParams) -> GeneratedText:
import litellm
Expand Down
Loading
Loading