From ddc55cf8abeb858fe695390667e0814a977f0cb0 Mon Sep 17 00:00:00 2001 From: AsyncLee Date: Thu, 20 Oct 2022 01:21:52 +0900 Subject: [PATCH] Fix vram leak when loading model When loading with the previous code, both pl_sd itself and the model read by state_dict from pl_sd are stored in vram, so the program consumes more graphics memory than it needs. Modify that part. --- scripts/image_variations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/image_variations.py b/scripts/image_variations.py index 28bcee42..6c143ba4 100644 --- a/scripts/image_variations.py +++ b/scripts/image_variations.py @@ -19,7 +19,7 @@ def load_model_from_config(config, ckpt, device, verbose=False): print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location=device) + pl_sd = torch.load(ckpt, map_location='cpu') if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"]