Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@
QEffQwen3MoeSparseMoeBlock,
)
from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import (
QEffQwen3VLDecoderWrapper,
QEffQwen3VLForConditionalGeneration,
QEffQwen3VLModel,
QEffQwen3VLTextAttention,
Expand Down Expand Up @@ -854,6 +855,7 @@ class SamplerTransform:
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen3VLDecoderWrapper,
}

@classmethod
Expand Down
13 changes: 10 additions & 3 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class SamplerOutput(ModelOutput):
probs: torch.FloatTensor = None
next_tokens: torch.IntTensor = None
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
deepstack_features: Optional[torch.FloatTensor] = None # For Qwen3VL
image_idx: Optional[torch.IntTensor] = None # for VLMs
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -110,6 +111,7 @@ def sampler_forward(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
deepstack_features: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand All @@ -135,7 +137,7 @@ def sampler_forward(
Perform the sampling of next tokens on the QAIC device (instead of the host)
and return the next tokens and/or probability distributions.

The vision_embeds and image_idx parameters are optional
The vision_embeds, deepstack_features, and image_idx parameters are optional
and are used only for VLMs when supported by the original forward function.

Args:
Expand Down Expand Up @@ -195,11 +197,15 @@ def sampler_forward(
past_key_values=past_key_values,
comp_ctx_lengths=comp_ctx_lengths,
)
output_keys = ["logits", "vision_embeds", "image_idx", "past_key_values"]
if batch_index is not None:
forward_kwargs["batch_index"] = batch_index
if deepstack_features is not None:
forward_kwargs["deepstack_features"] = deepstack_features
output_keys.insert(2, "deepstack_features")

logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs)
outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
result = self.old_forward(**forward_kwargs)
outputs = dict(zip(output_keys, result))
if position_ids.dim() == 3: # For models using m-rope
position_ids = position_ids[0]
else:
Expand Down Expand Up @@ -356,6 +362,7 @@ def sampler_forward(
probs=probs,
next_tokens=next_tokens, # Return sampled next tokens instead of logits
vision_embeds=outputs.get("vision_embeds", None),
deepstack_features=outputs.get("deepstack_features", None),
image_idx=outputs.get("image_idx", None),
past_key_values=outputs.get("past_key_values", None),
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
Expand Down
75 changes: 72 additions & 3 deletions tests/transformers/sampler/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@
None, # spec_length
True, # is_vlm
),
pytest.param(
"Qwen/Qwen3-VL-2B-Instruct", # model
(
["https://picsum.photos/id/237/536/354"] * 2,
["Can you describe the image in detail."] * 2,
), # images and prompts
128, # prefill_seq_len
4096, # ctx_len
20, # generation_len
2, # full_batch_size
None, # spec_length
True, # is_vlm
),
]


Expand Down Expand Up @@ -166,9 +179,10 @@ def test_sampler_transform(
mxfp6_matmul=True,
)
if is_vlm:
model_w_sampler_qpc_path = model_w_sampler_qpc_path[1]
model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1]
model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1]
lang_qpc_path = "lang_qpc_path"
model_w_sampler_qpc_path = model_w_sampler_qpc_path[lang_qpc_path]
model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[lang_qpc_path]
model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[lang_qpc_path]

# Init qaic session
model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path)
Expand Down Expand Up @@ -521,6 +535,61 @@ def test_random_sampling(
]
],
}
elif model == "Qwen/Qwen3-VL-2B-Instruct":
golden_texts = {
"w_sampler": "This is a close-up, top-down photograph of an adorable black puppy resting on weathered wooden flooring",
"wo_sampler": "This is a close-up, top-down photograph of a young black puppy, likely a Labrador Retri",
}
golden_ids = {
"w_sampler": [
[
1986,
374,
264,
3265,
5239,
11,
1909,
14875,
10300,
315,
458,
40608,
3691,
41189,
40119,
389,
9104,
291,
22360,
36148,
]
],
"wo_sampler": [
[
1986,
374,
264,
3265,
5239,
11,
1909,
14875,
10300,
315,
264,
3908,
3691,
41189,
11,
4363,
264,
79276,
10392,
461,
]
],
}
for i in range(full_batch_size):
assert (
tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"]
Expand Down
Loading