Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions basilisk/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SystemMessage,
)
from basilisk.decorators import ensure_no_task_running
from basilisk.provider_engine.stream_chunk_type import StreamChunkType
from basilisk.sound_manager import play_sound, stop_sound
from basilisk.views.enhanced_error_dialog import show_enhanced_error_dialog

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
self._stop_completion = False
self.last_time = 0
self.stream_buffer: str = ""
self._reasoning_started: bool = False

@ensure_no_task_running
def start_completion(
Expand Down Expand Up @@ -171,25 +173,47 @@ def _handle_completion(self, engine: BaseEngine, **kwargs: dict[str, Any]):
if success:
wx.CallAfter(self._completion_finished_success)

def _handle_stream_chunk(
self, chunk: str | tuple[str, Any], message_block: MessageBlock
):
if isinstance(chunk, str):
self.stream_buffer += chunk
elif isinstance(chunk, tuple):
chunk_type, chunk_data = chunk
if chunk_type == "citation":
if not message_block.response.citations:
message_block.response.citations = []
message_block.response.citations.append(chunk_data)
else:
logger.warning(
"Unknown chunk type in streaming response: %s", chunk_type
)

def _append_stream_and_maybe_flush(
self, message_block: MessageBlock, fragment: str
) -> None:
"""Append to the stream buffer and flush when a sentence boundary matches."""
self.stream_buffer += fragment
if RE_STREAM_BUFFER.match(self.stream_buffer):
self.flush_stream_buffer(message_block)

def _handle_stream_chunk(
self, chunk: tuple[str, Any], message_block: MessageBlock
) -> None:
chunk_type, chunk_data = chunk
if chunk_type == StreamChunkType.CITATION:
cits = message_block.response.citations
if cits is None:
cits = []
message_block.response.citations = cits
cits.append(chunk_data)
return
if chunk_type == StreamChunkType.REASONING:
if not self._reasoning_started:
self._reasoning_started = True
self._append_stream_and_maybe_flush(
message_block, f"```think\n{chunk_data}"
)
else:
self._append_stream_and_maybe_flush(message_block, chunk_data)
return
if chunk_type == StreamChunkType.CONTENT:
prefix = ""
if self._reasoning_started:
self._reasoning_started = False
prefix = "\n```\n\n"
self._append_stream_and_maybe_flush(
message_block, prefix + chunk_data
)
return
logger.warning(
"Unknown chunk type in streaming response: %s", chunk_type
)

def flush_stream_buffer(self, message_block: MessageBlock) -> None:
"""Flush the stream buffer to the message block."""
if self.stream_buffer:
Expand Down Expand Up @@ -218,18 +242,22 @@ def _handle_streaming_completion(
True if streaming was handled successfully, False if stopped
"""
new_block.response = Message(role=MessageRoleEnum.ASSISTANT, content="")
self._reasoning_started = False

# Notify that streaming has started
if self.on_stream_start:
wx.CallAfter(self.on_stream_start, new_block, system_message)

for chunk in engine.completion_response_with_stream(response):
for chunk in engine.completion_response_with_stream(
response, new_block
):
if self._stop_completion or global_vars.app_should_exit:
logger.debug("Stopping completion")
return False
self._handle_stream_chunk(chunk, new_block)

# Notify that streaming has finished
if self._reasoning_started:
self._reasoning_started = False
self.stream_buffer += "\n```"
self.flush_stream_buffer(new_block)
if self.on_stream_finish:
wx.CallAfter(self.on_stream_finish, new_block)
Expand Down Expand Up @@ -259,7 +287,6 @@ def _handle_non_streaming_completion(
response=response, new_block=new_block, **kwargs
)

# Notify that non-streaming completion has finished
if self.on_non_stream_finish:
wx.CallAfter(
self.on_non_stream_finish, completed_block, system_message
Expand All @@ -276,7 +303,6 @@ def _handle_stream_buffer(self, buffer: str):
if self.on_stream_chunk:
self.on_stream_chunk(buffer)

# Play periodic sound during streaming
new_time = time.time()
if new_time - self.last_time > 4:
play_sound("chat_response_pending")
Expand Down
16 changes: 15 additions & 1 deletion basilisk/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,35 @@
)
from .account_config import get_account_config as accounts
from .config_enums import (
ACCOUNT_MODEL_SORT_KEYS,
MODEL_SORT_KEYS,
AccountModelSortKeyEnum,
AccountSource,
AutomaticUpdateModeEnum,
KeyStorageMethodEnum,
LogLevelEnum,
ModelSortKeyEnum,
ReleaseChannelEnum,
)
from .conversation_profile import ConversationProfile
from .conversation_profile import (
get_conversation_profile_config as conversation_profiles,
)
from .main_config import BasiliskConfig
from .main_config import (
MODEL_METADATA_CACHE_TTL_HOURS_MAX,
MODEL_METADATA_CACHE_TTL_HOURS_MIN,
BasiliskConfig,
)
from .main_config import get_basilisk_config as conf

__all__ = [
"ACCOUNT_MODEL_SORT_KEYS",
"Account",
"AccountModelSortKeyEnum",
"MODEL_SORT_KEYS",
"ModelSortKeyEnum",
"MODEL_METADATA_CACHE_TTL_HOURS_MAX",
"MODEL_METADATA_CACHE_TTL_HOURS_MIN",
"AccountManager",
"AccountOrganization",
"AccountSource",
Expand Down
38 changes: 37 additions & 1 deletion basilisk/config/account_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
from basilisk.consts import APP_NAME
from basilisk.provider import Provider, get_provider, providers

from .config_enums import AccountSource, KeyStorageMethodEnum
from .config_enums import (
ACCOUNT_MODEL_SORT_KEYS,
AccountSource,
KeyStorageMethodEnum,
)
from .config_helper import (
BasiliskBaseSettings,
get_settings_config_dict,
Expand Down Expand Up @@ -141,6 +145,23 @@ class Account(BaseModel):
pattern=CUSTOM_BASE_URL_PATTERN,
description="Custom base URL for the API provider. Must be a valid HTTP/HTTPS URL.",
)
model_sort_key: Optional[str] = Field(
default=None,
description="Override default model sort key for this account. None = use preference.",
)
model_sort_reverse: Optional[bool] = Field(
default=None,
description="Override default model sort reverse for this account. None = use preference.",
)

@model_validator(mode="after")
def _validate_model_sort_key(self) -> "Account":
if (
self.model_sort_key is not None
and self.model_sort_key not in ACCOUNT_MODEL_SORT_KEYS
):
object.__setattr__(self, "model_sort_key", None)
return self

def __init__(self, **data: Any):
"""Initialize an account instance. If an error occurs, log the error and raise an exception."""
Expand Down Expand Up @@ -441,6 +462,21 @@ def get_account_from_info(self, value: AccountInfo) -> Optional[Account]:
return self.accounts[0]
return self.accounts[index]

def can_set_as_default(self, account: Optional[Account]) -> bool:
"""Check if the account can be set as default.

Any account (including from environment variables) can be set as default.

Args:
account: The account to check.

Returns:
True if the account exists and is not already the default.
"""
if not account or not isinstance(account, Account):
return False
return self.default_account != account

def set_default_account(self, value: Optional[Account]):
"""Set the default account for the configuration.

Expand Down
27 changes: 27 additions & 0 deletions basilisk/config/config_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,30 @@ def get_labels(cls) -> dict[AutomaticUpdateModeEnum, str]:
# Translators: A label for the automatic update mode in the settings dialog
cls.INSTALL: _("Install new version"),
}


class ModelSortKeyEnum(enum.StrEnum):
"""Sort keys for the global model list (preferences)."""

NONE = "none"
NAME = "name"
RELEASE_DATE = "release_date"
MAX_OUTPUT = "max_output"
CONTEXT_WINDOW = "context_window"


class AccountModelSortKeyEnum(enum.StrEnum):
"""Account override for model list sort. DEFAULT uses global preference."""

DEFAULT = "default"
NONE = "none"
NAME = "name"
RELEASE_DATE = "release_date"
MAX_OUTPUT = "max_output"
CONTEXT_WINDOW = "context_window"


MODEL_SORT_KEYS: tuple[str, ...] = tuple(m.value for m in ModelSortKeyEnum)
ACCOUNT_MODEL_SORT_KEYS: tuple[str, ...] = tuple(
m.value for m in AccountModelSortKeyEnum
)
70 changes: 64 additions & 6 deletions basilisk/config/conversation_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@ class ConversationProfile(BaseModel):
max_tokens: Optional[int] = Field(default=None)
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
frequency_penalty: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None)
seed: Optional[int] = Field(default=None)
top_k: Optional[int] = Field(default=None)
stop: Optional[list[str]] = Field(default=None)
stream_mode: bool = Field(default=True)
reasoning_mode: bool = Field(default=False)
reasoning_budget_tokens: Optional[int] = Field(default=None)
reasoning_effort: Optional[str] = Field(default=None)
reasoning_adaptive: bool = Field(default=False)

def __init__(self, **data: Any):
"""Initialize a conversation profile with the provided data.
Expand Down Expand Up @@ -236,6 +245,58 @@ def check_same_provider(self) -> ConversationProfile:
)
return self

def _validate_params_without_model(self) -> None:
"""Raise ValueError if any model param is set when ai_model_info is None."""
checks = [
(self.max_tokens, "Max tokens must be None without model"),
(self.temperature, "Temperature must be None without model"),
(self.top_p, "Top P must be None without model"),
(
self.frequency_penalty,
"Frequency penalty must be None without model",
),
(
self.presence_penalty,
"Presence penalty must be None without model",
),
(self.seed, "Seed must be None without model"),
(self.top_k, "Top K must be None without model"),
(self.stop, "Stop must be None without model"),
(
self.reasoning_budget_tokens,
_("Reasoning budget must be None without model"),
),
(
self.reasoning_effort,
_("Reasoning effort must be None without model"),
),
]
for val, msg in checks:
if val is not None:
raise ValueError(msg)

def _validate_reasoning_params_with_model(self) -> None:
"""Raise ValueError if reasoning params are invalid when model is present."""
if (
self.reasoning_budget_tokens is not None
and self.reasoning_budget_tokens < 0
):
# Translators: Error when reasoning budget tokens is negative
raise ValueError(
_(
"reasoning_budget_tokens must be None or a non-negative integer"
)
)
_REASONING_EFFORT_ALLOWED = ("minimal", "low", "medium", "high", "max")
if self.reasoning_effort is not None:
val = self.reasoning_effort.lower()
if val not in _REASONING_EFFORT_ALLOWED:
# Translators: Error when reasoning effort has an unsupported value
raise ValueError(
_("reasoning_effort must be None or one of: %s")
% ", ".join(_REASONING_EFFORT_ALLOWED)
)

@model_validator(mode="after")
def check_model_params(self) -> ConversationProfile:
"""Validates that model parameters are set correctly.
Expand All @@ -246,12 +307,9 @@ def check_model_params(self) -> ConversationProfile:
ValueError: If model parameters are set without an AI model.
"""
if self.ai_model_info is None:
if self.max_tokens is not None:
raise ValueError("Max tokens must be None without model")
if self.temperature is not None:
raise ValueError("Temperature must be None without model")
if self.top_p is not None:
raise ValueError("Top P must be None without model")
self._validate_params_without_model()
else:
self._validate_reasoning_params_with_model()
return self


Expand Down
Loading
Loading