@@ -567,9 +567,21 @@ def LoadModel(self, request, context):
567567 torchType = torchType ,
568568 variant = variant
569569 )
570-
570+
571571 print (f"LoadModel: After loading - ltx2_pipeline: { self .ltx2_pipeline } , img2vid: { self .img2vid } , txt2vid: { self .txt2vid } , PipelineType: { self .PipelineType } " , file = sys .stderr )
572572
573+ # Initialize LoRA hotswap support
574+ if hasattr (self .pipe , 'enable_lora_hotswap' ):
575+ try :
576+ self .pipe .enable_lora_hotswap (target_rank = 128 )
577+ print ("LoRA hotswap enabled" , file = sys .stderr )
578+ except Exception as e :
579+ print (f"Warning: Failed to enable LoRA hotswap: { e } " , file = sys .stderr )
580+
581+ # Initialize LoRA management
582+ self ._loaded_loras = {} # {adapter_name: path}
583+ self ._lora_counter = 0
584+
573585 if CLIPSKIP and request .CLIPSkip != 0 :
574586 self .clip_skip = request .CLIPSkip
575587 else :
@@ -607,22 +619,29 @@ def LoadModel(self, request, context):
607619 if mps_available :
608620 device = "mps"
609621 self .device = device
622+
623+ # Load static LoRAs from config (backward compatibility)
610624 if request .LoraAdapter :
611625 # Check if its a local file and not a directory ( we load lora differently for a safetensor file )
612626 if os .path .exists (request .LoraAdapter ) and not os .path .isdir (request .LoraAdapter ):
613- self .pipe .load_lora_weights (request .LoraAdapter )
627+ adapter_name = f"static_lora_{ self ._lora_counter } "
628+ self .pipe .load_lora_weights (request .LoraAdapter , adapter_name = adapter_name )
629+ self ._loaded_loras [adapter_name ] = request .LoraAdapter
630+ self ._lora_counter += 1
614631 else :
615632 self .pipe .unet .load_attn_procs (request .LoraAdapter )
633+
616634 if len (request .LoraAdapters ) > 0 :
617- i = 0
618635 adapters_name = []
619636 adapters_weights = []
620637 for adapter in request .LoraAdapters :
621638 if not os .path .isabs (adapter ):
622639 adapter = os .path .join (request .ModelPath , adapter )
623- self .pipe .load_lora_weights (adapter , adapter_name = f"adapter_{ i } " )
624- adapters_name .append (f"adapter_{ i } " )
625- i += 1
640+ adapter_name = f"static_lora_{ self ._lora_counter } "
641+ self .pipe .load_lora_weights (adapter , adapter_name = adapter_name )
642+ self ._loaded_loras [adapter_name ] = adapter
643+ adapters_name .append (adapter_name )
644+ self ._lora_counter += 1
626645
627646 for adapters_weight in request .LoraScales :
628647 adapters_weights .append (adapters_weight )
@@ -697,7 +716,118 @@ def load_lora_weights(self, checkpoint_path, multiplier, device, dtype):
697716 else :
698717 curr_layer .weight .data += multiplier * alpha * torch .mm (weight_up , weight_down )
699718
719+ def _hotswap_loras (self , requested_adapters , requested_scales ):
720+ """
721+ Hot-swap LoRA adapters without reloading the base model.
722+
723+ Args:
724+ requested_adapters: List of LoRA file paths
725+ requested_scales: List of LoRA scales
726+ """
727+ print (f"LoRA hotswap: requested { len (requested_adapters )} adapters" , file = sys .stderr )
728+
729+ # Resolve relative paths
730+ resolved_adapters = []
731+ for adapter in requested_adapters :
732+ if not os .path .isabs (adapter ):
733+ # Try different base paths
734+ if adapter .startswith ("loras/" ):
735+ # Check if it's a semantic name (e.g., "yudaiqiao")
736+ semantic_path = f"/build/models/loras/yiheyuan/{ os .path .basename (adapter )} .safetensors"
737+ if os .path .exists (semantic_path ):
738+ resolved_adapters .append (semantic_path )
739+ else :
740+ resolved_adapters .append (f"/build/models/{ adapter } " )
741+ else :
742+ resolved_adapters .append (f"/build/models/loras/yiheyuan/{ adapter } .safetensors" )
743+ else :
744+ resolved_adapters .append (adapter )
745+
746+ # Validate all adapters exist
747+ for adapter in resolved_adapters :
748+ if not os .path .exists (adapter ):
749+ raise FileNotFoundError (f"LoRA adapter not found: { adapter } " )
750+
751+ # Get currently active adapters
752+ current_adapters = set (self ._loaded_loras .keys ())
753+ requested_paths = set (resolved_adapters )
754+
755+ # Find adapters to load and unload
756+ current_paths = set (self ._loaded_loras .values ())
757+ to_load = requested_paths - current_paths
758+ to_unload = current_paths - requested_paths
759+
760+ print (f"LoRA hotswap: loading { len (to_load )} , unloading { len (to_unload )} " , file = sys .stderr )
761+
762+ # Unload unused adapters
763+ adapters_to_remove = []
764+ for adapter_name , adapter_path in self ._loaded_loras .items ():
765+ if adapter_path in to_unload :
766+ adapters_to_remove .append (adapter_name )
767+
768+ if adapters_to_remove :
769+ try :
770+ self .pipe .delete_adapters (adapters_to_remove )
771+ for adapter_name in adapters_to_remove :
772+ del self ._loaded_loras [adapter_name ]
773+ print (f"Unloaded LoRA adapters: { adapters_to_remove } " , file = sys .stderr )
774+ except Exception as e :
775+ print (f"Warning: Failed to unload some adapters: { e } " , file = sys .stderr )
776+
777+ # Load new adapters
778+ for adapter_path in to_load :
779+ adapter_name = f"dynamic_lora_{ self ._lora_counter } "
780+ try :
781+ self .pipe .load_lora_weights (adapter_path , adapter_name = adapter_name )
782+ self ._loaded_loras [adapter_name ] = adapter_path
783+ self ._lora_counter += 1
784+ print (f"Loaded LoRA adapter: { adapter_name } -> { adapter_path } " , file = sys .stderr )
785+ except Exception as e :
786+ print (f"Error loading LoRA { adapter_path } : { e } " , file = sys .stderr )
787+ raise
788+
789+ # Activate requested adapters with their scales
790+ active_names = []
791+ active_weights = []
792+
793+ for i , adapter_path in enumerate (resolved_adapters ):
794+ # Find the adapter name for this path
795+ adapter_name = None
796+ for name , path in self ._loaded_loras .items ():
797+ if path == adapter_path :
798+ adapter_name = name
799+ break
800+
801+ if adapter_name :
802+ active_names .append (adapter_name )
803+ scale = requested_scales [i ] if i < len (requested_scales ) else 1.0
804+ active_weights .append (scale )
805+
806+ # Set active adapters
807+ if active_names :
808+ try :
809+ self .pipe .set_adapters (active_names , adapter_weights = active_weights )
810+ print (f"Activated LoRA adapters: { active_names } with weights { active_weights } " , file = sys .stderr )
811+ except Exception as e :
812+ print (f"Error setting adapters: { e } " , file = sys .stderr )
813+ raise
814+ else :
815+ # Disable all adapters if none requested
816+ try :
817+ if hasattr (self .pipe , 'disable_lora' ):
818+ self .pipe .disable_lora ()
819+ print ("Disabled all LoRA adapters" , file = sys .stderr )
820+ except Exception as e :
821+ print (f"Warning: Failed to disable LoRA: { e } " , file = sys .stderr )
822+
700823 def GenerateImage (self , request , context ):
824+ # === Dynamic LoRA Hot-swapping ===
825+ if request .lora_adapters :
826+ try :
827+ self ._hotswap_loras (request .lora_adapters , request .lora_scales )
828+ except Exception as e :
829+ print (f"Error during LoRA hotswap: { e } " , file = sys .stderr )
830+ return backend_pb2 .Result (success = False , message = f"LoRA hotswap error: { e } " )
701831
702832 prompt = request .positive_prompt
703833
0 commit comments