@@ -100,8 +100,9 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
100100 :param tenant_id:
101101 :return:
102102 """
103+ session = db .session
103104 # Get all provider records of the workspace
104- provider_name_to_provider_records_dict = self ._get_all_providers (tenant_id )
105+ provider_name_to_provider_records_dict = self ._get_all_providers (session , tenant_id )
105106
106107 # Initialize trial provider records if not exist
107108 provider_name_to_provider_records_dict = self ._init_trial_provider_records (
@@ -118,7 +119,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
118119 ]
119120
120121 # Get all provider model records of the workspace
121- provider_name_to_provider_model_records_dict = self ._get_all_provider_models (tenant_id )
122+ provider_name_to_provider_model_records_dict = self ._get_all_provider_models (session , tenant_id )
122123 for provider_name in list (provider_name_to_provider_model_records_dict .keys ()):
123124 provider_id = ModelProviderID (provider_name )
124125 if str (provider_id ) not in provider_name_to_provider_model_records_dict :
@@ -131,7 +132,9 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
131132 provider_entities = model_provider_factory .get_providers ()
132133
133134 # Get All preferred provider types of the workspace
134- provider_name_to_preferred_model_provider_records_dict = self ._get_all_preferred_model_providers (tenant_id )
135+ provider_name_to_preferred_model_provider_records_dict = self ._get_all_preferred_model_providers (
136+ session , tenant_id
137+ )
135138 # Ensure that both the original provider name and its ModelProviderID string representation
136139 # are present in the dictionary to handle cases where either form might be used
137140 for provider_name in list (provider_name_to_preferred_model_provider_records_dict .keys ()):
@@ -143,15 +146,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
143146 )
144147
145148 # Get All provider model settings
146- provider_name_to_provider_model_settings_dict = self ._get_all_provider_model_settings (tenant_id )
149+ provider_name_to_provider_model_settings_dict = self ._get_all_provider_model_settings (session , tenant_id )
147150
148151 # Get All load balancing configs
149152 provider_name_to_provider_load_balancing_model_configs_dict = self ._get_all_provider_load_balancing_configs (
150- tenant_id
153+ session , tenant_id
151154 )
152155
153156 # Get All provider model credentials
154- provider_name_to_provider_model_credentials_dict = self ._get_all_provider_model_credentials (tenant_id )
157+ provider_name_to_provider_model_credentials_dict = self ._get_all_provider_model_credentials (session , tenant_id )
155158
156159 provider_configurations = ProviderConfigurations (tenant_id = tenant_id )
157160
@@ -403,88 +406,88 @@ def update_default_model_record(
403406 return default_model
404407
405408 @staticmethod
406- def _get_all_providers (tenant_id : str ) -> dict [str , list [Provider ]]:
409+ def _get_all_providers (session : Session , tenant_id : str ) -> dict [str , list [Provider ]]:
407410 provider_name_to_provider_records_dict = defaultdict (list )
408- with Session (db .engine , expire_on_commit = False ) as session :
409- stmt = select (Provider ).where (Provider .tenant_id == tenant_id , Provider .is_valid == True )
410- providers = session .scalars (stmt )
411- for provider in providers :
412- # Use provider name with prefix after the data migration
413- provider_name_to_provider_records_dict [str (ModelProviderID (provider .provider_name ))].append (provider )
411+ stmt = select (Provider ).where (Provider .tenant_id == tenant_id , Provider .is_valid == True )
412+ providers = session .scalars (stmt )
413+ for provider in providers :
414+ # Use provider name with prefix after the data migration
415+ provider_name_to_provider_records_dict [str (ModelProviderID (provider .provider_name ))].append (provider )
414416 return provider_name_to_provider_records_dict
415417
416418 @staticmethod
417- def _get_all_provider_models (tenant_id : str ) -> dict [str , list [ProviderModel ]]:
419+ def _get_all_provider_models (session : Session , tenant_id : str ) -> dict [str , list [ProviderModel ]]:
418420 """
419421 Get all provider model records of the workspace.
420422
421423 :param tenant_id: workspace id
422424 :return:
423425 """
424426 provider_name_to_provider_model_records_dict = defaultdict (list )
425- with Session (db .engine , expire_on_commit = False ) as session :
426- stmt = select (ProviderModel ).where (ProviderModel .tenant_id == tenant_id , ProviderModel .is_valid == True )
427- provider_models = session .scalars (stmt )
428- for provider_model in provider_models :
429- provider_name_to_provider_model_records_dict [provider_model .provider_name ].append (provider_model )
427+ stmt = select (ProviderModel ).where (ProviderModel .tenant_id == tenant_id , ProviderModel .is_valid == True )
428+ provider_models = session .scalars (stmt )
429+ for provider_model in provider_models :
430+ provider_name_to_provider_model_records_dict [provider_model .provider_name ].append (provider_model )
430431 return provider_name_to_provider_model_records_dict
431432
432433 @staticmethod
433- def _get_all_preferred_model_providers (tenant_id : str ) -> dict [str , TenantPreferredModelProvider ]:
434+ def _get_all_preferred_model_providers (session : Session , tenant_id : str ) -> dict [str , TenantPreferredModelProvider ]:
434435 """
435436 Get All preferred provider types of the workspace.
436437
437438 :param tenant_id: workspace id
438439 :return:
439440 """
440441 provider_name_to_preferred_provider_type_records_dict = {}
441- with Session (db .engine , expire_on_commit = False ) as session :
442- stmt = select (TenantPreferredModelProvider ).where (TenantPreferredModelProvider .tenant_id == tenant_id )
443- preferred_provider_types = session .scalars (stmt )
444- provider_name_to_preferred_provider_type_records_dict = {
445- preferred_provider_type .provider_name : preferred_provider_type
446- for preferred_provider_type in preferred_provider_types
447- }
442+ stmt = select (TenantPreferredModelProvider ).where (TenantPreferredModelProvider .tenant_id == tenant_id )
443+ preferred_provider_types = session .scalars (stmt )
444+ provider_name_to_preferred_provider_type_records_dict = {
445+ preferred_provider_type .provider_name : preferred_provider_type
446+ for preferred_provider_type in preferred_provider_types
447+ }
448448 return provider_name_to_preferred_provider_type_records_dict
449449
450450 @staticmethod
451- def _get_all_provider_model_settings (tenant_id : str ) -> dict [str , list [ProviderModelSetting ]]:
451+ def _get_all_provider_model_settings (session : Session , tenant_id : str ) -> dict [str , list [ProviderModelSetting ]]:
452452 """
453453 Get All provider model settings of the workspace.
454454
455455 :param tenant_id: workspace id
456456 :return:
457457 """
458458 provider_name_to_provider_model_settings_dict = defaultdict (list )
459- with Session ( db . engine , expire_on_commit = False ) as session :
460- stmt = select ( ProviderModelSetting ). where ( ProviderModelSetting . tenant_id == tenant_id )
461- provider_model_settings = session . scalars ( stmt )
462- for provider_model_setting in provider_model_settings :
463- provider_name_to_provider_model_settings_dict [ provider_model_setting . provider_name ]. append (
464- provider_model_setting
465- )
459+ stmt = select ( ProviderModelSetting ). where ( ProviderModelSetting . tenant_id == tenant_id )
460+ provider_model_settings = session . scalars ( stmt )
461+ for provider_model_setting in provider_model_settings :
462+ provider_name_to_provider_model_settings_dict [ provider_model_setting . provider_name ]. append (
463+ provider_model_setting
464+ )
465+
466466 return provider_name_to_provider_model_settings_dict
467467
468468 @staticmethod
469- def _get_all_provider_model_credentials (tenant_id : str ) -> dict [str , list [ProviderModelCredential ]]:
469+ def _get_all_provider_model_credentials (
470+ session : Session , tenant_id : str
471+ ) -> dict [str , list [ProviderModelCredential ]]:
470472 """
471473 Get All provider model credentials of the workspace.
472474
473475 :param tenant_id: workspace id
474476 :return:
475477 """
476478 provider_name_to_provider_model_credentials_dict = defaultdict (list )
477- with Session (db .engine , expire_on_commit = False ) as session :
478- stmt = select (ProviderModelCredential ).where (ProviderModelCredential .tenant_id == tenant_id )
479- provider_model_credentials = session .scalars (stmt )
480- for provider_model_credential in provider_model_credentials :
481- provider_name_to_provider_model_credentials_dict [provider_model_credential .provider_name ].append (
482- provider_model_credential
483- )
479+ stmt = select (ProviderModelCredential ).where (ProviderModelCredential .tenant_id == tenant_id )
480+ provider_model_credentials = session .scalars (stmt )
481+ for provider_model_credential in provider_model_credentials :
482+ provider_name_to_provider_model_credentials_dict [provider_model_credential .provider_name ].append (
483+ provider_model_credential
484+ )
484485 return provider_name_to_provider_model_credentials_dict
485486
486487 @staticmethod
487- def _get_all_provider_load_balancing_configs (tenant_id : str ) -> dict [str , list [LoadBalancingModelConfig ]]:
488+ def _get_all_provider_load_balancing_configs (
489+ session : Session , tenant_id : str
490+ ) -> dict [str , list [LoadBalancingModelConfig ]]:
488491 """
489492 Get All provider load balancing configs of the workspace.
490493
@@ -504,13 +507,12 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L
504507 return {}
505508
506509 provider_name_to_provider_load_balancing_model_configs_dict = defaultdict (list )
507- with Session (db .engine , expire_on_commit = False ) as session :
508- stmt = select (LoadBalancingModelConfig ).where (LoadBalancingModelConfig .tenant_id == tenant_id )
509- provider_load_balancing_configs = session .scalars (stmt )
510- for provider_load_balancing_config in provider_load_balancing_configs :
511- provider_name_to_provider_load_balancing_model_configs_dict [
512- provider_load_balancing_config .provider_name
513- ].append (provider_load_balancing_config )
510+ stmt = select (LoadBalancingModelConfig ).where (LoadBalancingModelConfig .tenant_id == tenant_id )
511+ provider_load_balancing_configs = session .scalars (stmt )
512+ for provider_load_balancing_config in provider_load_balancing_configs :
513+ provider_name_to_provider_load_balancing_model_configs_dict [
514+ provider_load_balancing_config .provider_name
515+ ].append (provider_load_balancing_config )
514516
515517 return provider_name_to_provider_load_balancing_model_configs_dict
516518
0 commit comments