Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 78 additions & 22 deletions k_diffusion/external.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,75 @@
import math

import torch
from torch import nn
from torch import nn, Tensor
from typing import Protocol, Generic, TypeVar, TYPE_CHECKING

from . import sampling, utils

class AbstractModel(Protocol):
def __call__(self, *args, **kwargs) -> Tensor: ...

class DenoiserModel(AbstractModel):
def __call__(self, x: Tensor, t: Tensor, *args, **kwargs) -> Tensor: ...

class CompVisModel(AbstractModel):
alphas_cumprod: Tensor
def apply_model(self, x: Tensor, t: Tensor, cond: Tensor) -> Tensor: ...

class WrappedModelProto(AbstractModel):
def sigma_to_t(self, sigma: Tensor) -> Tensor: ...
def t_to_sigma(self, t: Tensor) -> Tensor: ...
def discretize_sigma(self, sigma: Tensor) -> Tensor: ...

# the 'default' arg of TypeVar isn't valid at runtime, but amazingly seems to be utilised
# at compile-time by some type-checkers
# https://github.com/python/mypy/issues/4236#issuecomment-344660299
if TYPE_CHECKING:
TModel = TypeVar('TModel', bound=AbstractModel, default=AbstractModel)
TDenoiserModel = TypeVar('TDenoiserModel', bound=DenoiserModel, default=DenoiserModel)
TCompVisModel = TypeVar('TCompVisModel', bound=CompVisModel, default=CompVisModel)
else:
TModel = TypeVar('TModel', bound=AbstractModel)
TDenoiserModel = TypeVar('TDenoiserModel', bound=DenoiserModel)
TCompVisModel = TypeVar('TCompVisModel', bound=CompVisModel)

class BaseModelWrapper(nn.Module, WrappedModelProto, Generic[TModel]):
inner_model: TModel

"""The base wrapper class for the k-diffusion model wrapper idiom. Model
wrappers should subclass this class and customize the behavior of the
wrapped model by implementing or overriding methods."""
def __init__(self, model: TModel):
super().__init__()
self.inner_model = model

def __dir__(self):
return list(set(super().__dir__() + dir(self.inner_model)))

class VDenoiser(nn.Module):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.inner_model, name)

def forward(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)

def sigma_to_t(self, sigma: Tensor) -> Tensor:
return sigma

def t_to_sigma(self, t: Tensor) -> Tensor:
return t

def discretize_sigma(self, sigma: Tensor) -> Tensor:
return self.t_to_sigma(self.sigma_to_t(sigma))


class VDenoiser(BaseModelWrapper[TDenoiserModel]):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""

def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
def __init__(self, model: TDenoiserModel):
super().__init__(model)
self.sigma_data = 1.

def get_scalings(self, sigma):
Expand All @@ -38,12 +96,12 @@ def forward(self, input, sigma, **kwargs):
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip


class DiscreteSchedule(nn.Module):
class DiscreteSchedule(BaseModelWrapper[TModel]):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""

def __init__(self, sigmas, quantize):
super().__init__()
def __init__(self, sigmas, quantize, model: TModel):
super().__init__(model)
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
Expand Down Expand Up @@ -86,13 +144,12 @@ def t_to_sigma(self, t):
return log_sigma.exp()


class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
class DiscreteEpsDDPMDenoiser(DiscreteSchedule[TModel]):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""

def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
def __init__(self, model: TModel, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize, model)
self.sigma_data = 1.

def get_scalings(self, sigma):
Expand All @@ -115,10 +172,10 @@ def forward(self, input, sigma, **kwargs):
return input + eps * c_out


class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser[TModel]):
"""A wrapper for OpenAI diffusion models."""

def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
def __init__(self, model: TModel, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
Expand All @@ -130,22 +187,21 @@ def get_eps(self, *args, **kwargs):
return model_output


class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
class CompVisDenoiser(DiscreteEpsDDPMDenoiser[TCompVisModel]):
"""A wrapper for CompVis diffusion models."""

def __init__(self, model, quantize=False, device='cpu'):
def __init__(self, model: TCompVisModel, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)

def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)


class DiscreteVDDPMDenoiser(DiscreteSchedule):
class DiscreteVDDPMDenoiser(DiscreteSchedule[TModel]):
"""A wrapper for discrete schedule DDPM models that output v."""

def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
def __init__(self, model: TModel, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize, model)
self.sigma_data = 1.

def get_scalings(self, sigma):
Expand All @@ -169,10 +225,10 @@ def forward(self, input, sigma, **kwargs):
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip


class CompVisVDenoiser(DiscreteVDDPMDenoiser):
class CompVisVDenoiser(DiscreteVDDPMDenoiser[TCompVisModel]):
"""A wrapper for CompVis diffusion models that output v."""

def __init__(self, model, quantize=False, device='cpu'):
def __init__(self, model: TCompVisModel, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)

def get_v(self, x, t, cond, **kwargs):
Expand Down
31 changes: 18 additions & 13 deletions k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm.auto import trange, tqdm

from . import utils
from .external import WrappedModelProto


def append_zero(x):
Expand Down Expand Up @@ -115,14 +116,15 @@ def __call__(self, sigma, sigma_next):


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_euler(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = model.discretize_sigma(sigma_hat)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand All @@ -136,7 +138,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,


@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
def sample_euler_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand All @@ -156,14 +158,15 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis


@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_heun(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = model.discretize_sigma(sigma_hat)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand All @@ -185,14 +188,15 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,


@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_dpm_2(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = model.discretize_sigma(sigma_hat)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand All @@ -216,7 +220,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,


@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
def sample_dpm_2_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand Down Expand Up @@ -258,7 +262,7 @@ def fn(tau):


@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
def sample_lms(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
Expand All @@ -278,7 +282,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o


@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
def log_likelihood(model: WrappedModelProto, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
Expand Down Expand Up @@ -332,8 +336,9 @@ def propose_step(self, error):

class DPMSolver(nn.Module):
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
model: WrappedModelProto

def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
def __init__(self, model: WrappedModelProto, extra_args=None, eps_callback=None, info_callback=None):
super().__init__()
self.model = model
self.extra_args = {} if extra_args is None else extra_args
Expand Down Expand Up @@ -479,7 +484,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078


@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
def sample_dpm_fast(model: WrappedModelProto, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
Expand All @@ -491,7 +496,7 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback


@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
def sample_dpm_adaptive(model: WrappedModelProto, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
Expand All @@ -506,7 +511,7 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac


@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
def sample_dpmpp_2s_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand Down Expand Up @@ -540,7 +545,7 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,


@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
def sample_dpmpp_sde(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
Expand Down Expand Up @@ -582,7 +587,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N


@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, warmup_lms=False):
def sample_dpmpp_2m(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, warmup_lms=False):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
Expand Down