diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index 3c234f7..bee239c 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -454,7 +454,7 @@ def _build_vision_data(self, model_config: "OmniMoTModelConfig", sample_meta: Sa if self.vision_path and "://" in self.vision_path: raise ValueError("Must call `download()` before building vision data") - # Reasoner mode treats ``vision_path`` as a PIL image source; resolution/fps/num_frames are unused. + # Reasoner mode treats ``vision_path`` as an image (PIL) or video (mp4) source; resolution/fps/num_frames are unused. if sample_meta.model_mode.is_reasoner: self.condition_frame_indexes_vision = self.condition_frame_indexes_vision or [] self.condition_video_keep = self.condition_video_keep or "first" @@ -609,6 +609,7 @@ class ReasonerDataArgs(ArgsBase): top_p: _ReasonerTopP | None = None repetition_penalty: _ReasonerRepetitionPenalty | None = None presence_penalty: float | None = None + video_fps: pydantic.PositiveFloat | None = None class ReasonerDataOverrides(OverridesBase): @@ -629,6 +630,8 @@ class ReasonerDataOverrides(OverridesBase): """CTRL/HF-style multiplicative repetition penalty (>0). ``1.0`` is identity.""" presence_penalty: float | None = None """Additive presence penalty (any sign). ``0.0`` is identity.""" + video_fps: pydantic.PositiveFloat | None = None + """Frames per second to sample from a video vision_path. None -> decoder default (2.0).""" def _build_reasoner_data(self, model_config: "OmniMoTModelConfig", sample_meta: SampleMeta): if not sample_meta.model_mode.is_reasoner: diff --git a/cosmos_framework/inference/args_test.py b/cosmos_framework/inference/args_test.py index 3bf3703..bd78439 100644 --- a/cosmos_framework/inference/args_test.py +++ b/cosmos_framework/inference/args_test.py @@ -15,6 +15,7 @@ ModelMode, OmniSampleOverrides, OmniSetupOverrides, + ReasonerDataOverrides, ) from cosmos_framework.inference.common.config import structure_config @@ -156,3 +157,13 @@ def test_sample_args(tmp_path: Path): assert text2image_args.num_steps == 50 assert text2image_args.guidance == 4.0 assert text2image_args.shift == 3.0 + + +def test_reasoner_video_fps_defaults_none(): + ov = ReasonerDataOverrides() + assert ov.video_fps is None + + +def test_reasoner_video_fps_accepts_positive_float(): + ov = ReasonerDataOverrides(video_fps=2.0) + assert ov.video_fps == 2.0 diff --git a/cosmos_framework/inference/defaults/reasoner/sample_args.json b/cosmos_framework/inference/defaults/reasoner/sample_args.json index cc53991..e7a25ad 100644 --- a/cosmos_framework/inference/defaults/reasoner/sample_args.json +++ b/cosmos_framework/inference/defaults/reasoner/sample_args.json @@ -6,5 +6,6 @@ "top_k": null, "top_p": null, "repetition_penalty": 1.0, - "presence_penalty": 0.0 + "presence_penalty": 0.0, + "video_fps": null } diff --git a/cosmos_framework/inference/inference.py b/cosmos_framework/inference/inference.py index 83ff665..3078e4d 100644 --- a/cosmos_framework/inference/inference.py +++ b/cosmos_framework/inference/inference.py @@ -14,11 +14,15 @@ import cattrs.preconf.json import safetensors.torch import torch +import torchvision.io from PIL import Image +from qwen_vl_utils.vision_process import smart_nframes from torch.utils._pytree import tree_map_only from torch.utils.data import Dataset from typing_extensions import Self +from cosmos_framework.configs.base.defaults.compile import CompileConfig +from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig from cosmos_framework.inference.args import ( ModelMode, NegativeMetadataMode, @@ -26,6 +30,7 @@ OmniSetupArgs, ) from cosmos_framework.inference.common.args import ( + VIDEO_EXTENSIONS, CheckpointType, ConfigFileType, ParallelismArgs, @@ -46,13 +51,11 @@ pil_to_conditioning_frames, resize_pil_image, ) -from cosmos_framework.utils import log -from cosmos_framework.tools.visualize.video import save_img_or_video -from cosmos_framework.configs.base.defaults.compile import CompileConfig -from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel -from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING from cosmos_framework.model.vfm.upsampler.prompts import is_upsampled_prompt +from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING +from cosmos_framework.tools.visualize.video import save_img_or_video +from cosmos_framework.utils import log if TYPE_CHECKING: from cosmos_framework.configs.base.defaults.model_config import OmniMoTModelConfig @@ -463,14 +466,44 @@ def _get_prompt_sample_data(sample_args: OmniSampleArgs, model: OmniMoTModel, *, return out +def _decode_reasoner_video(vision_path: str, video_fps: float | None) -> dict[str, Any]: + """Decode a local video file into the frame-list payload the Qwen3-VL processor expects. + + Returns ``{"frames": [PIL.Image, ...], "fps": float}``. Uses the same + ``torchvision.io.read_video`` decode the rest of the inference path relies on + (no ``decord`` dependency), then uniformly samples frames toward ``video_fps`` + (default 2.0) via Qwen's ``smart_nframes``. The repo ``Qwen3VLProcessor`` runs + with ``do_sample_frames=False``, so it consumes this pre-sampled frame list + as-is and handles its own per-frame resize.""" + frames, _, info = torchvision.io.read_video(str(vision_path), pts_unit="sec") # [T,H,W,C] uint8 + total_frames = int(frames.shape[0]) + if total_frames == 0: + raise ValueError(f"Decoded zero frames from reasoner video: {vision_path}") + src_fps = float(info.get("video_fps") or 0.0) or 1.0 + target_fps = video_fps if video_fps is not None else 2.0 + nframes = smart_nframes({"fps": target_fps}, total_frames=total_frames, video_fps=src_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() + pil_frames = [Image.fromarray(frames[i].numpy()) for i in idx] + sample_fps = nframes / total_frames * src_fps + return {"frames": pil_frames, "fps": sample_fps} + + def _get_reasoner_sample_data(sample_args: OmniSampleArgs, model: OmniMoTModel) -> dict[str, Any]: - """Sample batch for reasoner text generation: prompt + optional conditioning image.""" + """Sample batch for reasoner text generation: prompt + optional conditioning image or video.""" image: Image.Image | None = None + video: dict[str, Any] | None = None if sample_args.vision_path is not None: - image = Image.open(sample_args.vision_path).convert("RGB") + if Path(sample_args.vision_path).suffix.lower() in VIDEO_EXTENSIONS: + video = _decode_reasoner_video(str(sample_args.vision_path), sample_args.video_fps) + else: + image = Image.open(sample_args.vision_path).convert("RGB") + # Both keys are emitted for every sample (``None`` when absent) so the batch + # builder can positionally align them and the three-way homogeneity check in + # ``_generate_reasoner_batch`` reliably detects an image/video/text mix. return { model.input_caption_key: [sample_args.prompt], "reasoner_images": [image], + "reasoner_videos": [video], } @@ -1655,13 +1688,28 @@ def _generate_reasoner_batch( prompts: list[str] = data_batch[self.model.input_caption_key] raw_images: list[Image.Image | None] = data_batch["reasoner_images"] - n_set = sum(img is not None for img in raw_images) - if 0 < n_set < len(raw_images): + raw_videos: list[dict[str, Any] | None] | None = data_batch.get("reasoner_videos") + + n_img = sum(img is not None for img in raw_images) + n_vid = sum(v is not None for v in (raw_videos or [])) + if n_img and n_vid: + raise ValueError( + "Reasoner batch mixes image- and video-conditioned samples. Split into separate batches." + ) + if 0 < n_img < len(raw_images): raise ValueError( "Reasoner batch mixes image-conditioned and text-only samples " - f"({n_set}/{len(raw_images)} have vision_path). Split into separate batches." + f"({n_img}/{len(raw_images)} have an image vision_path). Split into separate batches." + ) + if raw_videos is not None and 0 < n_vid < len(raw_videos): + raise ValueError( + "Reasoner batch mixes video-conditioned and text-only samples " + f"({n_vid}/{len(raw_videos)} have a video vision_path). Split into separate batches." ) - images: list[Image.Image] | None = cast(list[Image.Image], raw_images) if n_set == len(raw_images) else None + images: list[Image.Image] | None = cast(list[Image.Image], raw_images) if n_img == len(raw_images) else None + videos: list[dict[str, Any]] | None = ( + cast(list[dict[str, Any]], raw_videos) if raw_videos is not None and n_vid == len(raw_videos) else None + ) try: with sync_distributed_errors(): @@ -1686,6 +1734,7 @@ def _generate_reasoner_batch( prompts, max_new_tokens=sample_args_list[0].max_new_tokens, images=images, + videos=videos, do_sample=sample_args_list[0].do_sample, temperature=sample_args_list[0].temperature, top_k=sample_args_list[0].top_k, diff --git a/cosmos_framework/inference/inference_test.py b/cosmos_framework/inference/inference_test.py index d1f501b..b8ec16b 100644 --- a/cosmos_framework/inference/inference_test.py +++ b/cosmos_framework/inference/inference_test.py @@ -169,6 +169,7 @@ def _make_reasoner_sample_args(**overrides: Any) -> SimpleNamespace: model_mode=ModelMode.REASONER, prompt="Describe a robotic arm.", vision_path=None, + video_fps=None, max_new_tokens=8, do_sample=False, temperature=1.0, @@ -189,7 +190,11 @@ def test_get_sample_data_reasoner_text_only() -> None: out = inference.get_sample_data(sample_args, model, device="cpu") - assert out == {"caption": ["Describe a robotic arm."], "reasoner_images": [None]} + assert out == { + "caption": ["Describe a robotic arm."], + "reasoner_images": [None], + "reasoner_videos": [None], + } @pytest.mark.L0 @@ -205,13 +210,35 @@ def test_get_sample_data_reasoner_with_image(tmp_path: Path) -> None: out = inference.get_sample_data(sample_args, model, device="cpu") - assert list(out) == ["caption", "reasoner_images"] + assert list(out) == ["caption", "reasoner_images", "reasoner_videos"] assert out["caption"] == ["Describe a robotic arm."] + assert out["reasoner_videos"] == [None] assert len(out["reasoner_images"]) == 1 assert out["reasoner_images"][0].size == (8, 8) assert out["reasoner_images"][0].mode == "RGB" +@pytest.mark.L0 +def test_get_sample_data_reasoner_with_video(monkeypatch: pytest.MonkeyPatch) -> None: + """A video ``vision_path`` routes through ``_decode_reasoner_video`` into ``reasoner_videos``. + + The decoder is monkeypatched (real decode needs torchvision + an actual clip); + this asserts the routing/contract, not the decode itself.""" + from cosmos_framework.inference import inference + + decoded = {"frames": ["F0", "F1"], "fps": 2.0} + monkeypatch.setattr(inference, "_decode_reasoner_video", lambda path, fps: decoded) + model = SimpleNamespace(input_caption_key="caption") + sample_args = _make_reasoner_sample_args(vision_path="/tmp/clip.mp4", video_fps=2.0) + + out = inference.get_sample_data(sample_args, model, device="cpu") + + assert out["caption"] == ["Describe a robotic arm."] + assert out["reasoner_videos"] == [decoded] + assert out["reasoner_images"] == [None] + assert "video_sampling_kwargs" not in out + + @pytest.mark.L0 def test_reasoner_defaults_json_round_trip() -> None: import json as _json @@ -349,3 +376,5 @@ def test_reasoner_defaults_validate_against_overrides() -> None: filtered = {k: v for k, v in defaults.items() if k in OmniSampleOverrides.model_fields} assert set(defaults) - set(filtered) == set(), f"defaults has unknown fields: {set(defaults) - set(filtered)}" OmniSampleOverrides.model_validate(filtered) + + diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 03f0c3f..c643629 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -276,6 +276,8 @@ def generate_reasoner_text( *, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, eos_token_id: int | list[int] | None = None, pad_token_id: int | None = None, @@ -296,9 +298,11 @@ def generate_reasoner_text( prompts through this single entry point: pass ``pixel_values`` + ``image_grid_thw`` (and optionally ``attention_mask``) for image-conditioned prefill via the Qwen3-VL - visual encoder, or omit them for text-only prefill. Uses the - und-pathway weights (those WITHOUT the ``_moe_gen`` suffix) plus - ``embed_tokens`` / ``norm`` / ``lm_head``; the generation pathway + visual encoder, or omit them for text-only prefill. Video + conditioning is also supported via ``pixel_values_videos`` + + ``video_grid_thw``; the image and video pairs are mutually exclusive. + Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix) + plus ``embed_tokens`` / ``norm`` / ``lm_head``; the generation pathway and all VFM-level multimodal embedders / heads (``vae2llm``, ``llm2vae``, ``sound2llm``, etc.) are bypassed. @@ -327,6 +331,8 @@ def generate_reasoner_text( max_new_tokens=max_new_tokens, pixel_values=pixel_values, image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, attention_mask=attention_mask, eos_token_id=eos_token_id, pad_token_id=pad_token_id, diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 4908e7a..673d0a0 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -1494,6 +1494,8 @@ def _impl_generate_reasoner_text( *, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, eos_token_id: int | list[int] | None = None, pad_token_id: int | None = None, @@ -1550,10 +1552,9 @@ def _impl_generate_reasoner_text( ``Qwen3VLProcessor`` emits — pass it through unchanged. Moved to the prompt's device internally. ``None`` (default) means text-only prompt; in that case the multimodal prefill - path is skipped entirely. Videos are *not* supported here — - this function has no ``pixel_values_videos`` / ``video_grid_thw`` - parameters; for I2V conditioning, frames must be passed as - images. + path is skipped entirely. For video conditioning, pass ``pixel_values_videos`` + + ``video_grid_thw`` instead (mutually exclusive with the image + pair). image_grid_thw: Optional ``[num_images, 3]`` long tensor giving ``(t, h, w)`` — the temporal / height / width feature-grid size per image as produced by ``Qwen3VLProcessor`` (``t`` is @@ -1643,11 +1644,15 @@ def _impl_generate_reasoner_text( if (pixel_values is None) != (image_grid_thw is None): raise ValueError("pixel_values and image_grid_thw must be provided together.") + if (pixel_values_videos is None) != (video_grid_thw is None): + raise ValueError("pixel_values_videos and video_grid_thw must be provided together.") + if pixel_values is not None and pixel_values_videos is not None: + raise ValueError("Reasoner conditions on one medium at a time: pass image OR video, not both.") _prefill_start = time.time() mrope_position_deltas: torch.Tensor | None = None - if pixel_values is None: + if pixel_values is None and pixel_values_videos is None: hidden = model.reasoner_forward(input_ids, cache=cache) # [B,T_prompt,hidden_size] else: if not hasattr(causal_lm, "visual"): @@ -1663,6 +1668,8 @@ def _impl_generate_reasoner_text( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, attention_mask=attention_mask, ) hidden = model.reasoner_forward( @@ -1936,6 +1943,8 @@ def generate_reasoner_text( *, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, eos_token_id: int | list[int] | None = None, pad_token_id: int | None = None, @@ -1956,6 +1965,8 @@ def generate_reasoner_text( the Qwen3-VL visual encoder; omit them for text-only prefill. The two arguments are mutually required: passing exactly one raises ``ValueError`` inside :func:`_impl_generate_reasoner_text`. + Video conditioning is also supported via ``pixel_values_videos`` + + ``video_grid_thw``; the image and video pairs are mutually exclusive. Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix) plus the model-level ``embed_tokens`` / ``norm`` / ``lm_head``, and — @@ -1970,6 +1981,8 @@ def generate_reasoner_text( max_new_tokens=max_new_tokens, pixel_values=pixel_values, image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, attention_mask=attention_mask, eos_token_id=eos_token_id, pad_token_id=pad_token_id, @@ -2064,6 +2077,8 @@ def generate_reasoner_text( *, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, eos_token_id: int | list[int] | None = None, pad_token_id: int | None = None, @@ -2084,6 +2099,8 @@ def generate_reasoner_text( the Qwen3-VL visual encoder; omit them for text-only prefill. The two arguments are mutually required: passing exactly one raises ``ValueError`` inside :func:`_impl_generate_reasoner_text`. + Video conditioning is also supported via ``pixel_values_videos`` + + ``video_grid_thw``; the image and video pairs are mutually exclusive. Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix) plus the model-level ``embed_tokens`` / ``norm`` / ``lm_head``, and — @@ -2099,6 +2116,8 @@ def generate_reasoner_text( max_new_tokens=max_new_tokens, pixel_values=pixel_values, image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, attention_mask=attention_mask, eos_token_id=eos_token_id, pad_token_id=pad_token_id, diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 9f71a00..429cbfc 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -3763,6 +3763,7 @@ def generate_reasoner_text( max_new_tokens: int, *, images: list[Any] | None = None, + videos: list[Any] | None = None, prompt_builder: Callable[[str], list[dict[str, Any]]] | None = None, do_sample: bool = False, temperature: float | None = 1.0, @@ -3779,8 +3780,10 @@ def generate_reasoner_text( (or wraps the prompt as a single user message when no callback is given), (b) tokenizes it — text-only via :meth:`tokenize_text`, or multimodal via ``self.vlm_processor.apply_chat_template`` when - ``images`` is supplied (which lowers the chat into ``input_ids``, - ``attention_mask``, ``pixel_values``, and ``image_grid_thw``), (c) + ``images`` or ``videos`` is supplied (the image path lowers the chat + into ``input_ids``, ``attention_mask``, ``pixel_values``, and + ``image_grid_thw``; the video path yields ``pixel_values_videos`` and + ``video_grid_thw`` instead), (c) runs the reasoner-only AR decode loop through ``self.net.generate_reasoner_text`` (the lower-level token-driven pass-through that delegates to ``unified_mot._impl_generate_reasoner_text``), @@ -3835,6 +3838,14 @@ def generate_reasoner_text( ``processor.apply_chat_template``, so any input it accepts works (file path ``str``, ``PIL.Image.Image``, ``np.ndarray``, or a CHW / HWC tensor). + videos: Optional per-prompt conditioning videos (mutually + exclusive with ``images``). Each entry must be a + ``{"frames": [...PIL...], "fps": float}`` payload + (pre-decoded by the caller, e.g. via + ``_decode_reasoner_video``). The frames list and fps are + forwarded into the ``{"type": "video", "video": frames, + "fps": fps}`` chat block so the processor produces + ``pixel_values_videos`` / ``video_grid_thw``. prompt_builder: Optional callback that maps a raw prompt string to a chat-style messages list (e.g. :func:`projects.cosmos3.vfm.upsampler.prompts.build_messages` @@ -3895,32 +3906,40 @@ def generate_reasoner_text( Raises: ValueError: If ``images`` length does not match ``inputs`` - length. - RuntimeError: If ``images`` is supplied but the live VLM - processor does not implement ``apply_chat_template`` - (i.e., the VLM is configured as text-only). + length, or if ``videos`` length does not match ``inputs`` + length. Also raised if both ``images`` and ``videos`` are + supplied simultaneously (only one medium is allowed per + call). + RuntimeError: If ``images`` or ``videos`` is supplied but the + live VLM processor does not implement + ``apply_chat_template`` (i.e., the VLM is configured as + text-only). """ # Decide whether the multimodal flow is in play, and validate the # image-list contract here so the failure happens before any # decoding work — far easier to debug than a downstream # ``apply_chat_template`` error. - use_multimodal = images is not None + if images is not None and videos is not None: + raise ValueError("generate_reasoner_text conditions on one medium at a time: pass `images` OR `videos`, not both.") + use_image = images is not None + use_video = videos is not None + use_multimodal = use_image or use_video + media = images if use_image else videos if use_multimodal: - assert images is not None # narrowed by `use_multimodal` - if len(images) != len(inputs): + assert media is not None # narrowed by `use_multimodal` + if len(media) != len(inputs): raise ValueError( - f"generate_reasoner_text: `images` length ({len(images)}) " + f"generate_reasoner_text: media length ({len(media)}) " f"must equal `inputs` length ({len(inputs)}) for the " - "image-conditioned flow." + "vision-conditioned flow." ) if not callable(getattr(self.vlm_processor, "apply_chat_template", None)): raise RuntimeError( - "generate_reasoner_text(images=...) requires a multimodal " + "generate_reasoner_text(images=/videos=...) requires a multimodal " "VLM processor (e.g. Qwen3VLProcessor) but the live processor " f"{type(self.vlm_processor).__name__!r} does not implement " "apply_chat_template — the live VLM is configured as text-only." ) - # Resolve EOS / pad ids internally so callers don't have to know # about VLM-specific id wiring. EOS comes from the cached VLM # special-tokens dict (set in ``set_up_tokenizers``); pad mirrors @@ -3957,22 +3976,18 @@ def generate_reasoner_text( messages = [{"role": "user", "content": prompt}] if use_multimodal: - assert images is not None # narrowed by `use_multimodal` - # Replace the LAST user message's content with a Qwen3-VL - # multimodal block (image + text). Earlier messages - # (system, prior turns) are kept verbatim so any chat - # scaffolding the callback added still governs the - # assistant response. + assert media is not None # narrowed by `use_multimodal` last_user = messages[-1] last_text = last_user["content"] if isinstance(last_user.get("content"), str) else "" + if use_video: + media_item: dict[str, Any] = {"type": "video", "video": media[idx]["frames"], "fps": media[idx]["fps"]} + else: + media_item = {"type": "image", "image": media[idx]} multimodal_messages = list(messages[:-1]) multimodal_messages.append( { "role": "user", - "content": [ - {"type": "image", "image": images[idx]}, - {"type": "text", "text": last_text}, - ], + "content": [media_item, {"type": "text", "text": last_text}], } ) processor_inputs = self.vlm_processor.apply_chat_template( @@ -3981,31 +3996,48 @@ def generate_reasoner_text( add_generation_prompt=True, return_tensors="pt", ) - # ``Qwen3VLProcessor.apply_chat_template`` strips the - # leading batch dim from ``input_ids`` / ``attention_mask`` - # (see its inline comment); restore it so the inner - # token-level call sees ``[B=1, T_prompt]``. inner_input_ids = processor_inputs["input_ids"].to(device).unsqueeze(0) inner_attention_mask = processor_inputs["attention_mask"].to(device).unsqueeze(0) - inner_pixel_values = processor_inputs["pixel_values"].to(device) # [N_patches,C,H,W] - inner_image_grid_thw = processor_inputs["image_grid_thw"].to(device) # [num_images,3] - out_ids = self.net.generate_reasoner_text( - input_ids=inner_input_ids, - max_new_tokens=max_new_tokens, - pixel_values=inner_pixel_values, - image_grid_thw=inner_image_grid_thw, - attention_mask=inner_attention_mask, - eos_token_id=eos_id, - pad_token_id=pad_id, - do_sample=do_sample, - temperature=temperature if temperature is not None else 1.0, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - presence_penalty=presence_penalty, - seed=seed, - return_only_new_tokens=True, - ) + if use_video: + inner_pixel_values_videos = processor_inputs["pixel_values_videos"].to(device) + inner_video_grid_thw = processor_inputs["video_grid_thw"].to(device) + out_ids = self.net.generate_reasoner_text( + input_ids=inner_input_ids, + max_new_tokens=max_new_tokens, + pixel_values_videos=inner_pixel_values_videos, + video_grid_thw=inner_video_grid_thw, + attention_mask=inner_attention_mask, + eos_token_id=eos_id, + pad_token_id=pad_id, + do_sample=do_sample, + temperature=temperature if temperature is not None else 1.0, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + presence_penalty=presence_penalty, + seed=seed, + return_only_new_tokens=True, + ) + else: + inner_pixel_values = processor_inputs["pixel_values"].to(device) # [N_patches,C,H,W] + inner_image_grid_thw = processor_inputs["image_grid_thw"].to(device) # [num_images,3] + out_ids = self.net.generate_reasoner_text( + input_ids=inner_input_ids, + max_new_tokens=max_new_tokens, + pixel_values=inner_pixel_values, + image_grid_thw=inner_image_grid_thw, + attention_mask=inner_attention_mask, + eos_token_id=eos_id, + pad_token_id=pad_id, + do_sample=do_sample, + temperature=temperature if temperature is not None else 1.0, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + presence_penalty=presence_penalty, + seed=seed, + return_only_new_tokens=True, + ) else: # Text-only path. Pull the system prompt (if any) and # the last user message text out of the messages list, diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py index 5a8ecff..4768b08 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py @@ -497,8 +497,10 @@ def get_placeholder_mask( def prepare_multimodal_reasoner_inputs( causal_lm: Any, input_ids: torch.Tensor, # [B,T_prompt] - pixel_values: torch.Tensor, # [N_patches,C,H,W] - image_grid_thw: torch.Tensor, # [num_images,3] + pixel_values: torch.Tensor | None = None, # [N_patches,C,H,W] + image_grid_thw: torch.Tensor | None = None, # [num_images,3] + pixel_values_videos: torch.Tensor | None = None, # [N_patches,C,H,W] + video_grid_thw: torch.Tensor | None = None, # [num_videos,3] attention_mask: Optional[torch.Tensor] = None, ) -> tuple[ torch.Tensor, # inputs_embeds [B,T_prompt,hidden_size] @@ -525,11 +527,11 @@ def prepare_multimodal_reasoner_inputs( ``*TextModel.reasoner_forward`` instead of HF's full ``self.language_model(...)`` forward, so HF's ``past_key_values`` / ``cache_position`` lifecycle is replaced by - the AR loop's :class:`ReasonerKVCache` lifecycle. Videos and - dual image+video paths are not supported here; only - ``image_grid_thw`` is consumed — matching the public - ``generate_reasoner_text`` API, which has no - ``pixel_values_videos`` / ``video_grid_thw`` parameters. + the AR loop's :class:`ReasonerKVCache` lifecycle. Either the + image pair (``pixel_values`` + ``image_grid_thw``) or the + video pair (``pixel_values_videos`` + ``video_grid_thw``) is consumed — + not both. The video recipe mirrors the image recipe but routes through + the video placeholder mask and ``video_grid_thw`` rope index. Validation: ``get_placeholder_mask`` raises ``ValueError`` if the number of image placeholder tokens in ``input_ids`` does not match @@ -574,21 +576,36 @@ def prepare_multimodal_reasoner_inputs( mrope_position_deltas: Per-sample rope delta used by the caller to extend positions during decode. """ + is_video = pixel_values_videos is not None inputs_embeds = causal_lm.model.embed_tokens(input_ids).clone() # [B,T_prompt,hidden_size] - pixel_values = pixel_values.to(device=inputs_embeds.device) - image_grid_thw = image_grid_thw.to(device=inputs_embeds.device) - - image_embeds, deepstack_visual_embeds = get_image_features(causal_lm, pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) - image_mask, _video_mask = get_placeholder_mask( - causal_lm, - input_ids, - inputs_embeds=inputs_embeds, - image_features=image_embeds, - ) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # [B,T_prompt,hidden_size] - visual_pos_masks = image_mask[..., 0] # [B,T_prompt] + if is_video: + pixel_values_videos = pixel_values_videos.to(device=inputs_embeds.device) + video_grid_thw = video_grid_thw.to(device=inputs_embeds.device) + # get_video_features == get_image_features (same visual tower); reuse the free helper. + video_embeds, deepstack_visual_embeds = get_image_features(causal_lm, pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) + _image_mask, video_mask = get_placeholder_mask( + causal_lm, + input_ids, + inputs_embeds=inputs_embeds, + video_features=video_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # [B,T_prompt,hidden_size] + visual_pos_masks = video_mask[..., 0] # [B,T_prompt] + else: + pixel_values = pixel_values.to(device=inputs_embeds.device) + image_grid_thw = image_grid_thw.to(device=inputs_embeds.device) + image_embeds, deepstack_visual_embeds = get_image_features(causal_lm, pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) + image_mask, _video_mask = get_placeholder_mask( + causal_lm, + input_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # [B,T_prompt,hidden_size] + visual_pos_masks = image_mask[..., 0] # [B,T_prompt] deepstack_visual_embeds = [ embed.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) for embed in deepstack_visual_embeds @@ -597,7 +614,8 @@ def prepare_multimodal_reasoner_inputs( position_ids, mrope_position_deltas = get_rope_index( causal_lm, input_ids=input_ids, - image_grid_thw=image_grid_thw, + image_grid_thw=None if is_video else image_grid_thw, + video_grid_thw=video_grid_thw if is_video else None, attention_mask=attention_mask, ) diff --git a/docs/inference.md b/docs/inference.md index 1580147..aa46b75 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -19,6 +19,7 @@ ______________________________________________________________________ - [Sample Arguments](#sample-arguments) - [Text](#text) - [Vision (Image/Video)](#vision-imagevideo) + - [Reasoner](#reasoner) - [Action](#action) - [Custom Defaults](#custom-defaults) - [Guardrails](#guardrails) @@ -196,6 +197,14 @@ Generation arguments: Outputs `vision.jpg` or `vision.mp4` depending on `num_frames`. +### Reasoner + +For `model_mode=reasoner`, `vision_path` may point to an **image** (`.jpg`/`.png`/…) or a **video** (`.mp4`). A video is decoded into frames using the dataloader's canonical decode path and then passed to the Qwen3-VL processor. + +- `video_fps`: frames per second to sample from the video (default: the decoder's default of 2.0). + +Example: [`inputs/reasoner/reasoner_video.json`](../inputs/reasoner/reasoner_video.json). + ### Action Common arguments: diff --git a/inputs/reasoner/reasoner_video.json b/inputs/reasoner/reasoner_video.json new file mode 100644 index 0000000..1c68308 --- /dev/null +++ b/inputs/reasoner/reasoner_video.json @@ -0,0 +1,5 @@ +{ + "model_mode": "reasoner", + "prompt": "Describe what is happening in this video in one sentence.", + "vision_path": "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/2b17a2413bd86b2cf9b03823637108851e4ddf2d/inputs/vision/robot_pouring.mp4" +}