Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def configure_arg_parser():
parser.add_argument("seeds", nargs="*", type=int, help="Random seeds")
parser.add_argument("-m", "--mode", type=str, default="ddpm", help="Sampling mode")
parser.add_argument("-f", "--freq", type=int, default=1, help="Sampling step frequency")
parser.add_argument("-s", "--guidance_strength", type=float, default=0.0, help="Guidance strength")
parser.add_argument("-i", "--dataset_dir", type=str, help="Input file for generation")
parser.add_argument("-o", "--output_dir", type=str, help="Output directory for sampling result")
parser.add_argument("-d", "--device_id", type=int, default=0, help="GPU device id")
Expand All @@ -34,6 +35,7 @@ def main(
seeds: list[int],
mode: str,
freq: int,
guidance_strength: float,
dataset_dir: str,
output_dir: str,
device_id: int,
Expand All @@ -52,7 +54,7 @@ def main(

device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"

_, _, enc_dim, dec_dim = get_components(config.base_name)
_, _, enc_dim, dec_dim = get_components(config.base_name, config.encoder.pretrained, **config.decoder)
model = DiDi.load_from_checkpoint(model_path, enc_dim=enc_dim, dec_dim=dec_dim, map_location=device)
model.eval()

Expand All @@ -64,7 +66,7 @@ def main(
context.append(utterance)
joined_context, _ = preprocess(context, "")
raw_context = context_tokenizer(joined_context, **tokenizer_kwargs).to(device)
reply = sample(raw_context, model, mode, freq, context_tokenizer)[0]
reply = sample(raw_context, model, mode, freq, guidance_strength, context_tokenizer)[0]
context.append(reply)
print("DiDi:", reply)
except KeyboardInterrupt:
Expand Down
2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def main(config_path: str, dataset_dir: str, ckpt_dir: str = None, resume: str =
train_dataset.vocab_size,
config.encoder.freeze,
pad_idx=train_dataset.pad_idx,
bos_idx=train_dataset.bos_idx,
eos_idx=train_dataset.eos_idx,
batch_decoder=batch_decoder,
**config.didi,
)
Expand Down
8 changes: 8 additions & 0 deletions src/data/commonsense_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def vocab_size(self) -> int:
def pad_idx(self) -> int:
return self.context_tokenizer.pad_token_id

@property
def bos_idx(self) -> int:
return self.context_tokenizer.bos_token_id or self.context_tokenizer.cls_token_id

@property
def eos_idx(self) -> int:
return self.context_tokenizer.eos_token_id or self.context_tokenizer.sep_token_id

def __iter__(self) -> Iterator[tuple[str, str]]:
n_epochs = 0
while True:
Expand Down
8 changes: 8 additions & 0 deletions src/data/reddit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def vocab_size(self) -> int:
def pad_idx(self) -> int:
return self.context_tokenizer.pad_token_id

@property
def bos_idx(self) -> int:
return self.context_tokenizer.bos_token_id or self.context_tokenizer.cls_token_id

@property
def eos_idx(self) -> int:
return self.context_tokenizer.eos_token_id or self.context_tokenizer.sep_token_id

def __iter__(self) -> Iterator[tuple[str, str]]:
for file in self.files:
zero_rank_info(f"Reading file: {file}")
Expand Down
42 changes: 40 additions & 2 deletions src/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(
schedule: str,
step_freq: int,
pad_idx: int,
bos_idx: int,
eos_idx: int,
context_dropout_prob: float = 0.0,
guidance_strength: float = 0.0,
tie_weights: bool = False,
lr: float = 0.0001,
weight_decay: float = 0.0,
Expand All @@ -91,11 +95,17 @@ def __init__(
super().__init__()
self.save_hyperparameters(ignore=[encoder, decoder])
self.diffusion_steps = diffusion_steps
self.pad_idx = pad_idx
self.step_freq = step_freq
self.encoder_dim = enc_dim
self.decoder_dim = dec_dim

self.dropout_prob = context_dropout_prob
self.w = guidance_strength

self.pad_idx = pad_idx
self.bos_idx = bos_idx
self.eos_idx = eos_idx

self.emb = nn.Embedding(vocabulary_size, dec_dim, padding_idx=pad_idx)
self.time_embeds = nn.Embedding(diffusion_steps + 1, dec_dim)

Expand Down Expand Up @@ -153,6 +163,22 @@ def _encode_context(self, encoder_input_ids, encoder_attention_mask):
context = self.adapter(context)
return context

def dropout_context(self, context, dropout_prob):
out_context = context.copy()
batch_size = context.input_ids.shape[0]

empty_context = torch.full_like(context.input_ids[0], self.pad_idx)
empty_mask = torch.zeros_like(context.attention_mask[0])
empty_context[0] = self.bos_idx
empty_context[1] = self.eos_idx
empty_mask[0] = 1
empty_mask[1] = 1

condition = torch.rand((batch_size, 1), device=context.input_ids.device) < dropout_prob
out_context["input_ids"] = torch.where(condition, empty_context, context.input_ids)
out_context["attention_mask"] = torch.where(condition, empty_mask, context.attention_mask)
return out_context

def forward(
self,
encoder_input_ids: torch.Tensor = None,
Expand Down Expand Up @@ -183,6 +209,10 @@ def forward(

def training_step(self, batch: list, batch_idx: int):
raw_context, target = batch

if self.dropout_prob:
raw_context = self.dropout_context(raw_context, self.dropout_prob)

emb = self.emb(target.input_ids)
x_0 = get_x0(emb, self.std_0)
noise = torch.randn_like(x_0)
Expand Down Expand Up @@ -224,7 +254,15 @@ def training_step(self, batch: list, batch_idx: int):
def validation_step(self, batch: list, batch_idx: int):
raw_context, target = batch
max_trg_len = target.input_ids.shape[1]
logits = sample(raw_context, self, self.sampling_mode, self.step_freq, max_len=max_trg_len, raw_output=True)
logits = sample(
raw_context,
self,
self.sampling_mode,
self.step_freq,
guidance_strength=self.w,
max_len=max_trg_len,
raw_output=True,
)
predictions = logits.argmax(-1)

self.val_ce.append(calculate_batch_ce(logits, target.input_ids, target.attention_mask).item())
Expand Down
58 changes: 43 additions & 15 deletions src/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,31 @@


@torch.no_grad()
def sample(raw_context, model, mode, step_freq, tokenizer=None, max_len=-1, raw_output=False, skip_special=True):
def sample(
raw_context,
model,
mode,
step_freq,
guidance_strength=0.0,
tokenizer=None,
max_len=-1,
raw_output=False,
skip_special=True,
):
input_ids = raw_context.input_ids
emb = model.emb(input_ids)[:, :max_len]

x_t = torch.randn_like(emb) * model.sigmas[-1]

cached_context = None
empty_cached_context = None
ones = torch.ones((emb.shape[0], 1), dtype=torch.long, device=emb.device)
noise = torch.empty_like(emb)

if mode == "ddpm":
logits = sample_ddpm(model, x_t, raw_context, cached_context, noise, ones, step_freq)
logits = sample_ddpm(
model, x_t, raw_context, cached_context, empty_cached_context, noise, ones, step_freq, guidance_strength
)
elif mode == "euler":
logits = sample_euler(model, x_t, raw_context, cached_context, noise, ones, step_freq)
else:
Expand All @@ -33,31 +46,46 @@ def sample(raw_context, model, mode, step_freq, tokenizer=None, max_len=-1, raw_
return select_reply(replies)


def sample_ddpm(model, x_t, raw_context, cached_context, noise, ones, step_freq):
def guided_step(model, x_t, t, raw_context, cached_context, empty_cached_context, ones, guidance_strength):
x_0, cached_context = model(
encoder_input_ids=raw_context.input_ids,
encoder_attention_mask=raw_context.attention_mask,
decoder_inputs_embeds=x_t,
time_ids=t * ones,
context=cached_context,
)

if guidance_strength:
empty_context = model.dropout_context(raw_context, 1)
x_0_uncond, empty_cached_context = model(
encoder_input_ids=empty_context.input_ids,
encoder_attention_mask=empty_context.attention_mask,
decoder_inputs_embeds=x_t,
time_ids=t * ones,
context=empty_cached_context,
)
x_0 = (1 + guidance_strength) * x_0 - guidance_strength * x_0_uncond
return x_0, cached_context, empty_cached_context


def sample_ddpm(
model, x_t, raw_context, cached_context, empty_cached_context, noise, ones, step_freq, guidance_strength
):
diffusion_steps = model.diffusion_steps
timesteps = range(diffusion_steps, 1, -step_freq)

x_t = scale_input(x_t, model.sigmas[-1])

for t in timesteps:
x_0, cached_context = model(
encoder_input_ids=raw_context.input_ids,
encoder_attention_mask=raw_context.attention_mask,
decoder_inputs_embeds=x_t,
time_ids=t * ones,
context=cached_context,
x_0, cached_context, empty_cached_context = guided_step(
model, x_t, t, raw_context, cached_context, empty_cached_context, ones, guidance_strength
)

sigma_t = model.sigmas[max(t - step_freq, 1)]
noise.normal_(0, 1)
x_t = scale_input(x_0 + sigma_t * noise, sigma_t)

x_0, _ = model(
encoder_attention_mask=raw_context.attention_mask,
decoder_inputs_embeds=x_t,
time_ids=ones,
context=cached_context,
)
x_0, *_ = guided_step(model, x_t, 1, raw_context, cached_context, empty_cached_context, ones, guidance_strength)

logits = model.classifier(x_0)
return logits
Expand Down