Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions src/mmgp/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class clock:
def __init__(self):
self.start_time = 0
self.end_time = 0
self.parameters_ref = {}
self.cotenants_map = {}

@classmethod
def start(cls):
Expand Down Expand Up @@ -885,7 +887,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
for n, p in module.named_parameters(recurse = False):
if tied_w != None and n == tied_w[0]:
if isinstance( named_modules[tied_w[1]], QModuleMixin) :
setattr(module, n, None) # release refs of tied weights if source is going to be quantized
setattr(module, n, None) # release refs to tied weights if source is going to be quantized
# otherwise don't force load as it will be loaded in the source anyway
else:
_force_load_parameter(p)
Expand Down Expand Up @@ -1630,32 +1632,38 @@ def detach_hook(self, module):
last_offload_obj = None
class offload:
def __init__(self):
global last_offload_obj
self.active_models = []
self.active_models_ids = []
self.active_subcaches = {}
self.models = {}
self.cotenants_map = {
"text_encoder": ["vae", "text_encoder_2"],
"text_encoder_2": ["vae", "text_encoder"],
}
self.verboseLevel = 0
self.blocks_of_modules = {}
self.blocks_of_modules_sizes = {}
self.anyCompiledModule = False
self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
self.last_reserved_mem_check =0
self.device_mem_capacity = 0
self.last_reserved_mem_check = 0
self.default_stream = None
self.transfer_stream = None
self.parameters_ref = {}
self.cotenants_map = {}
if torch.cuda.is_available():
self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
self.last_reserved_mem_check = time.time()
self.default_stream = torch.cuda.default_stream(torch.device("cuda"))
self.transfer_stream = torch.cuda.Stream()
else:
import psutil
self.device_mem_capacity = psutil.virtual_memory().total
self.last_reserved_mem_check = 0
self.default_stream = None
self.transfer_stream = None
self.loaded_blocks = {}
self.prev_blocks_names = {}
self.next_blocks_names = {}
self.preloaded_blocks_per_model = {}
self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
self.transfer_stream = torch.cuda.Stream()
self.async_transfers = False
self.parameters_ref = {}
self.max_reservable_memory = 0

global last_offload_obj
last_offload_obj = self


def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name, submodule_name):

Expand Down Expand Up @@ -1743,8 +1751,10 @@ def _move_loras(self, loras_active_adapters, loras_modules, to_GPU):

@torch.compiler.disable()
def gpu_load_blocks(self, model_id, blocks_name, preload = False):
# cl = clock.start()

if not torch.cuda.is_available():
if self.verboseLevel >= 1:
print("CUDA is not available. Skipping gpu_load_blocks.")
return

entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name

Expand Down Expand Up @@ -1831,7 +1841,11 @@ def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None

@torch.compiler.disable()
def gpu_unload_blocks(self, model_id, blocks_name):
# cl = clock.start()
if not torch.cuda.is_available():
if self.verboseLevel >= 1:
print("CUDA is not available. Skipping gpu_unload_blocks.")
return

if blocks_name != None and blocks_name == self.loaded_blocks[model_id]:
self.loaded_blocks[model_id] = None

Expand Down Expand Up @@ -1873,6 +1887,10 @@ def gpu_unload_blocks(self, model_id, blocks_name):

# @torch.compiler.disable()
def gpu_load(self, model_id):
if not torch.cuda.is_available():
if self.verboseLevel >= 1:
print("CUDA is not available. Skipping gpu_load.")
return
model = self.models[model_id]
self.active_models.append(model)
self.active_models_ids.append(model_id)
Expand All @@ -1881,6 +1899,11 @@ def gpu_load(self, model_id):
self.gpu_load_blocks(model_id, block_name, True)

def unload_all(self):
if not torch.cuda.is_available():
if self.verboseLevel >= 1:
print("CUDA is not available. Skipping unload_all.")
return

for model_id in self.active_models_ids:
self.gpu_unload_blocks(model_id, None)
for block_name in self.preloaded_blocks_per_model[model_id]:
Expand All @@ -1904,6 +1927,9 @@ def unload_all(self):
self.last_reserved_mem_check = time.time()

def move_args_to_gpu(self, dtype, *args, **kwargs):
if not torch.cuda.is_available():
return args, kwargs

new_args= []
new_kwargs={}

Expand All @@ -1926,6 +1952,8 @@ def move_args_to_gpu(self, dtype, *args, **kwargs):
return new_args, new_kwargs

def ready_to_check_mem(self):
if not torch.cuda.is_available():
return False
if self.anyCompiledModule:
return
cur_clock = time.time()
Expand All @@ -1937,6 +1965,8 @@ def ready_to_check_mem(self):


def empty_cache_if_needed(self):
if not torch.cuda.is_available():
return
mem_reserved = torch.cuda.memory_reserved()
mem_threshold = 0.9*self.device_mem_capacity
if mem_reserved >= mem_threshold:
Expand Down Expand Up @@ -2301,8 +2331,6 @@ def release(self):
torch.cuda.empty_cache()




def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertWeightsFloatTo = torch.bfloat16, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
"""Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
Expand Down Expand Up @@ -2361,8 +2389,6 @@ def get_parsed_budget(b):
self.models = models

extraModelsToQuantize = extraModelsToQuantize if extraModelsToQuantize is not None else []
if not isinstance(extraModelsToQuantize, list):
extraModelsToQuantize= [extraModelsToQuantize]
if quantizeTransformer:
extraModelsToQuantize.append("transformer")
models_to_quantize = extraModelsToQuantize
Expand Down Expand Up @@ -2580,9 +2606,9 @@ def print_size_range(n,start_num,prev_num, prev_size ):
if prev_num < 0:
print(f"Size of submodel '{n}': {prev_size/ONE_MB:.1f} MB")
elif prev_num - start_num <=1:
print(f"Size of submodel '{n+ str(start_num)}': {prev_size/ONE_MB:.1f} MB")
print(f"Size of submodel '{n+str(start_num)}': {prev_size/ONE_MB:.1f} MB")
else:
print(f"Size of submodel '{n+ str(start_num) +'-'+ str(prev_num)}': {(prev_num-start_num+1)*prev_size/ONE_MB:.1f} MB ({prev_size/ONE_MB:.1f} MB x {prev_num-start_num+1})")
print(f"Size of submodel '{n+str(start_num) +'-'+ str(prev_num)}': {(prev_num-start_num+1)*prev_size/ONE_MB:.1f} MB ({prev_size/ONE_MB:.1f} MB x {prev_num-start_num+1})")

for n, size in self.blocks_of_modules_sizes.items():
size = int(size / 10000)* 10000
Expand All @@ -2597,8 +2623,9 @@ def print_size_range(n,start_num,prev_num, prev_size ):
print_size_range(prev_pre,start_num,prev_num, prev_size )


torch.set_default_device('cuda')
torch.cuda.empty_cache()
torch.set_default_device('cuda') if torch.cuda.is_available() else None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

return self
Expand Down