Bug Description
Full-causal diffusion forcing in Cosmos3 currently requires a NATTEN version that does not appear to be publicly available. I was wondering whether there are plans to release the dev version NATTEN in the near future, or if there are workarounds that can be utilized?
The dependency chain is:
- In
cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py, video_temporal_causal=True requires joint_attn_implementation="three_way".
- The
three_way + video_temporal_causal path calls generate_temporal_causal_natten_metadata() from the attention setup path.
generate_temporal_causal_natten_metadata() calls generate_multi_dim_varlen_parameters().
generate_multi_dim_varlen_parameters() hard-gates on NATTEN >= 0.21.9.dev0 via NATTEN_VARLEN_MULTI_DIM_VERSION = "0.21.9.dev0".
- Public NATTEN appears to top out at
0.21.6.
As a result, the documented/config-supported full-causal diffusion forcing mode cannot run with the public Cosmos Framework dependency set.
Reproduction Steps
Run the Vision SFT recipe with full-causal diffusion forcing enabled via config overrides:
export DATASET_PATH="$PWD/examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge"
export BASE_CHECKPOINT_PATH="$PWD/checkpoints/Cosmos3-Nano"
export WAN_VAE_PATH="$PWD/checkpoints/Wan2.2_VAE.pth"
torchrun --nproc_per_node=8 -m cosmos_framework.scripts.train \
--sft-toml=examples/toml/sft_config/vision_sft_nano.toml \
model.config.causal_training_strategy=diffusion_forcing \
model.config.video_temporal_causal=True \
model.config.joint_attn_implementation=three_way
Reproducibility:
Bug Description
Full-causal diffusion forcing in Cosmos3 currently requires a NATTEN version that does not appear to be publicly available. I was wondering whether there are plans to release the dev version NATTEN in the near future, or if there are workarounds that can be utilized?
The dependency chain is:
cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py,video_temporal_causal=Truerequiresjoint_attn_implementation="three_way".three_way + video_temporal_causalpath callsgenerate_temporal_causal_natten_metadata()from the attention setup path.generate_temporal_causal_natten_metadata()callsgenerate_multi_dim_varlen_parameters().generate_multi_dim_varlen_parameters()hard-gates onNATTEN >= 0.21.9.dev0viaNATTEN_VARLEN_MULTI_DIM_VERSION = "0.21.9.dev0".0.21.6.As a result, the documented/config-supported full-causal diffusion forcing mode cannot run with the public Cosmos Framework dependency set.
Reproduction Steps
Run the Vision SFT recipe with full-causal diffusion forcing enabled via config overrides:
Reproducibility:
training_stepreaches thethree_way + video_temporal_causalattention path