Skip to content

Default to Forward KL? #5

@hank0316

Description

@hank0316

Hi @idanshen,

Thanks for open-sourcing the training code.

I noticed that the default value of alpha in DistilConfig is set to 0:

alpha: float = field(
default=0.0,
metadata={
"help": "Alpha coefficient. If `0.0` (default), the forward KL is used. If `1.0`, the reverse KL is used. If anything in between, the Jensen-Shannon Divergence is used."
},
)

And is not overridden in main.py:

config = DistilConfig(
seed=args.seed,
use_vllm = True,
vllm_mode="colocate",
vllm_tensor_parallel_size=1,
vllm_gpu_memory_utilization=0.3,
vllm_enable_sleep_mode=True,
learning_rate = args.learning_rate,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
logging_steps = 1,
bf16 = True,
fp16 = False,
per_device_train_batch_size = 1,
gradient_accumulation_steps = args.num_prompts_per_batch,
max_prompt_length = 1024,
max_completion_length = 1024,
num_train_epochs = args.num_train_epochs,
save_steps = 100,
max_grad_norm = 1,
report_to = "wandb",
output_dir = args.output_dir,
log_completions = False, # True for debugging
sync_ref_model = True,
ref_model_sync_steps = 1,
ref_model_mixup_alpha = args.ref_model_mixup_alpha,
vllm_importance_sampling_correction = True,
num_loss_tokens_to_skip = 3,
)

In the trainer implementation, it seems that when alpha = 0, the trainer will use forward KL, instead of reverse KL (eq 1 in your paper):

if self.alpha == 0: #Forward KL
kl_loss = kl_div(all_logps, teacher_all_logps, reduction="none", log_target=True)
elif self.alpha == 1: #Reverse KL
kl_loss = kl_div(teacher_all_logps, all_logps, reduction="none", log_target=True)

Is this the intended default behavior? Not sure if I misunderstood the loss calculation logics.

Looking forward to your clarification!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions