Skip to content

Commit 01147bd

Browse files
committed
perf: optimize get_configurations
1 parent 3cb944f commit 01147bd

1 file changed

Lines changed: 53 additions & 51 deletions

File tree

api/core/provider_manager.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
100100
:param tenant_id:
101101
:return:
102102
"""
103+
session = db.session
103104
# Get all provider records of the workspace
104-
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
105+
provider_name_to_provider_records_dict = self._get_all_providers(session, tenant_id)
105106

106107
# Initialize trial provider records if not exist
107108
provider_name_to_provider_records_dict = self._init_trial_provider_records(
@@ -118,7 +119,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
118119
]
119120

120121
# Get all provider model records of the workspace
121-
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
122+
provider_name_to_provider_model_records_dict = self._get_all_provider_models(session, tenant_id)
122123
for provider_name in list(provider_name_to_provider_model_records_dict.keys()):
123124
provider_id = ModelProviderID(provider_name)
124125
if str(provider_id) not in provider_name_to_provider_model_records_dict:
@@ -131,7 +132,9 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
131132
provider_entities = model_provider_factory.get_providers()
132133

133134
# Get All preferred provider types of the workspace
134-
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
135+
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(
136+
session, tenant_id
137+
)
135138
# Ensure that both the original provider name and its ModelProviderID string representation
136139
# are present in the dictionary to handle cases where either form might be used
137140
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
@@ -143,15 +146,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
143146
)
144147

145148
# Get All provider model settings
146-
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
149+
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(session, tenant_id)
147150

148151
# Get All load balancing configs
149152
provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs(
150-
tenant_id
153+
session, tenant_id
151154
)
152155

153156
# Get All provider model credentials
154-
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)
157+
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(session, tenant_id)
155158

156159
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
157160

@@ -403,88 +406,88 @@ def update_default_model_record(
403406
return default_model
404407

405408
@staticmethod
406-
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
409+
def _get_all_providers(session: Session, tenant_id: str) -> dict[str, list[Provider]]:
407410
provider_name_to_provider_records_dict = defaultdict(list)
408-
with Session(db.engine, expire_on_commit=False) as session:
409-
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
410-
providers = session.scalars(stmt)
411-
for provider in providers:
412-
# Use provider name with prefix after the data migration
413-
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
411+
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
412+
providers = session.scalars(stmt)
413+
for provider in providers:
414+
# Use provider name with prefix after the data migration
415+
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
414416
return provider_name_to_provider_records_dict
415417

416418
@staticmethod
417-
def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]:
419+
def _get_all_provider_models(session: Session, tenant_id: str) -> dict[str, list[ProviderModel]]:
418420
"""
419421
Get all provider model records of the workspace.
420422
421423
:param tenant_id: workspace id
422424
:return:
423425
"""
424426
provider_name_to_provider_model_records_dict = defaultdict(list)
425-
with Session(db.engine, expire_on_commit=False) as session:
426-
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
427-
provider_models = session.scalars(stmt)
428-
for provider_model in provider_models:
429-
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
427+
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
428+
provider_models = session.scalars(stmt)
429+
for provider_model in provider_models:
430+
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
430431
return provider_name_to_provider_model_records_dict
431432

432433
@staticmethod
433-
def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
434+
def _get_all_preferred_model_providers(session: Session, tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
434435
"""
435436
Get All preferred provider types of the workspace.
436437
437438
:param tenant_id: workspace id
438439
:return:
439440
"""
440441
provider_name_to_preferred_provider_type_records_dict = {}
441-
with Session(db.engine, expire_on_commit=False) as session:
442-
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
443-
preferred_provider_types = session.scalars(stmt)
444-
provider_name_to_preferred_provider_type_records_dict = {
445-
preferred_provider_type.provider_name: preferred_provider_type
446-
for preferred_provider_type in preferred_provider_types
447-
}
442+
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
443+
preferred_provider_types = session.scalars(stmt)
444+
provider_name_to_preferred_provider_type_records_dict = {
445+
preferred_provider_type.provider_name: preferred_provider_type
446+
for preferred_provider_type in preferred_provider_types
447+
}
448448
return provider_name_to_preferred_provider_type_records_dict
449449

450450
@staticmethod
451-
def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
451+
def _get_all_provider_model_settings(session: Session, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
452452
"""
453453
Get All provider model settings of the workspace.
454454
455455
:param tenant_id: workspace id
456456
:return:
457457
"""
458458
provider_name_to_provider_model_settings_dict = defaultdict(list)
459-
with Session(db.engine, expire_on_commit=False) as session:
460-
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
461-
provider_model_settings = session.scalars(stmt)
462-
for provider_model_setting in provider_model_settings:
463-
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
464-
provider_model_setting
465-
)
459+
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
460+
provider_model_settings = session.scalars(stmt)
461+
for provider_model_setting in provider_model_settings:
462+
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
463+
provider_model_setting
464+
)
465+
466466
return provider_name_to_provider_model_settings_dict
467467

468468
@staticmethod
469-
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
469+
def _get_all_provider_model_credentials(
470+
session: Session, tenant_id: str
471+
) -> dict[str, list[ProviderModelCredential]]:
470472
"""
471473
Get All provider model credentials of the workspace.
472474
473475
:param tenant_id: workspace id
474476
:return:
475477
"""
476478
provider_name_to_provider_model_credentials_dict = defaultdict(list)
477-
with Session(db.engine, expire_on_commit=False) as session:
478-
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
479-
provider_model_credentials = session.scalars(stmt)
480-
for provider_model_credential in provider_model_credentials:
481-
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
482-
provider_model_credential
483-
)
479+
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
480+
provider_model_credentials = session.scalars(stmt)
481+
for provider_model_credential in provider_model_credentials:
482+
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
483+
provider_model_credential
484+
)
484485
return provider_name_to_provider_model_credentials_dict
485486

486487
@staticmethod
487-
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
488+
def _get_all_provider_load_balancing_configs(
489+
session: Session, tenant_id: str
490+
) -> dict[str, list[LoadBalancingModelConfig]]:
488491
"""
489492
Get All provider load balancing configs of the workspace.
490493
@@ -504,13 +507,12 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L
504507
return {}
505508

506509
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
507-
with Session(db.engine, expire_on_commit=False) as session:
508-
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
509-
provider_load_balancing_configs = session.scalars(stmt)
510-
for provider_load_balancing_config in provider_load_balancing_configs:
511-
provider_name_to_provider_load_balancing_model_configs_dict[
512-
provider_load_balancing_config.provider_name
513-
].append(provider_load_balancing_config)
510+
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
511+
provider_load_balancing_configs = session.scalars(stmt)
512+
for provider_load_balancing_config in provider_load_balancing_configs:
513+
provider_name_to_provider_load_balancing_model_configs_dict[
514+
provider_load_balancing_config.provider_name
515+
].append(provider_load_balancing_config)
514516

515517
return provider_name_to_provider_load_balancing_model_configs_dict
516518

0 commit comments

Comments
 (0)