Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config/dflash/dflash_gemma4_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
target_layer_ids=[5, 17, 29, 41, 46],
mask_token_id=4,
num_anchors=512,

# Enable D2-style real-token prefix features for dflash training.
enable_d2_feature=True,
d2_prefix_weight_base=0.9,
# Disable markov head.
markov_rank=0,

# Disable confidence head.
confidence_head_alpha=0.0,

# CE-only loss.
loss_decay_gamma=4.0,
ce_loss_alpha=1.0,
Expand Down
6 changes: 3 additions & 3 deletions config/dflash/dflash_qwen3_14b.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
target_layer_ids=[1, 10, 19, 28, 37],
mask_token_id=151669,
num_anchors=512,

# Enable D2-style real-token prefix features for dflash training.
enable_d2_feature=True,
d2_prefix_weight_base=0.9,
# Disable markov head.
markov_rank=0,

# Disable confidence head.
confidence_head_alpha=0.0,

# CE-only loss.
loss_decay_gamma=4.0,
ce_loss_alpha=1.0,
Expand Down
9 changes: 5 additions & 4 deletions config/dflash/dflash_qwen3_4b.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from deepspec.trainer import Qwen3DSparkTrainer

BASE_TB_DIR = os.path.expanduser("~/tensorboard")
BASE_CKPT_DIR = os.path.expanduser("~/checkpoints")
project_name = "deepspec"
Expand All @@ -13,13 +14,13 @@
target_layer_ids=[1, 9, 17, 25, 33],
mask_token_id=151669,
num_anchors=512,

# Enable D2-style real-token prefix features for dflash training.
enable_d2_feature=True,
d2_prefix_weight_base=0.9,
# Disable markov head.
markov_rank=0,

# Disable confidence head.
confidence_head_alpha=0.0,

# CE-only loss.
loss_decay_gamma=4.0,
ce_loss_alpha=1.0,
Expand Down Expand Up @@ -56,7 +57,7 @@

def finalize_cfg(cfg):
logging_cfg = dict(cfg["logging"])
project_name=str(cfg['project_name'])
project_name = str(cfg["project_name"])
exp_name = str(cfg["exp_name"])
logging_cfg["checkpoint_dir"] = os.path.join(BASE_CKPT_DIR, project_name, exp_name)
logging_cfg["tensorboard_dir"] = os.path.join(BASE_TB_DIR, project_name, exp_name)
Expand Down
9 changes: 5 additions & 4 deletions config/dflash/dflash_qwen3_8b.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from deepspec.trainer import Qwen3DSparkTrainer

BASE_TB_DIR = os.path.expanduser("~/tensorboard")
BASE_CKPT_DIR = os.path.expanduser("~/checkpoints")
project_name = "deepspec"
Expand All @@ -13,13 +14,13 @@
target_layer_ids=[1, 9, 17, 25, 33],
mask_token_id=151669,
num_anchors=512,

# Enable D2-style real-token prefix features for dflash training.
enable_d2_feature=True,
d2_prefix_weight_base=0.9,
# Disable markov head.
markov_rank=0,

# Disable confidence head.
confidence_head_alpha=0.0,

# CE-only loss.
loss_decay_gamma=4.0,
ce_loss_alpha=1.0,
Expand Down Expand Up @@ -56,7 +57,7 @@

def finalize_cfg(cfg):
logging_cfg = dict(cfg["logging"])
project_name=str(cfg['project_name'])
project_name = str(cfg["project_name"])
exp_name = str(cfg["exp_name"])
logging_cfg["checkpoint_dir"] = os.path.join(BASE_CKPT_DIR, project_name, exp_name)
logging_cfg["tensorboard_dir"] = os.path.join(BASE_TB_DIR, project_name, exp_name)
Expand Down
120 changes: 110 additions & 10 deletions deepspec/modeling/dspark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DSparkForwardOutput:
confidence_pred: Optional[torch.Tensor] = None
# [batch_size, num_anchors, block_size, vocab_size]
aligned_target_logits: Optional[torch.Tensor] = None
# [batch_size, num_anchors, block_size]
loss_position_offsets: Optional[torch.Tensor] = None


class AcceptRatePredictor(nn.Module):
Expand All @@ -51,7 +53,10 @@ def forward(self, features):

def extract_context_feature(hidden_states, layer_ids):
return torch.cat(
[hidden_states[0 if layer_id == -1 else layer_id + 1] for layer_id in layer_ids],
[
hidden_states[0 if layer_id == -1 else layer_id + 1]
for layer_id in layer_ids
],
dim=-1,
)

Expand All @@ -68,9 +73,9 @@ def validate_target_layer_ids(layer_ids, num_target_layers: int):
f"for num_target_layers={num_target_layers}. "
"-1 denotes the embedding output."
)
assert previous is None or layer_id > previous, (
"target_layer_ids must be strictly increasing."
)
assert (
previous is None or layer_id > previous
), "target_layer_ids must be strictly increasing."
previous = layer_id
return layer_ids

Expand Down Expand Up @@ -140,9 +145,13 @@ def sample_anchor_positions(
keep_mask = torch.zeros(bsz, max_n, dtype=torch.bool, device=device)
return anchors, keep_mask

indices = torch.arange(num_candidates, device=device).unsqueeze(0).expand(
bsz,
-1,
indices = (
torch.arange(num_candidates, device=device)
.unsqueeze(0)
.expand(
bsz,
-1,
)
)
masked_indices = torch.where(
valid,
Expand Down Expand Up @@ -282,9 +291,13 @@ def create_noise_embed(
block_starts = torch.arange(num_blocks, device=device) * block_size
block_starts = block_starts.unsqueeze(0).expand(bsz, -1)
anchor_tokens = torch.gather(input_ids, 1, anchor_positions)
flat_batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(
bsz,
num_blocks,
flat_batch_idx = (
torch.arange(bsz, device=device)
.unsqueeze(1)
.expand(
bsz,
num_blocks,
)
)
noise_ids[flat_batch_idx, block_starts] = torch.where(
block_keep_mask,
Expand All @@ -294,6 +307,90 @@ def create_noise_embed(
return embed_tokens(noise_ids)


def sample_d2_prefix_lengths(
*,
bsz: int,
num_blocks: int,
block_size: int,
prefix_weight_base: float,
device: torch.device,
) -> torch.Tensor:
min_prefix = min(2, int(block_size) - 1)
max_prefix = int(block_size) - 1
if max_prefix <= min_prefix:
return torch.full(
(bsz, num_blocks),
min_prefix,
dtype=torch.long,
device=device,
)

prefix_ids = torch.arange(min_prefix, max_prefix + 1, device=device)
weights = torch.pow(
torch.full_like(prefix_ids, float(prefix_weight_base), dtype=torch.float32),
prefix_ids.float(),
)
sample_indices = torch.multinomial(
weights,
num_samples=bsz * num_blocks,
replacement=True,
).reshape(bsz, num_blocks)
return prefix_ids[sample_indices]


def create_d2_noise_embed(
embed_tokens: nn.Module,
input_ids: torch.Tensor,
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
prefix_lengths: torch.Tensor,
*,
mask_token_id: int,
block_size: int,
) -> torch.Tensor:
bsz, seq_len = input_ids.shape
num_blocks = anchor_positions.shape[1]
device = input_ids.device
offsets = torch.arange(block_size, device=device).view(1, 1, -1)
token_positions = anchor_positions.unsqueeze(-1) + offsets
safe_positions = token_positions.clamp(max=seq_len - 1)
real_tokens = torch.gather(
input_ids.unsqueeze(1).expand(-1, num_blocks, -1),
2,
safe_positions,
)
visible_prefix = offsets < prefix_lengths.unsqueeze(-1)
valid_positions = token_positions < seq_len
fill_mask = visible_prefix & block_keep_mask.unsqueeze(-1) & valid_positions
mask_tokens = torch.full_like(real_tokens, mask_token_id)
noise_ids = torch.where(fill_mask, real_tokens, mask_tokens)
return embed_tokens(noise_ids.reshape(bsz, num_blocks * block_size))


def build_d2_eval_mask(
*,
seq_len: int,
loss_mask: torch.Tensor,
label_indices: torch.Tensor,
safe_label_indices: torch.Tensor,
block_keep_mask: torch.Tensor,
prefix_lengths: torch.Tensor,
) -> torch.Tensor:
pos_in_block = torch.arange(label_indices.size(-1), device=label_indices.device)
pos_in_block = pos_in_block.view(1, 1, -1)
loss_position_mask = pos_in_block >= prefix_lengths.unsqueeze(-1)
target_valid = label_indices < seq_len
target_loss_mask = torch.gather(
loss_mask.unsqueeze(1).expand(-1, label_indices.size(1), -1),
2,
safe_label_indices,
)
target_valid = target_valid & (target_loss_mask > 0.5)
target_valid = target_valid & block_keep_mask.unsqueeze(-1)
contiguous_mask = (target_valid | ~loss_position_mask).to(torch.int32)
return contiguous_mask.cumprod(dim=-1).bool() & loss_position_mask


__all__ = [
"DSparkForwardOutput",
"AcceptRatePredictor",
Expand All @@ -306,4 +403,7 @@ def create_noise_embed(
"log_sampler_stats",
"create_position_ids",
"create_noise_embed",
"sample_d2_prefix_lengths",
"create_d2_noise_embed",
"build_d2_eval_mask",
]
22 changes: 16 additions & 6 deletions deepspec/modeling/dspark/gemma4/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def _validate_required_text_fields(text_config) -> None:
"use_double_wide_mlp",
)
for field in required_fields:
assert hasattr(text_config, field), (
f"target_config.text_config.{field} must be provided."
)
assert hasattr(
text_config, field
), f"target_config.text_config.{field} must be provided."


def build_draft_config(target_config, model_args):
Expand Down Expand Up @@ -75,9 +75,9 @@ def build_draft_config(target_config, model_args):
markov_rank = int(model_args.markov_rank)
assert markov_rank >= 0, f"markov_rank must be >= 0, got {markov_rank}"
if markov_rank > 0:
assert "markov_head_type" in model_args, (
"markov_head_type must be provided when markov_rank > 0."
)
assert (
"markov_head_type" in model_args
), "markov_head_type must be provided when markov_rank > 0."

draft_config.architectures = ["Gemma4DSparkModel"]
draft_config.target_model_type = str(target_config.model_type)
Expand All @@ -91,6 +91,16 @@ def build_draft_config(target_config, model_args):
draft_config.mask_token_id = int(model_args.mask_token_id)
draft_config.target_layer_ids = target_layer_ids
draft_config.num_anchors = int(model_args.num_anchors)
draft_config.enable_d2_feature = bool(
getattr(model_args, "enable_d2_feature", False)
)
draft_config.d2_prefix_weight_base = float(
getattr(model_args, "d2_prefix_weight_base", 0.9)
)
assert draft_config.d2_prefix_weight_base > 0.0, (
"d2_prefix_weight_base must be positive, "
f"got {draft_config.d2_prefix_weight_base}"
)
draft_config.enable_confidence_head = enable_confidence_head
if enable_confidence_head:
draft_config.confidence_head_with_markov = bool(
Expand Down
Loading