-
Notifications
You must be signed in to change notification settings - Fork 47
[Bug] Cosmos Predict2.5 inference quality mismatch with official codebase — wrong sigma schedule #18
Description
Cosmos Predict2 inference produces degraded results compared to official Cosmos codebase
Summary
When running Cosmos-Predict2.5-2B video2world inference with FastGen using the same model checkpoint, text prompt, input image, resolution, and step count, the generated video quality is noticeably worse than the official cosmos-predict2.5 inference code. After comparing the two codebases, I identified three misalignments in the sample() method of CosmosPredict2DiT (fastgen/networks/cosmos_predict2/network.py).
Reproduction
python scripts/inference/video_model_inference.py \
--config fastgen/configs/experiments/CosmosPredict2/config_sft.py \
--do_student_sampling False --num_steps 35 --fps 24 \
--neg_prompt_file scripts/inference/prompts/negative_prompt_cosmos.txt \
--input_image_file scripts/inference/prompts/source_image_paths.txt \
--num_conditioning_frames 1 \
- model.guidance_scale=5.0 model.net.is_video2world=True model.input_shape="[16, 21, 60, 104]"Issue 1: Wrong sigma/timestep schedule
The most impactful difference. The official Cosmos uses a Karras sigma schedule, while FastGen uses the default diffusers flow-shift linear schedule.
Official Cosmos (FlowUniPCMultistepScheduler.set_timesteps with use_kerras_sigma=True):
sigma_max = 200
sigma_min = 0.01
rho = 7
sigmas = np.arange(num_steps + 1) / num_steps
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho)) ** rho
sigmas = sigmas / (1 + sigmas)
timesteps = sigmas * 1000 # cast to int64 (truncation, not rounding)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)FastGen (network.py line ~1153):
self.sample_scheduler = UniPCMultistepScheduler(
num_train_timesteps=1000,
prediction_type="flow_prediction",
use_flow_sigmas=True,
flow_shift=shift,
)
self.sample_scheduler.set_timesteps(num_inference_steps=num_steps, device=noise.device)This internally computes:
sigmas = np.linspace(1.0, 1/N, N)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # shift=5.0When I use the Karras sigma schedule in fastgen repo, the performance improves a lot but is still not as good as the cosmos original results.
So do you have any idea what other factors would cause the degradation?
Issue 2: Misleading example command in scripts/inference/video_model_inference.py
In the scripts/inference/video_model_inference.py, there exists an example command which not indicate the neg_prompt_file thus the code would import another neg_prompt_file rather than the neg_prompt_file for the cosmos model, you might also want to fix this :)
# Video2World: Cosmos Predict2
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\
scripts/inference/video_model_inference.py --do_student_sampling False \\
--input_image_file scripts/inference/prompts/source_image_paths.txt --num_conditioning_frames 1 \\
--config fastgen/configs/experiments/CosmosPredict2/config_sft.py \\
- trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 model.net.is_video2world=True \\
log_config.name=cosmos_v2w_inferenceSuggested Fix
In fastgen/networks/cosmos_predict2/network.py, the sample() method should:
- Use the Karras sigma schedule matching the official Cosmos
set_timestepswithuse_kerras_sigma=True
# Initialize internal state
self.sample_scheduler.set_timesteps(num_inference_steps=num_steps, device=noise.device)
# 2. Override with Karras sigma schedule
import numpy as np
sigma_max, sigma_min, rho = 200, 0.01, 7
sigmas = np.arange(num_steps + 1) / num_steps
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho)) ** rho
sigmas = sigmas / (1 + sigmas)
timesteps_np = sigmas * self.sample_scheduler.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sample_scheduler.sigmas = torch.from_numpy(sigmas).to("cpu")
self.sample_scheduler.timesteps = torch.from_numpy(timesteps_np).to(device=noise.device, dtype=torch.int64)
self.sample_scheduler.num_inference_steps = len(timesteps_np)
self.sample_scheduler.model_outputs = [None] * self.sample_scheduler.config.solver_order
self.sample_scheduler.lower_order_nums = 0
self.sample_scheduler.last_sample = None
self.sample_scheduler._step_index = None
self.sample_scheduler._begin_index = None- align those example commands in your scripts with those in the readme