diff --git a/config/dflash/dflash_gemma4_12b.py b/config/dflash/dflash_gemma4_12b.py index b07826b..ded50b2 100644 --- a/config/dflash/dflash_gemma4_12b.py +++ b/config/dflash/dflash_gemma4_12b.py @@ -16,13 +16,13 @@ target_layer_ids=[5, 17, 29, 41, 46], mask_token_id=4, num_anchors=512, - + # Enable D2-style real-token prefix features for dflash training. + enable_d2_feature=True, + d2_prefix_weight_base=0.9, # Disable markov head. markov_rank=0, - # Disable confidence head. confidence_head_alpha=0.0, - # CE-only loss. loss_decay_gamma=4.0, ce_loss_alpha=1.0, diff --git a/config/dflash/dflash_qwen3_14b.py b/config/dflash/dflash_qwen3_14b.py index 131f299..35cfe3f 100644 --- a/config/dflash/dflash_qwen3_14b.py +++ b/config/dflash/dflash_qwen3_14b.py @@ -16,13 +16,13 @@ target_layer_ids=[1, 10, 19, 28, 37], mask_token_id=151669, num_anchors=512, - + # Enable D2-style real-token prefix features for dflash training. + enable_d2_feature=True, + d2_prefix_weight_base=0.9, # Disable markov head. markov_rank=0, - # Disable confidence head. confidence_head_alpha=0.0, - # CE-only loss. loss_decay_gamma=4.0, ce_loss_alpha=1.0, diff --git a/config/dflash/dflash_qwen3_4b.py b/config/dflash/dflash_qwen3_4b.py index b98b394..e49ce8a 100644 --- a/config/dflash/dflash_qwen3_4b.py +++ b/config/dflash/dflash_qwen3_4b.py @@ -1,5 +1,6 @@ import os from deepspec.trainer import Qwen3DSparkTrainer + BASE_TB_DIR = os.path.expanduser("~/tensorboard") BASE_CKPT_DIR = os.path.expanduser("~/checkpoints") project_name = "deepspec" @@ -13,13 +14,13 @@ target_layer_ids=[1, 9, 17, 25, 33], mask_token_id=151669, num_anchors=512, - + # Enable D2-style real-token prefix features for dflash training. + enable_d2_feature=True, + d2_prefix_weight_base=0.9, # Disable markov head. markov_rank=0, - # Disable confidence head. confidence_head_alpha=0.0, - # CE-only loss. loss_decay_gamma=4.0, ce_loss_alpha=1.0, @@ -56,7 +57,7 @@ def finalize_cfg(cfg): logging_cfg = dict(cfg["logging"]) - project_name=str(cfg['project_name']) + project_name = str(cfg["project_name"]) exp_name = str(cfg["exp_name"]) logging_cfg["checkpoint_dir"] = os.path.join(BASE_CKPT_DIR, project_name, exp_name) logging_cfg["tensorboard_dir"] = os.path.join(BASE_TB_DIR, project_name, exp_name) diff --git a/config/dflash/dflash_qwen3_8b.py b/config/dflash/dflash_qwen3_8b.py index cad100c..79885a9 100644 --- a/config/dflash/dflash_qwen3_8b.py +++ b/config/dflash/dflash_qwen3_8b.py @@ -1,5 +1,6 @@ import os from deepspec.trainer import Qwen3DSparkTrainer + BASE_TB_DIR = os.path.expanduser("~/tensorboard") BASE_CKPT_DIR = os.path.expanduser("~/checkpoints") project_name = "deepspec" @@ -13,13 +14,13 @@ target_layer_ids=[1, 9, 17, 25, 33], mask_token_id=151669, num_anchors=512, - + # Enable D2-style real-token prefix features for dflash training. + enable_d2_feature=True, + d2_prefix_weight_base=0.9, # Disable markov head. markov_rank=0, - # Disable confidence head. confidence_head_alpha=0.0, - # CE-only loss. loss_decay_gamma=4.0, ce_loss_alpha=1.0, @@ -56,7 +57,7 @@ def finalize_cfg(cfg): logging_cfg = dict(cfg["logging"]) - project_name=str(cfg['project_name']) + project_name = str(cfg["project_name"]) exp_name = str(cfg["exp_name"]) logging_cfg["checkpoint_dir"] = os.path.join(BASE_CKPT_DIR, project_name, exp_name) logging_cfg["tensorboard_dir"] = os.path.join(BASE_TB_DIR, project_name, exp_name) diff --git a/deepspec/modeling/dspark/common.py b/deepspec/modeling/dspark/common.py index 76d908a..98495aa 100644 --- a/deepspec/modeling/dspark/common.py +++ b/deepspec/modeling/dspark/common.py @@ -38,6 +38,8 @@ class DSparkForwardOutput: confidence_pred: Optional[torch.Tensor] = None # [batch_size, num_anchors, block_size, vocab_size] aligned_target_logits: Optional[torch.Tensor] = None + # [batch_size, num_anchors, block_size] + loss_position_offsets: Optional[torch.Tensor] = None class AcceptRatePredictor(nn.Module): @@ -51,7 +53,10 @@ def forward(self, features): def extract_context_feature(hidden_states, layer_ids): return torch.cat( - [hidden_states[0 if layer_id == -1 else layer_id + 1] for layer_id in layer_ids], + [ + hidden_states[0 if layer_id == -1 else layer_id + 1] + for layer_id in layer_ids + ], dim=-1, ) @@ -68,9 +73,9 @@ def validate_target_layer_ids(layer_ids, num_target_layers: int): f"for num_target_layers={num_target_layers}. " "-1 denotes the embedding output." ) - assert previous is None or layer_id > previous, ( - "target_layer_ids must be strictly increasing." - ) + assert ( + previous is None or layer_id > previous + ), "target_layer_ids must be strictly increasing." previous = layer_id return layer_ids @@ -140,9 +145,13 @@ def sample_anchor_positions( keep_mask = torch.zeros(bsz, max_n, dtype=torch.bool, device=device) return anchors, keep_mask - indices = torch.arange(num_candidates, device=device).unsqueeze(0).expand( - bsz, - -1, + indices = ( + torch.arange(num_candidates, device=device) + .unsqueeze(0) + .expand( + bsz, + -1, + ) ) masked_indices = torch.where( valid, @@ -282,9 +291,13 @@ def create_noise_embed( block_starts = torch.arange(num_blocks, device=device) * block_size block_starts = block_starts.unsqueeze(0).expand(bsz, -1) anchor_tokens = torch.gather(input_ids, 1, anchor_positions) - flat_batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand( - bsz, - num_blocks, + flat_batch_idx = ( + torch.arange(bsz, device=device) + .unsqueeze(1) + .expand( + bsz, + num_blocks, + ) ) noise_ids[flat_batch_idx, block_starts] = torch.where( block_keep_mask, @@ -294,6 +307,90 @@ def create_noise_embed( return embed_tokens(noise_ids) +def sample_d2_prefix_lengths( + *, + bsz: int, + num_blocks: int, + block_size: int, + prefix_weight_base: float, + device: torch.device, +) -> torch.Tensor: + min_prefix = min(2, int(block_size) - 1) + max_prefix = int(block_size) - 1 + if max_prefix <= min_prefix: + return torch.full( + (bsz, num_blocks), + min_prefix, + dtype=torch.long, + device=device, + ) + + prefix_ids = torch.arange(min_prefix, max_prefix + 1, device=device) + weights = torch.pow( + torch.full_like(prefix_ids, float(prefix_weight_base), dtype=torch.float32), + prefix_ids.float(), + ) + sample_indices = torch.multinomial( + weights, + num_samples=bsz * num_blocks, + replacement=True, + ).reshape(bsz, num_blocks) + return prefix_ids[sample_indices] + + +def create_d2_noise_embed( + embed_tokens: nn.Module, + input_ids: torch.Tensor, + anchor_positions: torch.Tensor, + block_keep_mask: torch.Tensor, + prefix_lengths: torch.Tensor, + *, + mask_token_id: int, + block_size: int, +) -> torch.Tensor: + bsz, seq_len = input_ids.shape + num_blocks = anchor_positions.shape[1] + device = input_ids.device + offsets = torch.arange(block_size, device=device).view(1, 1, -1) + token_positions = anchor_positions.unsqueeze(-1) + offsets + safe_positions = token_positions.clamp(max=seq_len - 1) + real_tokens = torch.gather( + input_ids.unsqueeze(1).expand(-1, num_blocks, -1), + 2, + safe_positions, + ) + visible_prefix = offsets < prefix_lengths.unsqueeze(-1) + valid_positions = token_positions < seq_len + fill_mask = visible_prefix & block_keep_mask.unsqueeze(-1) & valid_positions + mask_tokens = torch.full_like(real_tokens, mask_token_id) + noise_ids = torch.where(fill_mask, real_tokens, mask_tokens) + return embed_tokens(noise_ids.reshape(bsz, num_blocks * block_size)) + + +def build_d2_eval_mask( + *, + seq_len: int, + loss_mask: torch.Tensor, + label_indices: torch.Tensor, + safe_label_indices: torch.Tensor, + block_keep_mask: torch.Tensor, + prefix_lengths: torch.Tensor, +) -> torch.Tensor: + pos_in_block = torch.arange(label_indices.size(-1), device=label_indices.device) + pos_in_block = pos_in_block.view(1, 1, -1) + loss_position_mask = pos_in_block >= prefix_lengths.unsqueeze(-1) + target_valid = label_indices < seq_len + target_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, label_indices.size(1), -1), + 2, + safe_label_indices, + ) + target_valid = target_valid & (target_loss_mask > 0.5) + target_valid = target_valid & block_keep_mask.unsqueeze(-1) + contiguous_mask = (target_valid | ~loss_position_mask).to(torch.int32) + return contiguous_mask.cumprod(dim=-1).bool() & loss_position_mask + + __all__ = [ "DSparkForwardOutput", "AcceptRatePredictor", @@ -306,4 +403,7 @@ def create_noise_embed( "log_sampler_stats", "create_position_ids", "create_noise_embed", + "sample_d2_prefix_lengths", + "create_d2_noise_embed", + "build_d2_eval_mask", ] diff --git a/deepspec/modeling/dspark/gemma4/config.py b/deepspec/modeling/dspark/gemma4/config.py index 3c10f16..d7b877a 100644 --- a/deepspec/modeling/dspark/gemma4/config.py +++ b/deepspec/modeling/dspark/gemma4/config.py @@ -44,9 +44,9 @@ def _validate_required_text_fields(text_config) -> None: "use_double_wide_mlp", ) for field in required_fields: - assert hasattr(text_config, field), ( - f"target_config.text_config.{field} must be provided." - ) + assert hasattr( + text_config, field + ), f"target_config.text_config.{field} must be provided." def build_draft_config(target_config, model_args): @@ -75,9 +75,9 @@ def build_draft_config(target_config, model_args): markov_rank = int(model_args.markov_rank) assert markov_rank >= 0, f"markov_rank must be >= 0, got {markov_rank}" if markov_rank > 0: - assert "markov_head_type" in model_args, ( - "markov_head_type must be provided when markov_rank > 0." - ) + assert ( + "markov_head_type" in model_args + ), "markov_head_type must be provided when markov_rank > 0." draft_config.architectures = ["Gemma4DSparkModel"] draft_config.target_model_type = str(target_config.model_type) @@ -91,6 +91,16 @@ def build_draft_config(target_config, model_args): draft_config.mask_token_id = int(model_args.mask_token_id) draft_config.target_layer_ids = target_layer_ids draft_config.num_anchors = int(model_args.num_anchors) + draft_config.enable_d2_feature = bool( + getattr(model_args, "enable_d2_feature", False) + ) + draft_config.d2_prefix_weight_base = float( + getattr(model_args, "d2_prefix_weight_base", 0.9) + ) + assert draft_config.d2_prefix_weight_base > 0.0, ( + "d2_prefix_weight_base must be positive, " + f"got {draft_config.d2_prefix_weight_base}" + ) draft_config.enable_confidence_head = enable_confidence_head if enable_confidence_head: draft_config.confidence_head_with_markov = bool( diff --git a/deepspec/modeling/dspark/gemma4/modeling.py b/deepspec/modeling/dspark/gemma4/modeling.py index c339b97..4e19ce2 100644 --- a/deepspec/modeling/dspark/gemma4/modeling.py +++ b/deepspec/modeling/dspark/gemma4/modeling.py @@ -20,13 +20,16 @@ from deepspec.modeling.dspark.common import ( AcceptRatePredictor, DSparkForwardOutput, + build_d2_eval_mask, build_eval_mask, + create_d2_noise_embed, create_dspark_attention_mask, create_noise_embed, create_position_ids, extract_context_feature, log_sampler_stats, sample_anchor_positions, + sample_d2_prefix_lengths, ) from deepspec.modeling.dspark.markov_head import build_markov_head from deepspec.utils.sampling import sample_tokens @@ -45,9 +48,9 @@ def __init__(self, config, layer_idx: int): else: self.num_key_value_heads = int(config.num_key_value_heads) self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - assert self.num_attention_heads % self.num_key_value_heads == 0, ( - "num_attention_heads must be divisible by the Gemma4 key/value head count." - ) + assert ( + self.num_attention_heads % self.num_key_value_heads == 0 + ), "num_attention_heads must be divisible by the Gemma4 key/value head count." self.scaling = 1.0 self.attention_dropout = float(config.attention_dropout) self.is_causal = False @@ -173,12 +176,12 @@ class Gemma4DSparkDecoderLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - assert not bool(config.enable_moe_block), ( - "Gemma4 DSpark prototype does not support Gemma4 MoE blocks yet." - ) - assert int(config.hidden_size_per_layer_input) == 0, ( - "Gemma4 DSpark prototype does not support per-layer input gates yet." - ) + assert not bool( + config.enable_moe_block + ), "Gemma4 DSpark prototype does not support Gemma4 MoE blocks yet." + assert ( + int(config.hidden_size_per_layer_input) == 0 + ), "Gemma4 DSpark prototype does not support per-layer input gates yet." self.self_attn = Gemma4DSparkAttention(config=config, layer_idx=layer_idx) self.mlp = Gemma4TextMLP(config, layer_idx) self.input_layernorm = Gemma4RMSNorm( @@ -214,9 +217,9 @@ def forward( ) -> torch.Tensor: del position_ids, output_attentions, use_cache assert hidden_states is not None, "hidden_states must be provided." - assert target_hidden_states is not None, ( - "target_hidden_states must be provided." - ) + assert ( + target_hidden_states is not None + ), "target_hidden_states must be provided." assert position_embeddings is not None, "position_embeddings must be provided." residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -260,9 +263,9 @@ def __init__(self, config) -> None: for field in required_fields: assert hasattr(config, field), f"config.{field} must be provided." if int(config.markov_rank) > 0: - assert hasattr(config, "markov_head_type"), ( - "config.markov_head_type must be provided when markov_rank > 0." - ) + assert hasattr( + config, "markov_head_type" + ), "config.markov_head_type must be provided when markov_rank > 0." if bool(config.enable_confidence_head): assert hasattr(config, "confidence_head_with_markov"), ( "config.confidence_head_with_markov must be provided when " @@ -300,6 +303,14 @@ def __init__(self, config) -> None: self.block_size = int(config.block_size) self.mask_token_id = config.mask_token_id self.num_anchors = int(config.num_anchors) + self.enable_d2_feature = bool(getattr(config, "enable_d2_feature", False)) + self.d2_prefix_weight_base = float( + getattr(config, "d2_prefix_weight_base", 0.9) + ) + assert self.d2_prefix_weight_base > 0.0, ( + "d2_prefix_weight_base must be positive, " + f"got {self.d2_prefix_weight_base}" + ) self.markov_head = build_markov_head(config) @@ -342,9 +353,9 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: softcap = getattr(self.config, "final_logit_softcapping", None) if softcap is not None: softcap = float(softcap) - assert softcap > 0.0, ( - "config.final_logit_softcapping must be positive when provided." - ) + assert ( + softcap > 0.0 + ), "config.final_logit_softcapping must be positive when provided." logits = torch.tanh(logits / softcap) * softcap return logits @@ -464,17 +475,40 @@ def forward( num_anchors=self.num_anchors, device=device, ) - noise_embedding = create_noise_embed( - self.embed_tokens, - input_ids, - anchor_positions, - block_keep_mask, - mask_token_id=self.mask_token_id, - block_size=self.block_size, - ) - context_position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand( - bsz, - -1, + prefix_lengths = None + if self.enable_d2_feature: + prefix_lengths = sample_d2_prefix_lengths( + bsz=bsz, + num_blocks=anchor_positions.size(1), + block_size=self.block_size, + prefix_weight_base=self.d2_prefix_weight_base, + device=device, + ) + noise_embedding = create_d2_noise_embed( + self.embed_tokens, + input_ids, + anchor_positions, + block_keep_mask, + prefix_lengths, + mask_token_id=self.mask_token_id, + block_size=self.block_size, + ) + else: + noise_embedding = create_noise_embed( + self.embed_tokens, + input_ids, + anchor_positions, + block_keep_mask, + mask_token_id=self.mask_token_id, + block_size=self.block_size, + ) + context_position_ids = ( + torch.arange(seq_len, device=device) + .unsqueeze(0) + .expand( + bsz, + -1, + ) ) draft_position_ids = create_position_ids(anchor_positions, self.block_size) full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) @@ -500,11 +534,18 @@ def forward( -1, ) - label_offsets = torch.arange(1, self.block_size + 1, device=device).view( - 1, - 1, - -1, - ) + if self.enable_d2_feature: + label_offsets = torch.arange(0, self.block_size, device=device).view( + 1, + 1, + -1, + ) + else: + label_offsets = torch.arange(1, self.block_size + 1, device=device).view( + 1, + 1, + -1, + ) label_indices = anchor_positions.unsqueeze(-1) + label_offsets safe_label_indices = label_indices.clamp(max=seq_len - 1) safe_label_indices = torch.where( @@ -536,13 +577,29 @@ def forward( ), ) aligned_target_logits = self.compute_logits(aligned_target_hidden) - eval_mask = build_eval_mask( - seq_len=seq_len, - loss_mask=loss_mask, - label_indices=label_indices, - safe_label_indices=safe_label_indices, - block_keep_mask=block_keep_mask, - ) + loss_position_offsets = None + if self.enable_d2_feature: + assert prefix_lengths is not None + eval_mask = build_d2_eval_mask( + seq_len=seq_len, + loss_mask=loss_mask, + label_indices=label_indices, + safe_label_indices=safe_label_indices, + block_keep_mask=block_keep_mask, + prefix_lengths=prefix_lengths, + ) + pos_in_block = torch.arange(self.block_size, device=device).view(1, 1, -1) + loss_position_offsets = (pos_in_block - prefix_lengths.unsqueeze(-1)).clamp( + min=0 + ) + else: + eval_mask = build_eval_mask( + seq_len=seq_len, + loss_mask=loss_mask, + label_indices=label_indices, + safe_label_indices=safe_label_indices, + block_keep_mask=block_keep_mask, + ) anchor_token_ids = torch.gather( input_ids, 1, @@ -595,6 +652,7 @@ def forward( block_keep_mask=block_keep_mask, confidence_pred=confidence_pred, aligned_target_logits=aligned_target_logits, + loss_position_offsets=loss_position_offsets, ) diff --git a/deepspec/modeling/dspark/loss.py b/deepspec/modeling/dspark/loss.py index 88dd56b..1f823f6 100644 --- a/deepspec/modeling/dspark/loss.py +++ b/deepspec/modeling/dspark/loss.py @@ -28,11 +28,17 @@ def _build_loss_weight_mask( block_size: int, device: torch.device, loss_decay_gamma: Optional[float], + loss_position_offsets: Optional[torch.Tensor] = None, ) -> torch.Tensor: loss_weight_mask = eval_mask.to(torch.float32) if loss_decay_gamma is not None and loss_decay_gamma > 0: - positions = torch.arange(block_size, device=device).view(1, 1, -1) - decay_weights = torch.exp(-positions.float() / float(loss_decay_gamma)) + if loss_position_offsets is None: + loss_position_offsets = torch.arange(block_size, device=device).view( + 1, 1, -1 + ) + decay_weights = torch.exp( + -loss_position_offsets.to(torch.float32) / float(loss_decay_gamma) + ) loss_weight_mask = loss_weight_mask * decay_weights return loss_weight_mask @@ -105,6 +111,7 @@ def _collect_local_terms( block_size=block_size, device=device, loss_decay_gamma=loss_decay_gamma, + loss_position_offsets=outputs.loss_position_offsets, ) flat_logits = draft_logits.reshape(-1, vocab_size) flat_targets = target_ids.reshape(-1) @@ -118,9 +125,9 @@ def _collect_local_terms( aligned_target_logits=aligned_target_logits, ) zero = ce_loss_num.new_zeros(()) - assert l1_loss_alpha <= 0 or aligned_target_logits is not None, ( - "aligned_target_logits is required when l1_loss_alpha > 0." - ) + assert ( + l1_loss_alpha <= 0 or aligned_target_logits is not None + ), "aligned_target_logits is required when l1_loss_alpha > 0." if l1_loss_alpha > 0: l1_loss_num, l1_loss_den = _compute_local_l1_term( outputs=outputs, @@ -150,34 +157,30 @@ def _collect_local_terms( confidence_bias_num = zero confidence_cumprod_bias_num = zero if has_confidence: - assert accept_rate_3d is not None, ( - "aligned_target_logits is required when confidence head is enabled." - ) + assert ( + accept_rate_3d is not None + ), "aligned_target_logits is required when confidence head is enabled." confidence_targets = accept_rate_3d.detach() - confidence_errors = F.binary_cross_entropy_with_logits( - outputs.confidence_pred.float(), - confidence_targets, - reduction="none", - ) * loss_weight_mask + confidence_errors = ( + F.binary_cross_entropy_with_logits( + outputs.confidence_pred.float(), + confidence_targets, + reduction="none", + ) + * loss_weight_mask + ) confidence_loss_num = confidence_errors.sum() confidence_loss_den = loss_weight_mask.sum() with torch.no_grad(): confidence_probs = outputs.confidence_pred.float().sigmoid() confidence_error = confidence_probs - accept_rate_3d - confidence_abs_error_num = ( - confidence_error.abs() * loss_weight_mask - ).sum() + confidence_abs_error_num = (confidence_error.abs() * loss_weight_mask).sum() confidence_bias_num = (confidence_error * loss_weight_mask).sum() valid_mask = outputs.eval_mask.to(torch.float32) - confidence_prefix_probs = ( - confidence_probs * valid_mask - ).cumprod(dim=-1) - confidence_prefix_targets = ( - accept_rate_3d * valid_mask - ).cumprod(dim=-1) + confidence_prefix_probs = (confidence_probs * valid_mask).cumprod(dim=-1) + confidence_prefix_targets = (accept_rate_3d * valid_mask).cumprod(dim=-1) confidence_cumprod_bias_num = ( - (confidence_prefix_probs - confidence_prefix_targets) - * loss_weight_mask + (confidence_prefix_probs - confidence_prefix_targets) * loss_weight_mask ).sum() loss_terms = { @@ -277,9 +280,7 @@ def compute_dspark_loss( local_ce_loss = loss_terms["ce_loss_num"] / (loss_terms["ce_loss_den"] + 1e-6) local_l1_loss = local_ce_loss.new_zeros(()) if global_denominators["l1_loss_den"].item() > 0: - local_l1_loss = loss_terms["l1_loss_num"] / ( - loss_terms["l1_loss_den"] + 1e-6 - ) + local_l1_loss = loss_terms["l1_loss_num"] / (loss_terms["l1_loss_den"] + 1e-6) local_confidence_loss = local_ce_loss.new_zeros(()) if has_confidence: local_confidence_loss = loss_terms["confidence_loss_num"] / ( diff --git a/deepspec/modeling/dspark/qwen3/config.py b/deepspec/modeling/dspark/qwen3/config.py index ae71fac..95c3ff7 100644 --- a/deepspec/modeling/dspark/qwen3/config.py +++ b/deepspec/modeling/dspark/qwen3/config.py @@ -30,9 +30,9 @@ def build_draft_config( markov_rank = int(model_args.markov_rank) assert markov_rank >= 0, f"markov_rank must be >= 0, got {markov_rank}" if markov_rank > 0: - assert "markov_head_type" in model_args, ( - "markov_head_type must be provided when markov_rank > 0." - ) + assert ( + "markov_head_type" in model_args + ), "markov_head_type must be provided when markov_rank > 0." draft_config = copy.deepcopy(target_config) draft_config.architectures = ["Qwen3DSparkModel"] @@ -45,6 +45,16 @@ def build_draft_config( draft_config.mask_token_id = int(model_args.mask_token_id) draft_config.target_layer_ids = target_layer_ids draft_config.num_anchors = int(model_args.num_anchors) + draft_config.enable_d2_feature = bool( + getattr(model_args, "enable_d2_feature", False) + ) + draft_config.d2_prefix_weight_base = float( + getattr(model_args, "d2_prefix_weight_base", 0.9) + ) + assert draft_config.d2_prefix_weight_base > 0.0, ( + "d2_prefix_weight_base must be positive, " + f"got {draft_config.d2_prefix_weight_base}" + ) draft_config.enable_confidence_head = enable_confidence_head if enable_confidence_head: draft_config.confidence_head_with_markov = bool( diff --git a/deepspec/modeling/dspark/qwen3/modeling.py b/deepspec/modeling/dspark/qwen3/modeling.py index 2f9c220..55a0ffd 100644 --- a/deepspec/modeling/dspark/qwen3/modeling.py +++ b/deepspec/modeling/dspark/qwen3/modeling.py @@ -20,13 +20,16 @@ from deepspec.modeling.dspark.common import ( AcceptRatePredictor, DSparkForwardOutput, + build_d2_eval_mask, build_eval_mask, + create_d2_noise_embed, create_dspark_attention_mask, create_noise_embed, create_position_ids, extract_context_feature, log_sampler_stats, sample_anchor_positions, + sample_d2_prefix_lengths, ) from deepspec.modeling.dspark.markov_head import build_markov_head from deepspec.utils.sampling import sample_tokens @@ -51,9 +54,7 @@ def __init__(self, config, layer_idx: int): ) self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = ( - self.num_attention_heads // self.num_key_value_heads - ) + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = False @@ -173,11 +174,11 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -215,9 +216,9 @@ def __init__(self, config) -> None: for field in required_fields: assert hasattr(config, field), f"config.{field} must be provided." if int(config.markov_rank) > 0: - assert hasattr(config, "markov_head_type"), ( - "config.markov_head_type must be provided when markov_rank > 0." - ) + assert hasattr( + config, "markov_head_type" + ), "config.markov_head_type must be provided when markov_rank > 0." if bool(config.enable_confidence_head): assert hasattr(config, "confidence_head_with_markov"), ( "config.confidence_head_with_markov must be provided when " @@ -248,6 +249,14 @@ def __init__(self, config) -> None: self.block_size = int(config.block_size) self.mask_token_id = config.mask_token_id self.num_anchors = int(config.num_anchors) + self.enable_d2_feature = bool(getattr(config, "enable_d2_feature", False)) + self.d2_prefix_weight_base = float( + getattr(config, "d2_prefix_weight_base", 0.9) + ) + assert self.d2_prefix_weight_base > 0.0, ( + "d2_prefix_weight_base must be positive, " + f"got {self.d2_prefix_weight_base}" + ) # Markov head. self.markov_head = build_markov_head(config) @@ -402,15 +411,36 @@ def forward( num_anchors=self.num_anchors, device=device, ) - noise_embedding = create_noise_embed( - self.embed_tokens, - input_ids, - anchor_positions, - block_keep_mask, - mask_token_id=self.mask_token_id, - block_size=self.block_size, + prefix_lengths = None + if self.enable_d2_feature: + prefix_lengths = sample_d2_prefix_lengths( + bsz=bsz, + num_blocks=anchor_positions.size(1), + block_size=self.block_size, + prefix_weight_base=self.d2_prefix_weight_base, + device=device, + ) + noise_embedding = create_d2_noise_embed( + self.embed_tokens, + input_ids, + anchor_positions, + block_keep_mask, + prefix_lengths, + mask_token_id=self.mask_token_id, + block_size=self.block_size, + ) + else: + noise_embedding = create_noise_embed( + self.embed_tokens, + input_ids, + anchor_positions, + block_keep_mask, + mask_token_id=self.mask_token_id, + block_size=self.block_size, + ) + context_position_ids = ( + torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) ) - context_position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) draft_position_ids = create_position_ids(anchor_positions, self.block_size) full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) dspark_attn_mask = create_dspark_attention_mask( @@ -430,9 +460,14 @@ def forward( num_blocks = anchor_positions.size(1) output_hidden_4d = output_hidden.reshape(bsz, num_blocks, self.block_size, -1) - label_offsets = torch.arange(1, self.block_size + 1, device=device).view( - 1, 1, -1 - ) + if self.enable_d2_feature: + label_offsets = torch.arange(0, self.block_size, device=device).view( + 1, 1, -1 + ) + else: + label_offsets = torch.arange(1, self.block_size + 1, device=device).view( + 1, 1, -1 + ) label_indices = anchor_positions.unsqueeze(-1) + label_offsets safe_label_indices = label_indices.clamp(max=seq_len - 1) safe_label_indices = torch.where( @@ -464,13 +499,29 @@ def forward( ), ) aligned_target_logits = self.compute_logits(aligned_target_hidden) - eval_mask = build_eval_mask( - seq_len=seq_len, - loss_mask=loss_mask, - label_indices=label_indices, - safe_label_indices=safe_label_indices, - block_keep_mask=block_keep_mask, - ) + loss_position_offsets = None + if self.enable_d2_feature: + assert prefix_lengths is not None + eval_mask = build_d2_eval_mask( + seq_len=seq_len, + loss_mask=loss_mask, + label_indices=label_indices, + safe_label_indices=safe_label_indices, + block_keep_mask=block_keep_mask, + prefix_lengths=prefix_lengths, + ) + pos_in_block = torch.arange(self.block_size, device=device).view(1, 1, -1) + loss_position_offsets = (pos_in_block - prefix_lengths.unsqueeze(-1)).clamp( + min=0 + ) + else: + eval_mask = build_eval_mask( + seq_len=seq_len, + loss_mask=loss_mask, + label_indices=label_indices, + safe_label_indices=safe_label_indices, + block_keep_mask=block_keep_mask, + ) anchor_token_ids = torch.gather( input_ids, 1, @@ -505,9 +556,9 @@ def forward( confidence_pred = None if self.confidence_head is not None: if self.confidence_head_with_markov: - prev_embeddings = self.markov_head.get_prev_embeddings(prev_token_ids).to( - dtype=output_hidden_4d.dtype - ) + prev_embeddings = self.markov_head.get_prev_embeddings( + prev_token_ids + ).to(dtype=output_hidden_4d.dtype) confidence_features = torch.cat( [output_hidden_4d, prev_embeddings], dim=-1, @@ -523,6 +574,7 @@ def forward( block_keep_mask=block_keep_mask, confidence_pred=confidence_pred, aligned_target_logits=aligned_target_logits, + loss_position_offsets=loss_position_offsets, )