Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*.pyc
*.log
*.json
*.out
*.om
*.mp4

kernel_meta/
results/
Binary file added examples/wan_animate/animate/image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/wan_animate/animate/video.mp4
Binary file not shown.
Binary file added examples/wan_animate/replace/image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/wan_animate/replace/video.mp4
Binary file not shown.
122 changes: 122 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"animate-14B": {
"prompt": "视频中的人在做动作",
"video": "",
"pose": "",
"mask": "",
},
}


Expand Down Expand Up @@ -240,6 +246,7 @@ def _parse_args():
default="./output/quant_data",
help="Path for calibration data or weight export.")

parser = add_animate_args(parser)
parser = add_attentioncache_args(parser)
parser = add_rainfusion_args(parser)
args = parser.parse_args()
Expand All @@ -248,6 +255,31 @@ def _parse_args():

return args

def add_animate_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Animate args")
# animate
group.add_argument(
"--src_root_path",
type=str,
default=None,
help="The file of the process output path. Default None.")
group.add_argument(
"--refert_num",
type=int,
default=77,
help="How many frames used for temporal guidance. Recommended to be 1 or 5."
)
group.add_argument(
"--replace_flag",
action="store_true",
default=False,
help="Whether to use replace.")
group.add_argument(
"--use_relighting_lora",
action="store_true",
default=False,
help="Whether to use relighting lora.")
return parser

def add_attentioncache_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Attention Cache args")
Expand Down Expand Up @@ -599,6 +631,96 @@ def generate(args):
stream.synchronize()
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")
elif "animate" in args.task:
logging.info("Creating Wan-Animate pipeline.")
wan_animate = wan.WanAnimate(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
use_relighting_lora=args.use_relighting_lora,
use_vae_parallel=args.vae_parallel,
quant_mode=args.quant_mode,
quant_data_dir=args.quant_data_dir,
)

transformer = wan_animate.noise_model

if args.use_rainfusion:
if args.dit_fsdp:
transformer._fsdp_wrapped_module.rainfusion_config = rainfusion_config
else:
transformer.rainfusion_config = rainfusion_config

if args.tp_size > 1:
logging.info("Initializing Tensor Parallel ...")
applicator = TensorParallelApplicator(args.tp_size, device_map="cpu")
applicator.apply_to_model(transformer)

if args.quant_mode == 2:
logging.info(f"quantize weights saved, will be return")
return

if args.use_attentioncache:
config = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.blocks),
steps_count=args.sample_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
else:
config = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.blocks),
steps_count=args.sample_steps
)

cache = CacheAgent(config)

if args.dit_fsdp:
for block in transformer._fsdp_wrapped_module.blocks:
block._fsdp_wrapped_module.cache = cache
block._fsdp_wrapped_module.args = args
else:
for block in transformer.blocks:
block.cache = cache
block.args = args

logging.info("Warm up 2 steps ...")
video = wan_animate.generate(
src_root_path=args.src_root_path,
replace_flag=args.replace_flag,
refert_num = args.refert_num,
clip_len=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=2,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info(f"Generating video ...")
begin = time.time()
video = wan_animate.generate(
src_root_path=args.src_root_path,
replace_flag=args.replace_flag,
refert_num = args.refert_num,
clip_len=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")
else:
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
Expand Down
11 changes: 11 additions & 0 deletions requirements_animate.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
decord
peft
onnxruntime
pandas
matplotlib
-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2
loguru
sentencepiece
numpy==1.26.4
transformers==4.56.0
moviepy
39 changes: 39 additions & 0 deletions scripts/animate/preprocess_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
MODEL_PATH=/data2/test/zt/scripts/2025_Nov_Proj/wan2-animate/weights/Wan2.2-Animate-14B
CKPT_PATH="${MODEL_PATH}/process_checkpoint"

ANIMATE_ASSET_BASE_PATH="../../examples/wan_animate/animate"
ANIMATE_VIDEO_PATH="${ANIMATE_ASSET_BASE_PATH}/video.mp4"
ANIMATE_REFER_PATH="${ANIMATE_ASSET_BASE_PATH}/image.jpeg"
ANIMATE_SAVE_PATH="${ANIMATE_ASSET_BASE_PATH}/process_results"

REPLACE_ASSET_BASE_PATH="../../examples/wan_animate/replace"
REPLACE_VIDEO_PATH="${REPLACE_ASSET_BASE_PATH}/video.mp4"
REPLACE_REFER_PATH="${REPLACE_ASSET_BASE_PATH}/image.jpeg"
REPLACE_SAVE_PATH="${REPLACE_ASSET_BASE_PATH}/process_results"


mkdir -p ${ANIMATE_SAVE_PATH}
mkdir -p ${REPLACE_SAVE_PATH}

# Animate Preprocess
# python ../../wan/modules/animate/preprocess/preprocess_data.py \
# --ckpt_path ${CKPT_PATH} \
# --video_path ${ANIMATE_VIDEO_PATH} \
# --refer_path ${ANIMATE_REFER_PATH} \
# --save_path ${ANIMATE_SAVE_PATH} \
# --resolution_area 1280 720 \
# --retarget_flag \
# --use_flux

# Replace Preprocess
python ../../wan/modules/animate/preprocess/preprocess_data.py \
--ckpt_path ${CKPT_PATH} \
--video_path ${REPLACE_VIDEO_PATH} \
--refer_path ${REPLACE_REFER_PATH} \
--save_path ${REPLACE_SAVE_PATH} \
--resolution_area 1280 720 \
--iterations 3 \
--k 7 \
--w_len 1 \
--h_len 1 \
--replace_flag
25 changes: 25 additions & 0 deletions scripts/animate/run_animate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
MODEL_PATH=/data2/test/zt/scripts/2025_Nov_Proj/wan2-animate/weights/Wan2.2-Animate-14B

ANIMATE_ASSET_BASE_PATH="../../examples/wan_animate/animate"
ANIMATE_VIDEO_PATH="${ANIMATE_ASSET_BASE_PATH}/video.mp4"
ANIMATE_REFER_PATH="${ANIMATE_ASSET_BASE_PATH}/image.jpeg"
ANIMATE_SAVE_PATH="${ANIMATE_ASSET_BASE_PATH}/process_results"

REPLACE_ASSET_BASE_PATH="../../examples/wan_animate/replace"
REPLACE_VIDEO_PATH="${REPLACE_ASSET_BASE_PATH}/video.mp4"
REPLACE_REFER_PATH="${REPLACE_ASSET_BASE_PATH}/image.jpeg"
REPLACE_SAVE_PATH="${REPLACE_ASSET_BASE_PATH}/process_results"

SRC_PATH=$REPLACE_SAVE_PATH

# export ASCEND_LAUNCH_BLOCKING=1

torchrun --nnodes 1 --nproc_per_node 8 ../../generate.py \
--task animate-14B \
--ckpt_dir ${MODEL_PATH} \
--src_root_path ${SRC_PATH} \
--refert_num 1 \
--dit_fsdp \
--t5_fsdp \
--ulysses_size 8 \
--vae_parallel
28 changes: 28 additions & 0 deletions scripts/animate/run_animate_quant.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
MODEL_PATH=/data2/test/zt/scripts/2025_Nov_Proj/wan2-animate/weights/Wan2.2-Animate-14B

ANIMATE_ASSET_BASE_PATH="../../examples/wan_animate/animate"
ANIMATE_VIDEO_PATH="${ANIMATE_ASSET_BASE_PATH}/video.mp4"
ANIMATE_REFER_PATH="${ANIMATE_ASSET_BASE_PATH}/image.jpeg"
ANIMATE_SAVE_PATH="${ANIMATE_ASSET_BASE_PATH}/process_results"

REPLACE_ASSET_BASE_PATH="../../examples/wan_animate/replace"
REPLACE_VIDEO_PATH="${REPLACE_ASSET_BASE_PATH}/video.mp4"
REPLACE_REFER_PATH="${REPLACE_ASSET_BASE_PATH}/image.jpeg"
REPLACE_SAVE_PATH="${REPLACE_ASSET_BASE_PATH}/process_results"

SRC_PATH=$REPLACE_SAVE_PATH

QUANT_MODE=3
QUNAT_DIR=$MODEL_PATH/quant_weight

torchrun --nnodes 1 --nproc_per_node 8 ../../generate.py \
--task animate-14B \
--ckpt_dir ${MODEL_PATH} \
--src_root_path ${SRC_PATH} \
--refert_num 1 \
--dit_fsdp \
--t5_fsdp \
--ulysses_size 8 \
--vae_parallel \
--quant_data_dir $QUNAT_DIR \
--quant_mode $QUANT_MODE
2 changes: 2 additions & 0 deletions wan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from . import configs, distributed, modules
from .image2video import WanI2V
# from .speech2video import WanS2V
from .text2video import WanT2V
from .textimage2video import WanTI2V
from .animate import WanAnimate
Loading