Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2a6b1d1
unused
kschmid23 Mar 9, 2026
64534f2
unused
kschmid23 Mar 9, 2026
6b5ac71
down
kschmid23 Mar 9, 2026
22969c5
down
kschmid23 Mar 9, 2026
df4c8ba
stats
kschmid23 Mar 9, 2026
c0146b0
dowmloads models
kschmid23 Mar 9, 2026
d59939e
dowmloads models
kschmid23 Mar 9, 2026
c4ab0d8
down
kschmid23 Mar 9, 2026
b97d6b2
down
kschmid23 Mar 9, 2026
ff1d447
stats
kschmid23 Mar 9, 2026
f7125ae
stats
kschmid23 Mar 9, 2026
99999d5
stats
kschmid23 Mar 9, 2026
760ca16
stats
kschmid23 Mar 9, 2026
f98b6ad
stats
kschmid23 Mar 9, 2026
2399567
stats
kschmid23 Mar 9, 2026
f4a2c7a
down
kschmid23 Mar 9, 2026
4c4e504
down
kschmid23 Mar 10, 2026
a07c116
down
kschmid23 Mar 10, 2026
b455542
dist
kschmid23 Mar 10, 2026
8f080db
dist
kschmid23 Mar 10, 2026
bb997f0
dist
kschmid23 Mar 10, 2026
8b7348f
dist
kschmid23 Mar 10, 2026
4f01e0b
Fix MemoryError: use load_file instead of from_pretrained, keep model…
kschmid23 Mar 10, 2026
cc191a5
fix: use strict=False in load_state_dict for YUME custom params
kschmid23 Mar 10, 2026
7242121
fix: use os.path.join for Windows-compatible path
kschmid23 Mar 10, 2026
2498896
loading
kschmid23 Mar 10, 2026
7008def
default
kschmid23 Mar 10, 2026
5f9e46f
fix: set default device to cpu before InternVL load to avoid get_defa…
kschmid23 Mar 11, 2026
1d3ca30
fix: load safetensors checkpoint directly to device, skip CPU buffer
kschmid23 Mar 11, 2026
be8b7f9
fix: remove low_cpu_mem_usage from InternVL load to avoid meta tensor…
kschmid23 Mar 11, 2026
26631ef
works
kschmid23 Mar 12, 2026
cbb1590
wan
kschmid23 Mar 12, 2026
75ca7b5
req
kschmid23 Mar 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
*.mp4
**/__pycache__/
**/*.pyc
**/*.pyo

InternVL3-2B-Instruct/
Yume-5B-720P/
_vbench_caption.txt
_vbench_tmp/
outputs/
requirements-cog.txt
temp_caption_3th person.txt
temp_caption_bus.txt
temp_caption_default.txt
temp_caption_gc.txt
temp_caption_kitchen.txt
14 changes: 14 additions & 0 deletions download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Download Yume model weights from HuggingFace."""

from huggingface_hub import snapshot_download

REPOS = [
"stdstu123/Yume-5B-720P",
"OpenGVLab/InternVL3-2B-Instruct",
]

for repo_id in REPOS:
local_dir = f"./{repo_id.split('/')[-1]}"
print(f"Downloading {repo_id} -> {local_dir}")
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print(f"Done: {local_dir}")
Binary file added fastvideo/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
9 changes: 8 additions & 1 deletion fastvideo/models/mochi_hf/modeling_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
scale_lora_layers, unscale_lora_layers)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
try:
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
except ModuleNotFoundError:
import torch.nn.functional as _F
class LigerSiLUMulFunction:
@staticmethod
def apply(gate, hidden_states):
return _F.silu(gate) * hidden_states

from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad
from fastvideo.models.mochi_hf.norm import (MochiLayerNormContinuous,
Expand Down
3 changes: 2 additions & 1 deletion fastvideo/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,8 @@ def main(args):
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
backend = "gloo" if os.name == "nt" else "nccl"
dist.init_process_group(backend, rank=rank, world_size=world_size)

# Set independent cache directories for each rank
os.environ["TRITON_CACHE_DIR"] = f"/tmp/triton_cache_{rank}"
Expand Down
119 changes: 81 additions & 38 deletions fastvideo/sample/sample_5b.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import math
import os
import re
import sys
import torchvision
import time
Expand Down Expand Up @@ -867,7 +868,7 @@ def sample_one(


# Generate diverse output videos from identical input conditions
max_area = 704 * 1280
max_area = args.height * args.width
# pixel_values_vid = torch.nn.functional.interpolate(pixel_values_vid, size=(544, 960), mode='bilinear', align_corners=False)

repeat_nums = 1
Expand All @@ -889,29 +890,34 @@ def sample_one(

frame = model_input.shape[1]

main_print(f"[SAMPLE] VAE encoding input ({model_input.shape}) ...")
model_input = torch.cat([wan_i2v.vae.encode([model_input.to(device)[:,:-32].to(device)])[0], \
wan_i2v.vae.encode([model_input.to(device)[:,-32:].to(device)])[0]],dim=1)
wan_i2v.vae.encode([model_input.to(device)[:,-32:].to(device)])[0]],dim=1)
main_print(f"[SAMPLE] VAE encode done -> latent shape {model_input.shape}")

latents = model_input

img = model_input[:,:-latent_frame_zero]


main_print(f"[SAMPLE] wan_i2v.generate (i2v, frame_num={frame}) ...")
with torch.no_grad():
arg_c, arg_null, noise, mask2, img = wan_i2v.generate(
caption[0],
frame_num=frame,
max_area=max_area,
latent_frame_zero=latent_frame_zero,
img=img)
main_print("[SAMPLE] wan_i2v.generate done")
else:
frame = 32
main_print(f"[SAMPLE] wan_i2v.generate (t2v, frame_num={frame}) ...")
with torch.no_grad():
arg_c, arg_null, noise = wan_i2v.generate(
caption[0],
frame_num=32,
max_area=max_area,
latent_frame_zero=latent_frame_zero,)
main_print("[SAMPLE] wan_i2v.generate done")



Expand Down Expand Up @@ -949,15 +955,15 @@ def sample_one(

import time
start_time = time.time()

main_print(f"[SAMPLE] Denoising step_sample={step_sample}/{sample_num-1} steps={sample_step} ...")

if not t2v or step_sample > 0:
latent = torch.cat([img[0][:, :-latent_frame_zero, :, :], latent[:, -latent_frame_zero:, :, :]], dim=1)
#(1. - mask2[0]) * img[0] + mask2[0] * latent
print(latent.shape, "nbxkasbcna090-")
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):

for i in range(sample_step):
for i in tqdm(range(sample_step), desc="Sampling", unit="step"):
latent_model_input = [latent.squeeze(0)]

if not t2v or step_sample>0:
Expand All @@ -981,7 +987,6 @@ def sample_one(
# ])
# timestep = temp_ts.unsqueeze(0)

print(latent_model_input[0].shape,"0-2=ffje0r=----------a")
noise_pred_cond = transformer(latent_model_input, t=timestep, **arg_c)[0]

if i+1 == sample_step:
Expand Down Expand Up @@ -1010,7 +1015,6 @@ def sample_one(
# timestep = torch.stack(timestep)
# temp_ts = timestep.flatten()
# timestep = temp_ts#.unsqueeze(0)
print(latent_model_input[0].shape,"0-2=ffje0r=----------a")
noise_pred_cond = transformer(latent_model_input, t=timestep, flag=False, **arg_c)[0]

# # UniPC
Expand Down Expand Up @@ -1048,11 +1052,12 @@ def sample_one(
else:
model_input = latent

main_print(f"[SAMPLE] VAE decoding latents ...")
torch.cuda.empty_cache()
with torch.autocast("cuda", dtype=torch.bfloat16):
video_cat = scale(vae, model_input[:,-latent_frame_zero:,:,:])
video_cat = scale(vae, model_input[:,-latent_frame_zero:,:,:])
video = video_cat[:,-frame_zero:]
video_all.append(video)

if step_sample > 0:
#if video.shape[1] < frame_zero:
# video = torch.cat([video[:,0].unsqueeze(1).repeat(1,frame_zero-video.shape[1],1,1),video],dim=1)
Expand All @@ -1073,11 +1078,16 @@ def sample_one(
else:
videoid_str = str(videoid)

os.makedirs(video_output_dir, exist_ok=True)
caption_safe = re.sub(r'[\\/:*?"<>|→←↑↓·\s]+', '_', str(caption_ori))[:60]
filename = os.path.join(
video_output_dir,
videoid_str+"_"+str(caption_ori)+"_"+str(repeat_num)+"_"+str(rank)+"_"+str(step_sample)+".mp4",
)
export_to_video(video[0] , filename, fps=16)
video_output_dir,
f"{videoid_str}_{caption_safe}_{repeat_num}_{rank}_{step_sample}.mp4",
)
main_print(f"[SAMPLE] VAE decode done -> video shape {video[0].shape if hasattr(video[0], 'shape') else len(video[0])} frames")
main_print(f"[SAVE] Output path: {filename}")
export_to_video(video[0], filename, fps=args.fps)
main_print(f"[SAVE] Saved: {filename}")


if step_sample + 1 < sample_num:
Expand Down Expand Up @@ -1124,7 +1134,10 @@ def main(args):
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
print(f"[rank {rank}] dist init (backend={'gloo' if sys.platform == 'win32' else 'nccl'}) ...")
backend = "gloo" if sys.platform == "win32" else "nccl"
dist.init_process_group(backend, rank=rank, world_size=world_size)
print(f"[rank {rank}] dist init done")

# Set independent cache directories for each rank
os.environ["TRITON_CACHE_DIR"] = f"/tmp/triton_cache_{rank}"
Expand All @@ -1146,20 +1159,20 @@ def main(args):
ckpt_dir = "./Yume-5B-720P"

# Referenced from https://github.com/Wan-Video/Wan2.2
main_print(f"[INIT] Loading wan23.Yume from {ckpt_dir} ...")
wan_i2v = wan23.Yume(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device,
)
transformer = wan_i2v.model
)
main_print("[INIT] wan23.Yume loaded")
transformer = wan_i2v.model
transformer = transformer.eval().requires_grad_(False)

main_print(
f" Total Sample parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M"
)
main_print(
f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}"
)
main_print(f"[INIT] Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy} ...")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
Expand All @@ -1170,13 +1183,15 @@ def main(args):
)

if args.resume_from_checkpoint:
main_print(f"[INIT] Resuming from checkpoint: {args.resume_from_checkpoint} ...")
(
transformer,
init_steps,
) = resume_checkpoint(
transformer,
args.resume_from_checkpoint,
)
main_print(f"[INIT] Checkpoint resumed (init_steps={init_steps})")


from safetensors import safe_open
Expand Down Expand Up @@ -1238,12 +1253,14 @@ def main(args):

# transformer.load_state_dict(merged_weights, strict=False)

main_print("[INIT] Casting transformer to bfloat16 and wrapping with FSDP ...")
transformer = transformer.to(torch.bfloat16)
transformer = FSDP(
transformer,
**fsdp_kwargs,
use_orig_params=True,
)
main_print("[INIT] FSDP wrap done")



Expand All @@ -1261,38 +1278,52 @@ def main(args):
#init t5, clip and vae
vae = wan_i2v.vae

main_print("[INIT] dist.barrier ...")
dist.barrier()

main_print("[INIT] barrier passed")

wan_i2v.device = device
main_print("[INIT] Loading denoiser ...")
denoiser = load_denoiser()

print("jpg_dir", args.jpg_dir)
main_print("[INIT] Denoiser loaded")

main_print(f"[DATA] jpg_dir={args.jpg_dir} video_root_dir={args.video_root_dir} T2V={args.T2V}")
image_sample = False
dataset_ddp = None
dataset_length = None
if args.jpg_dir != None and not args.T2V:
dataset_ddp, dataset_length = create_scaled_videos(args.jpg_dir,
total_frames=33,
H1=704,
W1=1280)
main_print(f"[DATA] Building image dataset from {args.jpg_dir} ...")
dataset_ddp, dataset_length = create_scaled_videos(args.jpg_dir,
total_frames=33,
H1=args.height,
W1=args.width)
main_print(f"[DATA] Image dataset ready: {dataset_length} samples")
image_sample = True
elif not args.T2V:
main_print(f"[DATA] Building video dataset from {args.video_root_dir} ...")
dataset_ddp, dataset_length = mp4_data(args.video_root_dir)
main_print(f"[DATA] Video dataset ready: {dataset_length} samples")
image_sample = False

print(dataset_ddp,"dataset_ddpdataset_ddpdataset_ddp")
else:
main_print("[DATA] T2V mode — no dataset loaded")

step_times = deque(maxlen=100)
#image_sample = True
# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
path = '/mnt/petrelfs/maoxiaofeng/Yume_v2_release/InternVL3-2B-Instruct'
camption_model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True).eval().to(device)
_local = os.path.abspath(args.internvl_path)
path = _local if os.path.isdir(_local) else "OpenGVLab/InternVL3-2B-Instruct"
main_print(f"[INIT] Loading InternVL caption model from {path} ...")
# FSDP may leave an active DeviceContext("meta") TorchFunctionMode on the stack.
# torch.set_default_device("cpu") only sets a C++ variable and is overridden by
# the higher-priority Python TorchFunctionMode. Pushing a DeviceContext("cpu")
# via the context manager sits on top of any lingering meta context and wins.
with torch.device("cpu"):
camption_model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
use_flash_attn=False,
trust_remote_code=True).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
main_print("[INIT] InternVL caption model loaded")

if args.prompt!=None:
prompt1 = args.prompt
Expand All @@ -1304,7 +1335,9 @@ def main(args):
else:
date_len = int(dataset_length)//world_size + 1

main_print(f"[LOOP] Starting inference loop: {date_len-1} step(s), world_size={world_size}")
for step in range(1, date_len):
main_print(f"[LOOP] Step {step}/{date_len-1} starting ...")
start_time = time.time()
torch.cuda.empty_cache()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1335,6 +1368,7 @@ def main(args):
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
main_print(f"[LOOP] Step {step}/{date_len-1} done in {step_time:.1f}s (avg {avg_step_time:.1f}s)")


torch.cuda.empty_cache()
Expand Down Expand Up @@ -1396,6 +1430,9 @@ def main(args):
default=None,
)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument("--height", type=int, default=704)
parser.add_argument("--width", type=int, default=1280)
parser.add_argument("--fps", type=int, default=16)
parser.add_argument(
"--logging_dir",
type=str,
Expand Down Expand Up @@ -1571,6 +1608,12 @@ def main(args):
type=str,
default=None,
)
parser.add_argument(
"--internvl_path",
type=str,
default="./InternVL3-2B-Instruct",
help="Path to InternVL3-2B-Instruct model dir or HuggingFace repo ID.",
)
args = parser.parse_args()
main(args)

1 change: 1 addition & 0 deletions fastvideo/sample/sample_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ def sample_one(
video_output_dir,
videoid_str+"_"+str(caption_ori)+"_"+str(step_sample)+"_"+str(repeat_num)+".mp4",
)
print(filename)
export_to_video(video[0] , filename, fps=16)

if step_sample + 1 < sample_num:
Expand Down
Binary file modified fastvideo/utils/__pycache__/checkpoint.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added fastvideo/utils/__pycache__/load.cpython-312.pyc
Binary file not shown.
Binary file modified fastvideo/utils/__pycache__/logging_.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Loading