Skip to content

[Bug] Cosmos Predict2.5 inference quality mismatch with official codebase — wrong sigma schedule #18

@csy2077

Description

@csy2077

Cosmos Predict2 inference produces degraded results compared to official Cosmos codebase

Summary

When running Cosmos-Predict2.5-2B video2world inference with FastGen using the same model checkpoint, text prompt, input image, resolution, and step count, the generated video quality is noticeably worse than the official cosmos-predict2.5 inference code. After comparing the two codebases, I identified three misalignments in the sample() method of CosmosPredict2DiT (fastgen/networks/cosmos_predict2/network.py).

Reproduction

python scripts/inference/video_model_inference.py \
    --config fastgen/configs/experiments/CosmosPredict2/config_sft.py \
    --do_student_sampling False --num_steps 35 --fps 24 \
    --neg_prompt_file scripts/inference/prompts/negative_prompt_cosmos.txt \
    --input_image_file scripts/inference/prompts/source_image_paths.txt \
    --num_conditioning_frames 1 \
    - model.guidance_scale=5.0 model.net.is_video2world=True model.input_shape="[16, 21, 60, 104]"

Issue 1: Wrong sigma/timestep schedule

The most impactful difference. The official Cosmos uses a Karras sigma schedule, while FastGen uses the default diffusers flow-shift linear schedule.

Official Cosmos (FlowUniPCMultistepScheduler.set_timesteps with use_kerras_sigma=True):

sigma_max = 200
sigma_min = 0.01
rho = 7
sigmas = np.arange(num_steps + 1) / num_steps
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho)) ** rho
sigmas = sigmas / (1 + sigmas)
timesteps = sigmas * 1000  # cast to int64 (truncation, not rounding)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)

FastGen (network.py line ~1153):

self.sample_scheduler = UniPCMultistepScheduler(
    num_train_timesteps=1000,
    prediction_type="flow_prediction",
    use_flow_sigmas=True,
    flow_shift=shift,
)
self.sample_scheduler.set_timesteps(num_inference_steps=num_steps, device=noise.device)

This internally computes:

sigmas = np.linspace(1.0, 1/N, N)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)  # shift=5.0

When I use the Karras sigma schedule in fastgen repo, the performance improves a lot but is still not as good as the cosmos original results.

So do you have any idea what other factors would cause the degradation?

Issue 2: Misleading example command in scripts/inference/video_model_inference.py

In the scripts/inference/video_model_inference.py, there exists an example command which not indicate the neg_prompt_file thus the code would import another neg_prompt_file rather than the neg_prompt_file for the cosmos model, you might also want to fix this :)

# Video2World: Cosmos Predict2
    PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\
        scripts/inference/video_model_inference.py --do_student_sampling False \\
        --input_image_file scripts/inference/prompts/source_image_paths.txt --num_conditioning_frames 1 \\
        --config fastgen/configs/experiments/CosmosPredict2/config_sft.py \\
        - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 model.net.is_video2world=True \\
        log_config.name=cosmos_v2w_inference

Suggested Fix

In fastgen/networks/cosmos_predict2/network.py, the sample() method should:

  1. Use the Karras sigma schedule matching the official Cosmos set_timesteps with use_kerras_sigma=True
# Initialize internal state
self.sample_scheduler.set_timesteps(num_inference_steps=num_steps, device=noise.device)

# 2. Override with Karras sigma schedule
import numpy as np
sigma_max, sigma_min, rho = 200, 0.01, 7
sigmas = np.arange(num_steps + 1) / num_steps
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho)) ** rho
sigmas = sigmas / (1 + sigmas)
timesteps_np = sigmas * self.sample_scheduler.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sample_scheduler.sigmas = torch.from_numpy(sigmas).to("cpu")
self.sample_scheduler.timesteps = torch.from_numpy(timesteps_np).to(device=noise.device, dtype=torch.int64)
self.sample_scheduler.num_inference_steps = len(timesteps_np)
self.sample_scheduler.model_outputs = [None] * self.sample_scheduler.config.solver_order
self.sample_scheduler.lower_order_nums = 0
self.sample_scheduler.last_sample = None
self.sample_scheduler._step_index = None
self.sample_scheduler._begin_index = None
  1. align those example commands in your scripts with those in the readme

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions