Skip to content

Commit 28969ba

Browse files
committed
refresh with latest develop
2 parents 6360fa2 + 8062a9c commit 28969ba

40 files changed

Lines changed: 750 additions & 117 deletions

File tree

flo_ai/flo_ai/arium/builder.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,12 @@ def from_yaml(
592592
nested_builder = cls.from_yaml(
593593
yaml_file=arium_node.yaml_file,
594594
memory=None,
595-
agents=None,
595+
agents=agents,
596596
routers=None,
597597
base_llm=base_llm,
598-
function_registry=None,
599-
tool_registry=None,
598+
function_registry=function_registry,
599+
tool_registry=tool_registry,
600+
**kwargs,
600601
)
601602
nested_arium = nested_builder.build()
602603

@@ -634,11 +635,12 @@ def from_yaml(
634635
nested_builder = cls.from_yaml(
635636
yaml_str=yaml.dump(sub_config),
636637
memory=None,
637-
agents=None,
638+
agents=agents,
638639
routers=None,
639640
base_llm=base_llm,
640-
function_registry=None,
641-
tool_registry=None,
641+
function_registry=function_registry,
642+
tool_registry=tool_registry,
643+
**kwargs,
642644
)
643645
nested_arium = nested_builder.build()
644646

flo_ai/flo_ai/arium/nodes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ async def run(
194194
variables: Optional[Dict[str, Any]] = None,
195195
**kwargs,
196196
) -> Any:
197-
logger.info(
198-
f"Executing FunctionNode '{self.name}' with inputs: {inputs} variables: {variables} kwargs: {kwargs}"
199-
)
200-
201197
if asyncio.iscoroutinefunction(self.function):
202198
logger.info(f"Executing FunctionNode '{self.name}' as a coroutine function")
203199
result = await self.function(

flo_ai/flo_ai/helpers/llm_factory.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class LLMFactory:
2525
'vertexai',
2626
'rootflo',
2727
'openai_vllm',
28+
'azure_openai',
2829
}
2930

3031
@staticmethod
@@ -57,6 +58,8 @@ def create_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM':
5758
return LLMFactory._create_vertexai_llm(model_config, **kwargs)
5859
elif provider == 'openai_vllm':
5960
return LLMFactory._create_openai_vllm_llm(model_config, **kwargs)
61+
elif provider == 'azure_openai':
62+
return LLMFactory._create_azure_openai_llm(model_config, **kwargs)
6063
else:
6164
return LLMFactory._create_standard_llm(provider, model_config, **kwargs)
6265

@@ -159,6 +162,60 @@ def _create_openai_vllm_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM'
159162
temperature=temperature,
160163
)
161164

165+
@staticmethod
166+
def _create_azure_openai_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM':
167+
"""Create Azure OpenAI LLM instance with endpoint and API version."""
168+
from flo_ai.llm import AzureOpenAI
169+
170+
model_name = model_config.name
171+
if not model_name:
172+
raise ValueError('azure_openai provider requires "name" parameter')
173+
174+
# Endpoint and API version
175+
azure_endpoint = (
176+
kwargs.get('azure_endpoint')
177+
or model_config.azure_endpoint
178+
or os.getenv('AZURE_OPENAI_ENDPOINT')
179+
)
180+
if not azure_endpoint:
181+
raise ValueError(
182+
'azure_openai configuration incomplete. Missing required parameter: '
183+
'azure_endpoint. Provide it in model_config, as a kwarg, or via '
184+
'AZURE_OPENAI_ENDPOINT environment variable.'
185+
)
186+
187+
api_key = (
188+
kwargs.get('api_key')
189+
or model_config.api_key
190+
or os.getenv('AZURE_OPENAI_API_KEY')
191+
)
192+
if not api_key:
193+
raise ValueError(
194+
'azure_openai configuration incomplete. Missing required parameter: '
195+
'api_key. Provide it in model_config, as a kwarg, or via '
196+
'AZURE_OPENAI_API_KEY environment variable.'
197+
)
198+
199+
api_version = (
200+
kwargs.get('azure_api_version')
201+
or model_config.azure_api_version
202+
or os.getenv('AZURE_OPENAI_API_VERSION')
203+
or '2024-12-01-preview'
204+
)
205+
206+
temperature = kwargs.get(
207+
'temperature',
208+
model_config.temperature if model_config.temperature is not None else 0.7,
209+
)
210+
211+
return AzureOpenAI(
212+
model=model_name,
213+
api_key=str(api_key),
214+
azure_endpoint=str(azure_endpoint),
215+
api_version=str(api_version),
216+
temperature=temperature,
217+
)
218+
162219
@staticmethod
163220
def _create_rootflo_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM':
164221
"""Create RootFlo LLM instance with authentication."""

flo_ai/flo_ai/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .vertexai_llm import VertexAI
88
from .rootflo_llm import RootFloLLM
99
from .aws_bedrock_llm import AWSBedrock
10+
from .azure_openai_llm import AzureOpenAI
1011

1112
__all__ = [
1213
'BaseLLM',
@@ -18,4 +19,5 @@
1819
'VertexAI',
1920
'RootFloLLM',
2021
'AWSBedrock',
22+
'AzureOpenAI',
2123
]
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from typing import Dict, Any, List, AsyncIterator, Optional
2+
3+
from openai import AsyncAzureOpenAI
4+
5+
from .base_llm import BaseLLM
6+
from flo_ai.models.chat_message import ImageMessageContent
7+
from flo_ai.tool.base_tool import Tool
8+
from flo_ai.telemetry.instrumentation import (
9+
trace_llm_call,
10+
trace_llm_stream,
11+
llm_metrics,
12+
add_span_attributes,
13+
)
14+
from flo_ai.telemetry import get_tracer
15+
from opentelemetry import trace
16+
17+
18+
class AzureOpenAI(BaseLLM):
19+
def __init__(
20+
self,
21+
model: str,
22+
api_key: Optional[str],
23+
azure_endpoint: str,
24+
api_version: str = '2024-12-01-preview',
25+
temperature: float = 0.7,
26+
custom_headers: Optional[Dict[str, str]] = None,
27+
**kwargs,
28+
):
29+
"""
30+
Azure OpenAI LLM implementation using the AsyncAzureOpenAI client.
31+
32+
Args:
33+
model: Azure deployment name (passed as `model` to chat.completions.create)
34+
api_key: Azure OpenAI API key
35+
azure_endpoint: Azure endpoint URL, e.g. https://<resource>.cognitiveservices.azure.com/
36+
api_version: Azure OpenAI API version
37+
temperature: Sampling temperature
38+
custom_headers: Optional additional headers to send with each request
39+
**kwargs: Extra parameters forwarded to the SDK client / calls
40+
"""
41+
super().__init__(
42+
model=model, api_key=api_key, temperature=temperature, **kwargs
43+
)
44+
self.client = AsyncAzureOpenAI(
45+
api_key=self.api_key,
46+
azure_endpoint=azure_endpoint,
47+
api_version=api_version,
48+
default_headers=custom_headers,
49+
**kwargs,
50+
)
51+
self.model = model
52+
self.kwargs = kwargs
53+
54+
@trace_llm_call(provider='azureopenai')
55+
async def generate(
56+
self,
57+
messages: List[Dict[str, Any]],
58+
functions: Optional[List[Dict[str, Any]]] = None,
59+
output_schema: Optional[Dict[str, Any]] = None,
60+
**kwargs,
61+
) -> Any:
62+
# Handle structured output vs tool calling
63+
if output_schema:
64+
kwargs['response_format'] = {'type': 'json_object'}
65+
kwargs['functions'] = [
66+
{
67+
'name': output_schema.get('title', 'default'),
68+
'parameters': output_schema.get('schema', output_schema),
69+
}
70+
]
71+
kwargs['function_call'] = {'name': output_schema.get('title', 'default')}
72+
73+
if messages and messages[0]['role'] == 'system':
74+
messages[0]['content'] = (
75+
messages[0]['content']
76+
+ '\n\nPlease provide your response in JSON format according to the specified schema.'
77+
)
78+
else:
79+
messages.insert(
80+
0,
81+
{
82+
'role': 'system',
83+
'content': 'Please provide your response in JSON format according to the specified schema.',
84+
},
85+
)
86+
elif functions:
87+
kwargs['functions'] = functions
88+
89+
azure_kwargs = {
90+
'model': self.model,
91+
'messages': messages,
92+
'temperature': self.temperature,
93+
**self.kwargs,
94+
**kwargs,
95+
}
96+
97+
response = await self.client.chat.completions.create(**azure_kwargs)
98+
message = response.choices[0].message
99+
100+
if hasattr(response, 'usage') and response.usage:
101+
usage = response.usage
102+
llm_metrics.record_tokens(
103+
total_tokens=usage.total_tokens,
104+
prompt_tokens=usage.prompt_tokens,
105+
completion_tokens=usage.completion_tokens,
106+
model=self.model,
107+
provider='azureopenai',
108+
)
109+
110+
tracer = get_tracer()
111+
if tracer:
112+
current_span = trace.get_current_span()
113+
add_span_attributes(
114+
current_span,
115+
{
116+
'llm.tokens.prompt': usage.prompt_tokens,
117+
'llm.tokens.completion': usage.completion_tokens,
118+
'llm.tokens.total': usage.total_tokens,
119+
},
120+
)
121+
122+
return message
123+
124+
@trace_llm_stream(provider='azureopenai')
125+
async def stream(
126+
self,
127+
messages: List[Dict[str, Any]],
128+
functions: Optional[List[Dict[str, Any]]] = None,
129+
**kwargs: Any,
130+
) -> AsyncIterator[Dict[str, Any]]:
131+
"""Stream partial responses from Azure OpenAI Chat Completions API."""
132+
azure_kwargs = {
133+
'model': self.model,
134+
'messages': messages,
135+
'temperature': self.temperature,
136+
'stream': True,
137+
**self.kwargs,
138+
**kwargs,
139+
}
140+
141+
if functions:
142+
azure_kwargs['functions'] = functions
143+
144+
response = await self.client.chat.completions.create(**azure_kwargs)
145+
async for chunk in response:
146+
choices = getattr(chunk, 'choices', []) or []
147+
for choice in choices:
148+
delta = getattr(choice, 'delta', None)
149+
if delta is None:
150+
continue
151+
content = getattr(delta, 'content', None)
152+
if content:
153+
yield {'content': content}
154+
155+
def get_message_content(self, response: Dict[str, Any]) -> str:
156+
if isinstance(response, str):
157+
return response
158+
if hasattr(response, 'content') and response.content is not None:
159+
return str(response.content)
160+
return str(response)
161+
162+
def format_tool_for_llm(self, tool: 'Tool') -> Dict[str, Any]:
163+
"""Format a single tool for Azure OpenAI's API (OpenAI-compatible)."""
164+
return {
165+
'name': tool.name,
166+
'description': tool.description,
167+
'parameters': {
168+
'type': 'object',
169+
'properties': {
170+
name: {
171+
'type': info['type'],
172+
'description': info['description'],
173+
**(
174+
{'items': info['items']}
175+
if info.get('type') == 'array' and 'items' in info
176+
else {}
177+
),
178+
}
179+
for name, info in tool.parameters.items()
180+
},
181+
'required': list(tool.parameters.keys()),
182+
},
183+
}
184+
185+
def format_tools_for_llm(self, tools: List['Tool']) -> List[Dict[str, Any]]:
186+
"""Format tools for Azure OpenAI's API (OpenAI-compatible)."""
187+
return [self.format_tool_for_llm(tool) for tool in tools]
188+
189+
def format_image_in_message(self, image: ImageMessageContent) -> list[dict]:
190+
"""
191+
Format an image in the message for Azure OpenAI.
192+
193+
Azure vision models expect the OpenAI-style `"image_url"` block, for example:
194+
{
195+
"type": "image_url",
196+
"image_url": { "url": "data:image/png;base64,..." }
197+
}
198+
"""
199+
import base64
200+
201+
# Remote URL
202+
if image.url:
203+
return [
204+
{
205+
'type': 'image_url',
206+
'image_url': {
207+
'url': image.url,
208+
},
209+
}
210+
]
211+
212+
# Raw base64 string or bytes – construct a data URL
213+
if image.base64 or image.bytes:
214+
if not image.mime_type:
215+
raise ValueError(
216+
'Image mime type is required for Azure OpenAI image messages'
217+
)
218+
219+
if image.base64:
220+
b64 = image.base64
221+
else:
222+
b64 = base64.b64encode(image.bytes or b'').decode('utf-8')
223+
224+
data_url = f'data:{image.mime_type};base64,{b64}'
225+
226+
return [
227+
{
228+
'type': 'image_url',
229+
'image_url': {
230+
'url': data_url,
231+
},
232+
}
233+
]
234+
235+
raise NotImplementedError(
236+
f'Image formatting for AzureOpenAI LLM requires either url, base64 data, or bytes. '
237+
f'Received: url={image.url}, base64={bool(image.base64)}, bytes={bool(image.bytes)}'
238+
)

0 commit comments

Comments
 (0)