Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/agents/agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,21 @@

# @backoff.on_exception(backoff.expo, litellm.OpenAIError, max_tries=100)
def completion_with_backoff(**kwargs):
litellm.api_key = os.environ["OPENAI_API_KEY"]
litellm.api_base = os.environ.get("OPENAI_BASE_URL")

if os.environ.get("OPENAI_API_KEY") is None:
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_API_KEY")
if api_key is None:
raise ValueError("OPENAI_API_KEY is not set")

api_type = os.environ.get("OPENAI_API_TYPE", "openai")
api_base = os.environ.get("OPENAI_BASE_URL") or os.environ.get("AZURE_API_BASE")
api_version = os.environ.get("OPENAI_API_VERSION") or os.environ.get("AZURE_API_VERSION")

litellm.api_key = api_key
litellm.api_base = api_base
litellm.api_type = api_type
# Azure endpoints require an explicit api_version
if api_type == "azure" and api_version:
litellm.api_version = api_version

while True:
try:
return litellm.completion(**kwargs)
Expand All @@ -58,10 +67,16 @@ def __init__(self, config_path_or_dict: Union[str, dict] = None):
self.temperature: float = self.config_dict.get("temperature", 0.3)
self.log_path: str = self.config_dict.get("log_path", "logs")
self.API_KEY: str = self.config_dict.get(
"OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]
"OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_API_KEY")
)
self.API_BASE = self.config_dict.get(
"OPENAI_BASE_URL", os.environ.get("OPENAI_BASE_URL")
"OPENAI_BASE_URL", os.environ.get("OPENAI_BASE_URL") or os.environ.get("AZURE_API_BASE")
)
self.API_TYPE: str = self.config_dict.get(
"OPENAI_API_TYPE", os.environ.get("OPENAI_API_TYPE", "openai")
)
self.API_VERSION: str = self.config_dict.get(
"OPENAI_API_VERSION", os.environ.get("OPENAI_API_VERSION") or os.environ.get("AZURE_API_VERSION")
)
self.MAX_CHAT_MESSAGES: int = self.config_dict.get("max_chat_messages", 10)
self.ACTIVE_MODE: bool = self.config_dict.get("ACTIVE_MODE", False)
Expand All @@ -76,6 +91,8 @@ def __init__(self, config: LLMConfig) -> None:
self.log_path = self.config.log_path
self.API_KEY = self.config.API_KEY
self.API_BASE = self.config.API_BASE
self.api_type = self.config.API_TYPE
self.api_version = self.config.API_VERSION
self.MAX_CHAT_MESSAGES = self.config.MAX_CHAT_MESSAGES
self.ACTIVE_MODE = self.config.ACTIVE_MODE
self.SAVE_LOGS = self.config.SAVE_LOGS
Expand Down Expand Up @@ -123,6 +140,11 @@ def get_response(
litellm.api_key = self.API_KEY
if self.API_BASE:
litellm.api_base = self.API_BASE
litellm.api_type = self.api_type
if self.api_type == "azure" and self.api_version:
litellm.api_version = self.api_version

provider = "azure" if self.api_type == "azure" else "openai"

messages = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
Expand Down Expand Up @@ -152,7 +174,7 @@ def get_response(
tool_choice=tool_choice,
temperature=self.temperature,
response_format=response_format,
custom_llm_provider="openai",
custom_llm_provider=provider,
)
else:
response = completion_with_backoff(
Expand All @@ -161,7 +183,7 @@ def get_response(
temperature=self.temperature,
stream=stream,
response_format=response_format,
custom_llm_provider="openai",
custom_llm_provider=provider,
)

if response.choices[0].message.get("tool_calls"):
Expand Down
12 changes: 6 additions & 6 deletions src/agents/utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@
LAST_PROMPT_TEMPLATE = "{last}"

DEFAULT_NODE_PROMPT_TEMPLATES = {
# "style": STYLE_PROMPT_TEMPLATE,
# "task": TASK_PROMPT_TEMPLATE,
# "rule": RULE_PROMPT_TEMPLATE,
# "demonstrations": DEMONSTRATIONS_PROMPT_TEMPLATE,
# "output": OUTPUT_PROMPT_TEMPLATE,
# "last": LAST_PROMPT_TEMPLATE,
"style": STYLE_PROMPT_TEMPLATE,
"task": TASK_PROMPT_TEMPLATE,
"rule": RULE_PROMPT_TEMPLATE,
"demonstrations": DEMONSTRATIONS_PROMPT_TEMPLATE,
"output": OUTPUT_PROMPT_TEMPLATE,
"last": LAST_PROMPT_TEMPLATE,
}

# The following prompt templates are used in the config generation
Expand Down