diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index fb31215db5..d628a87ffa 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -18,8 +18,10 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + # if attr.device != torch.device("cuda"): + # attr = attr.to(torch.device("cuda")) + if attr.device != torch.device("mps"): + attr = attr.to(torch.float32).to(torch.device("mps")).contiguous() setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d9..e136a715ca 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -40,7 +40,8 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + # model.cuda() + model.to("mps") model.eval() return model @@ -199,7 +200,8 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") model = model.to(device) if opt.plms: @@ -242,7 +244,8 @@ def main(): precision_scope = autocast if opt.precision == "autocast" else nullcontext with torch.no_grad(): - with precision_scope("cuda"): + # with precision_scope("cuda"): + with nullcontext("mps"): with model.ema_scope(): tic = time.time() all_samples = list()