Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions easy_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse
import subprocess
import sys
import os
import tempfile
import shutil

def main():
parser = argparse.ArgumentParser(description="Simplified interface for Wan2.1 text-to-video generation.")
parser.add_argument("--prompt", type=str, required=True, help="The text prompt for video generation.")
parser.add_argument("--model_size", type=str, choices=["1.3B", "14B"], default="1.3B", help="The model size to use (1.3B or 14B).")
parser.add_argument("--model_dir", type=str, required=True, help="The directory containing the downloaded models (e.g., the folder where Wan2.1-T2V-1.3B is located).")
parser.add_argument("--output_path", type=str, default="output.mp4", help="The path to save the generated video.")

args = parser.parse_args()

task = f"t2v-{args.model_size}"
if args.model_size == "1.3B":
ckpt_dir_name = "Wan2.1-T2V-1.3B"
size = "832*480"
else: # 14B
ckpt_dir_name = "Wan2.1-T2V-14B"
size = "1280*720"

ckpt_dir = os.path.join(args.model_dir, ckpt_dir_name)

if not os.path.isdir(ckpt_dir):
print(f"Error: Checkpoint directory not found at {ckpt_dir}", file=sys.stderr)
sys.exit(1)

with tempfile.TemporaryDirectory() as temp_save_dir:
command = [
sys.executable,
"generate.py",
"--task", task,
"--ckpt_dir", ckpt_dir,
"--prompt", args.prompt,
"--size", size,
"--save_dir", temp_save_dir,
]

# Add some default arguments from README to improve quality for 1.3B model
if args.model_size == "1.3B":
command.extend(["--sample_guide_scale", "6", "--sample_shift", "8"])

print("Running generate.py... This may take a few minutes.")
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
print(f"Error running generate.py: {e}", file=sys.stderr)
sys.exit(1)
except FileNotFoundError:
print(f"Error: generate.py not found in the current directory.", file=sys.stderr)
sys.exit(1)

# Find the generated video file
generated_files = [f for f in os.listdir(temp_save_dir) if f.endswith(".mp4")]
if not generated_files:
print("Error: No video file was generated.", file=sys.stderr)
sys.exit(1)

generated_file_path = os.path.join(temp_save_dir, generated_files[0])

# Move the file to the desired output path
shutil.move(generated_file_path, args.output_path)
print(f"Video saved to {args.output_path}")

if __name__ == "__main__":
main()