Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def generate_response(
system_prompt=llm_request.system_prompt,
user_token=user_token,
client_headers=client_headers,
streaming=streaming,
)
history = CacheEntry.cache_entries_to_history(previous_input)
if streaming:
Expand Down
117 changes: 62 additions & 55 deletions ols/app/endpoints/streaming_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from fastapi import APIRouter, Depends, status
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage

from ols import config, constants
from ols.app.endpoints.ols import (
Expand All @@ -27,6 +26,7 @@
)
from ols.app.models.models import (
Attachment,
ChunkType,
ErrorResponse,
ForbiddenResponse,
LLMRequest,
Expand All @@ -48,9 +48,13 @@
auth_dependency = get_auth_dependency(config.ols_config, virtual_path="/ols-access")


LLM_TOKEN_EVENT = "token" # noqa: S105
LLM_TOOL_CALL_EVENT = "tool_call"
LLM_TOOL_RESULT_EVENT = "tool_result"
STREAM_KEY_EVENT = "event"
STREAM_KEY_DATA = "data"
TOKEN_KEY_ID = "id" # noqa: S105
TOKEN_KEY_TOKEN = "token" # noqa: S105
END_KEY_RAG_CHUNKS = "rag_chunks"
END_KEY_TRUNCATED = "truncated"
END_KEY_TOKEN_COUNTER = "token_counter" # noqa: S105


query_responses: dict[int | str, dict[str, Any]] = {
Expand Down Expand Up @@ -123,7 +127,7 @@ def conversation_request(
)


def format_stream_data(d: dict) -> str:
def format_stream_data(d: dict[str, object]) -> str:
"""Format outbound data in the Event Stream Format."""
data = json.dumps(d)
return f"data: {data}\n\n"
Expand All @@ -145,7 +149,7 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_event(data: dict, event_type: str, media_type: str) -> str:
def stream_event(data: dict[str, object], event_type: str, media_type: str) -> str:
"""Build an item to yield based on media type.

Args:
Expand All @@ -157,18 +161,20 @@ def stream_event(data: dict, event_type: str, media_type: str) -> str:
str: The formatted string or JSON to yield.
"""
if media_type == MEDIA_TYPE_TEXT:
if event_type == LLM_TOKEN_EVENT:
return data["token"]
if event_type == LLM_TOOL_CALL_EVENT:
return f"\nTool call: {json.dumps(data)}\n"
if event_type == LLM_TOOL_RESULT_EVENT:
return f"\nTool result: {json.dumps(data)}\n"
logger.error("Unknown event type: %s", event_type)
return ""
match event_type:
case _ if event_type == TOKEN_KEY_TOKEN:
return data[TOKEN_KEY_TOKEN]
case ChunkType.TOOL_CALL.value:
return f"\nTool call: {json.dumps(data)}\n"
case ChunkType.TOOL_RESULT.value:
return f"\nTool result: {json.dumps(data)}\n"
case _:
logger.error("Unknown event type: %s", event_type)
return ""
return format_stream_data(
{
"event": event_type,
"data": data,
STREAM_KEY_EVENT: event_type,
STREAM_KEY_DATA: data,
}
)

Expand Down Expand Up @@ -288,8 +294,8 @@ def store_data(
conversation_id: str,
llm_request: LLMRequest,
response: str,
tool_calls: list[dict],
tool_results: list[ToolMessage],
tool_calls: list[dict[str, object]],
tool_results: list[dict[str, object]],
attachments: list[Attachment],
query_without_attachments: str,
rag_chunks: list[RagChunk],
Expand Down Expand Up @@ -342,7 +348,7 @@ def store_data(


async def response_processing_wrapper(
generator: AsyncGenerator[Any, None],
generator: AsyncGenerator[StreamedChunk, None],
user_id: str,
conversation_id: str,
llm_request: LLMRequest,
Expand Down Expand Up @@ -372,9 +378,9 @@ async def response_processing_wrapper(
yield stream_start_event(conversation_id)

response: str = ""
rag_chunks: list = []
tool_calls: list = []
tool_results: list = []
rag_chunks: list[RagChunk] = []
tool_calls: list[dict[str, object]] = []
tool_results: list[dict[str, object]] = []
history_truncated: bool = False
idx: int = 0
token_counter: Optional[TokenCounter] = None
Expand All @@ -385,39 +391,40 @@ async def response_processing_wrapper(
msg = f"Expecting StreamedChunk, but got {type(item)}: {item}"
logger.error(msg)
raise ValueError(msg)
if item.type == "tool_call":
tool_calls.append(item.data)
yield stream_event(
data=item.data,
event_type=LLM_TOOL_CALL_EVENT,
media_type=media_type,
)
elif item.type == "tool_result":
tool_results.append(item.data)
yield stream_event(
data=item.data,
event_type=LLM_TOOL_RESULT_EVENT,
media_type=media_type,
)
elif item.type == "text":
response += item.text
yield stream_event(
data={"id": idx, "token": item.text},
event_type=LLM_TOKEN_EVENT,
media_type=media_type,
)
idx += 1
elif item.type == "end":
rag_chunks = item.data["rag_chunks"]
history_truncated = item.data["truncated"]
token_counter = item.data["token_counter"]
else:
msg = (
"Yielded unknown item type from streaming generator, "
f"item: {item}"
)
logger.error(msg)
raise ValueError(msg)
match item.type:
case ChunkType.TOOL_CALL:
tool_calls.append(item.data)
yield stream_event(
data=item.data,
event_type=ChunkType.TOOL_CALL.value,
media_type=media_type,
)
case ChunkType.TOOL_RESULT:
tool_results.append(item.data)
yield stream_event(
data=item.data,
event_type=ChunkType.TOOL_RESULT.value,
media_type=media_type,
)
case ChunkType.TEXT:
response += item.text
yield stream_event(
data={TOKEN_KEY_ID: idx, TOKEN_KEY_TOKEN: item.text},
event_type=TOKEN_KEY_TOKEN,
media_type=media_type,
)
idx += 1
case ChunkType.END:
rag_chunks = item.data[END_KEY_RAG_CHUNKS]
history_truncated = item.data[END_KEY_TRUNCATED]
token_counter = item.data[END_KEY_TOKEN_COUNTER]
case _:
msg = (
"Yielded unknown item type from streaming generator, "
f"item: {item}"
)
logger.error(msg)
raise ValueError(msg)
except PromptTooLongError as summarizer_error:
yield prompt_too_long_error(summarizer_error, media_type)
return # stop execution after error
Expand Down
12 changes: 11 additions & 1 deletion ols/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from collections import OrderedDict
from dataclasses import field
from enum import Enum
from typing import Any, Literal, Optional, Self, Union

from langchain_core.language_models.llms import LLM
Expand Down Expand Up @@ -982,6 +983,15 @@ class ProcessedRequest(BaseModel):
user_token: str


class ChunkType(str, Enum):
"""Supported streamed chunk types."""

TEXT = "text"
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
END = "end"


@dataclass
class StreamedChunk:
"""Represents a chunk of streamed data from the LLM.
Expand All @@ -992,7 +1002,7 @@ class StreamedChunk:
data: Additional data associated with the chunk (for non-text chunks)
"""

type: Literal["text", "tool_call", "tool_result", "end"]
type: ChunkType
text: str = ""
data: dict[str, Any] = field(default_factory=dict)

Expand Down
Loading