From 9d9e357708493beadd04a40e2cf4b636a7de4180 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Mon, 2 Feb 2026 15:01:47 +0800 Subject: [PATCH 1/8] load utils --- lightx2v/disagg/__init__.py | 1 + lightx2v/disagg/protocol.py | 0 lightx2v/disagg/services/__init__.py | 1 + lightx2v/disagg/utils.py | 246 +++++++++++++++++++++++++++ 4 files changed, 248 insertions(+) create mode 100644 lightx2v/disagg/__init__.py create mode 100644 lightx2v/disagg/protocol.py create mode 100644 lightx2v/disagg/services/__init__.py create mode 100644 lightx2v/disagg/utils.py diff --git a/lightx2v/disagg/__init__.py b/lightx2v/disagg/__init__.py new file mode 100644 index 000000000..45a7b4af0 --- /dev/null +++ b/lightx2v/disagg/__init__.py @@ -0,0 +1 @@ +# Disaggregation package initialization diff --git a/lightx2v/disagg/protocol.py b/lightx2v/disagg/protocol.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/disagg/services/__init__.py b/lightx2v/disagg/services/__init__.py new file mode 100644 index 000000000..54fa46a47 --- /dev/null +++ b/lightx2v/disagg/services/__init__.py @@ -0,0 +1 @@ +# Services package initialization diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py new file mode 100644 index 000000000..204a244a0 --- /dev/null +++ b/lightx2v/disagg/utils.py @@ -0,0 +1,246 @@ +import json +import logging +import os +import torch +import torch.distributed as dist +from typing import Dict, Any, Optional + +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import find_torch_model_path + +from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel +from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel +from lightx2v.models.video_encoders.hf.wan.vae import WanVAE +from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny, Wan2_2_VAE_tiny +from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.utils.set_config import get_default_config, set_config as set_config_base + +logger = logging.getLogger(__name__) + +class ConfigObj: + """Helper class to convert dictionary to object with attributes""" + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + +def set_config( + model_path, + task, + model_cls, + config_path=None, + **kwargs + ): + """ + Load configuration for Wan model. + """ + # Create arguments object similar to what set_config expects + args_dict = { + "task": task, + "model_path": model_path, + "model_cls": model_cls, + "config_json": config_path, + } + args_dict.update(kwargs) + + # Convert to object for set_config compatibility + args = ConfigObj(**args_dict) + + # Use existing set_config from utils + config = set_config_base(args) + + return config + +def build_wan_model_with_lora(wan_module, config, model_kwargs, lora_configs, model_type="high_noise_model"): + lora_dynamic_apply = config.get("lora_dynamic_apply", False) + + if lora_dynamic_apply: + if model_type in ["high_noise_model", "low_noise_model"]: + # For wan2.2 + lora_name_to_info = {item["name"]: item for item in lora_configs} + lora_path = lora_name_to_info[model_type]["path"] + lora_strength = lora_name_to_info[model_type]["strength"] + else: + # For wan2.1 + lora_path = lora_configs[0]["path"] + lora_strength = lora_configs[0]["strength"] + + model_kwargs["lora_path"] = lora_path + model_kwargs["lora_strength"] = lora_strength + model = wan_module(**model_kwargs) + else: + assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported." + assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging." + model = wan_module(**model_kwargs) + lora_wrapper = WanLoraWrapper(model) + if model_type in ["high_noise_model", "low_noise_model"]: + lora_configs = [lora_config for lora_config in lora_configs if lora_config["name"] == model_type] + lora_wrapper.apply_lora(lora_configs, model_type=model_type) + return model + +def load_wan_text_encoder(config: Dict[str, Any]): + # offload config + t5_offload = config.get("t5_cpu_offload", config.get("cpu_offload")) + if t5_offload: + t5_device = torch.device("cpu") + else: + t5_device = torch.device(AI_DEVICE) + tokenizer_path = os.path.join(config["model_path"], "google/umt5-xxl") + # quant_config + t5_quantized = config.get("t5_quantized", False) + if t5_quantized: + t5_quant_scheme = config.get("t5_quant_scheme", None) + assert t5_quant_scheme is not None + tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0] + t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth" + t5_quantized_ckpt = find_torch_model_path(config, "t5_quantized_ckpt", t5_model_name) + t5_original_ckpt = None + else: + t5_quant_scheme = None + t5_quantized_ckpt = None + t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" + t5_original_ckpt = find_torch_model_path(config, "t5_original_ckpt", t5_model_name) + + text_encoder = T5EncoderModel( + text_len=config["text_len"], + dtype=torch.bfloat16, + device=t5_device, + checkpoint_path=t5_original_ckpt, + tokenizer_path=tokenizer_path, + shard_fn=None, + cpu_offload=t5_offload, + t5_quantized=t5_quantized, + t5_quantized_ckpt=t5_quantized_ckpt, + quant_scheme=t5_quant_scheme, + load_from_rank0=config.get("load_from_rank0", False), + lazy_load=config.get("t5_lazy_load", False), + ) + # Return single encoder to match original returning list + text_encoders = [text_encoder] + return text_encoders + +def load_wan_image_encoder(config: Dict[str, Any]): + image_encoder = None + if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True): + # offload config + clip_offload = config.get("clip_cpu_offload", config.get("cpu_offload", False)) + if clip_offload: + clip_device = torch.device("cpu") + else: + clip_device = torch.device(AI_DEVICE) + # quant_config + clip_quantized = config.get("clip_quantized", False) + if clip_quantized: + clip_quant_scheme = config.get("clip_quant_scheme", None) + assert clip_quant_scheme is not None + tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] + clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth" + clip_quantized_ckpt = find_torch_model_path(config, "clip_quantized_ckpt", clip_model_name) + clip_original_ckpt = None + else: + clip_quantized_ckpt = None + clip_quant_scheme = None + clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" + clip_original_ckpt = find_torch_model_path(config, "clip_original_ckpt", clip_model_name) + + image_encoder = CLIPModel( + dtype=torch.float16, + device=clip_device, + checkpoint_path=clip_original_ckpt, + clip_quantized=clip_quantized, + clip_quantized_ckpt=clip_quantized_ckpt, + quant_scheme=clip_quant_scheme, + cpu_offload=clip_offload, + use_31_block=config.get("use_31_block", True), + load_from_rank0=config.get("load_from_rank0", False), + ) + + return image_encoder + +def get_vae_parallel(config: Dict[str, Any]): + if isinstance(config.get("parallel", False), bool): + return config.get("parallel", False) + if isinstance(config.get("parallel", False), dict): + return config.get("parallel", {}).get("vae_parallel", True) + return False + +def load_wan_vae_encoder(config: Dict[str, Any]): + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") + if config.get("model_cls", "") == "wan2.2": + vae_cls = Wan2_2_VAE + else: + vae_cls = WanVAE + + # offload config + vae_offload = config.get("vae_cpu_offload", config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(config, "vae_path", vae_name), + "device": vae_device, + "parallel": get_vae_parallel(config), + "use_tiling": config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "load_from_rank0": config.get("load_from_rank0", False), + "use_lightvae": config.get("use_lightvae", False), + } + if config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: + return None + else: + return vae_cls(**vae_config) + +def load_wan_vae_decoder(config: Dict[str, Any]): + + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") + tiny_vae_name = "taew2_1.pth" + + if config.get("model_cls", "") == "wan2.2": + vae_cls = Wan2_2_VAE + tiny_vae_cls = Wan2_2_VAE_tiny + tiny_vae_name = "taew2_2.pth" + else: + vae_cls = WanVAE + tiny_vae_cls = WanVAE_tiny + tiny_vae_name = "taew2_1.pth" + + # offload config + vae_offload = config.get("vae_cpu_offload", config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(config, "vae_path", vae_name), + "device": vae_device, + "parallel": get_vae_parallel(config), + "use_tiling": config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "use_lightvae": config.get("use_lightvae", False), + "dtype": GET_DTYPE(), + "load_from_rank0": config.get("load_from_rank0", False), + } + if config.get("use_tae", False): + tae_path = find_torch_model_path(config, "tae_path", tiny_vae_name) + vae_decoder = tiny_vae_cls(vae_path=tae_path, device=AI_DEVICE, need_scaled=config.get("need_scaled", False)).to(AI_DEVICE) + else: + vae_decoder = vae_cls(**vae_config) + return vae_decoder + +def load_wan_transformer(config: Dict[str, Any]): + if config["cpu_offload"]: + init_device = torch.device("cpu") + else: + init_device = torch.device(AI_DEVICE) + wan_model_kwargs = {"model_path": config["model_path"], "config": config, "device": init_device} + lora_configs = config.get("lora_configs") + if not lora_configs: + model = WanModel(**wan_model_kwargs) + else: + model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") + return model From 4c73d7d7b722e14700b7d55c4667dcdd643ef1c9 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Mon, 2 Feb 2026 18:46:08 +0800 Subject: [PATCH 2/8] fix load bugs and add a test --- lightx2v/disagg/utils.py | 67 ++++++++++++++++ lightx2v/disagg/wan_t2v.py | 152 +++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 lightx2v/disagg/wan_t2v.py diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 204a244a0..71dc35379 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -30,11 +30,31 @@ def set_config( task, model_cls, config_path=None, + attn_mode="flash_attn2", + rope_type="torch", + infer_steps=50, + target_video_length=81, + target_height=480, + target_width=832, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + aspect_ratio="16:9", + boundary=0.900, + boundary_step_index=2, + denoising_step_list=None, + audio_fps=24000, + double_precision_rope=True, + norm_modulate_backend="torch", + distilled_sigma_values=None, **kwargs ): """ Load configuration for Wan model. """ + if denoising_step_list is None: + denoising_step_list = [1000, 750, 500, 250] + # Create arguments object similar to what set_config expects args_dict = { "task": task, @@ -42,6 +62,53 @@ def set_config( "model_cls": model_cls, "config_json": config_path, } + + # Simulate logic from LightX2VPipeline.create_generator + # which calls set_infer_config / set_infer_config_json + # Here we directly populate args_dict with the required inference config + + if config_path is not None: + with open(config_path, "r") as f: + config_json_content = json.load(f) + args_dict.update(config_json_content) + else: + # Replicating set_infer_config logic + if model_cls == "ltx2": + args_dict["distilled_sigma_values"] = distilled_sigma_values + args_dict["infer_steps"] = len(distilled_sigma_values) - 1 if distilled_sigma_values is not None else infer_steps + else: + args_dict["infer_steps"] = infer_steps + + args_dict["target_width"] = target_width + args_dict["target_height"] = target_height + args_dict["target_video_length"] = target_video_length + args_dict["sample_guide_scale"] = sample_guide_scale + args_dict["sample_shift"] = sample_shift + + if sample_guide_scale == 1 or (model_cls == "z_image" and sample_guide_scale == 0): + args_dict["enable_cfg"] = False + else: + args_dict["enable_cfg"] = True + + args_dict["rope_type"] = rope_type + args_dict["fps"] = fps + args_dict["aspect_ratio"] = aspect_ratio + args_dict["boundary"] = boundary + args_dict["boundary_step_index"] = boundary_step_index + args_dict["denoising_step_list"] = denoising_step_list + args_dict["audio_fps"] = audio_fps + args_dict["double_precision_rope"] = double_precision_rope + + if model_cls.startswith("wan"): + # Set all attention types to the provided attn_mode + args_dict["self_attn_1_type"] = attn_mode + args_dict["cross_attn_1_type"] = attn_mode + args_dict["cross_attn_2_type"] = attn_mode + elif model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image"]: + args_dict["attn_type"] = attn_mode + + args_dict["norm_modulate_backend"] = norm_modulate_backend + args_dict.update(kwargs) # Convert to object for set_config compatibility diff --git a/lightx2v/disagg/wan_t2v.py b/lightx2v/disagg/wan_t2v.py new file mode 100644 index 000000000..bff5eb4c7 --- /dev/null +++ b/lightx2v/disagg/wan_t2v.py @@ -0,0 +1,152 @@ +import os +import torch +import logging +from loguru import logger +from lightx2v.disagg.utils import ( + load_wan_text_encoder, + load_wan_vae_decoder, + load_wan_transformer, + set_config, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.utils import seed_all, save_to_video, wan_vae_to_comfy +from lightx2v.utils.input_info import init_empty_input_info +from lightx2v.utils.envs import GET_DTYPE + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + +def main(): + # 1. Configuration + model_path = "/root/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" + task = "t2v" + model_cls = "wan2.1" + save_result_path = "/root/LightX2V/save_results/test_disagg.mp4" + + # Generation parameters + seed = 42 + prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + # Initialize configuration + # Note: We pass generation parameters as kwargs to override defaults/config file settings + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + # Configuration parameters from pipe.create_generator in original example + attn_mode= "sage_attn2", + infer_steps=50, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + # Default parameters usually set in LightX2VPipeline or config setup + enable_cfg=True + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Load Models + logger.info("Loading models...") + + # Text Encoder (T5) + text_encoders = load_wan_text_encoder(config) + text_encoder = text_encoders[0] + + # Transformer (WanModel) + model = load_wan_transformer(config) + + # VAE Decoder + vae_decoder = load_wan_vae_decoder(config) + + logger.info("Models loaded successfully.") + + # 3. Initialize Scheduler + # Only supporting basic NoCaching scheduler for this simple example as per default config + scheduler = WanScheduler(config) + model.set_scheduler(scheduler) + + # 4. Run Inference Pipeline + + # 4.1 Text Encoding + logger.info("Running text encoding...") + text_len = config.get("text_len", 512) + + # Context (Prompt) + context = text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + # Context Null (Negative Prompt) for CFG + if config.get("enable_cfg", False): + context_null = text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + # 4.2 Prepare Inputs for Transformer + # Wan T2V input construction + # We need to construct the 'inputs' dictionary expected by model.infer + + # Calculate latent shape + # Logic from DefaultRunner.get_latent_shape_with_target_hw + latent_h = config["target_height"] // config["vae_stride"][1] + latent_w = config["target_width"] // config["vae_stride"][2] + latent_shape = [ + config.get("num_channels_latents", 16), + (config["target_video_length"] - 1) // config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": None # T2V usually doesn't need image encoder output unless specified + } + + # 4.3 Scheduler Preparation + logger.info("Preparing scheduler...") + scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=None) + + # 4.4 Denoising Loop + logger.info("Starting denoising loop...") + infer_steps = scheduler.infer_steps + + for step_index in range(infer_steps): + logger.info(f"Step {step_index + 1}/{infer_steps}") + + # Pre-step + scheduler.step_pre(step_index=step_index) + + # Model Inference + model.infer(inputs) + + # Post-step + scheduler.step_post() + + latents = scheduler.latents + + # 4.5 VAE Decoding + logger.info("Decoding latents...") + # Decode latents to video frames + # latents need to be cast to correct dtype usually + gen_video = vae_decoder.decode(latents.to(GET_DTYPE())) + + # 5. Post-processing and Saving + logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + logger.info(f"Saving video to {save_result_path}...") + save_to_video(gen_video_final, save_result_path, fps=config.get("fps", 16), method="ffmpeg") + logger.info("Done!") + +if __name__ == "__main__": + main() From 0cacc4b74ba6244e3333ef44b0f4d9687ba4c318 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Tue, 3 Feb 2026 15:57:39 +0800 Subject: [PATCH 3/8] add i2v test --- lightx2v/disagg/examples/wan_i2v.py | 235 ++++++++++++++++++++++ lightx2v/disagg/{ => examples}/wan_t2v.py | 0 lightx2v/disagg/utils.py | 72 ++++++- 3 files changed, 301 insertions(+), 6 deletions(-) create mode 100644 lightx2v/disagg/examples/wan_i2v.py rename lightx2v/disagg/{ => examples}/wan_t2v.py (100%) diff --git a/lightx2v/disagg/examples/wan_i2v.py b/lightx2v/disagg/examples/wan_i2v.py new file mode 100644 index 000000000..d543eb3d9 --- /dev/null +++ b/lightx2v/disagg/examples/wan_i2v.py @@ -0,0 +1,235 @@ +import logging + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from loguru import logger +from PIL import Image + +from lightx2v.disagg.utils import ( + load_wan_image_encoder, + load_wan_text_encoder, + load_wan_vae_decoder, + load_wan_vae_encoder, + load_wan_transformer, + set_config, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import save_to_video, seed_all, wan_vae_to_comfy +from lightx2v_platform.base.global_var import AI_DEVICE + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + + +def read_image_input(image_path): + img_ori = Image.open(image_path).convert("RGB") + img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + return img, img_ori + + +def get_latent_shape_with_lat_hw(config, latent_h, latent_w): + return [ + config.get("num_channels_latents", 16), + (config["target_video_length"] - 1) // config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + +def compute_latent_shape_from_image(config, image_tensor): + h, w = image_tensor.shape[2:] + aspect_ratio = h / w + max_area = config["target_height"] * config["target_width"] + + latent_h = round( + np.sqrt(max_area * aspect_ratio) + // config["vae_stride"][1] + // config["patch_size"][1] + * config["patch_size"][1] + ) + latent_w = round( + np.sqrt(max_area / aspect_ratio) + // config["vae_stride"][2] + // config["patch_size"][2] + * config["patch_size"][2] + ) + latent_shape = get_latent_shape_with_lat_hw(config, latent_h, latent_w) + return latent_shape, latent_h, latent_w + + +def get_vae_encoder_output(vae_encoder, config, first_frame, latent_h, latent_w): + h = latent_h * config["vae_stride"][1] + w = latent_w * config["vae_stride"][2] + + msk = torch.ones( + 1, + config["target_video_length"], + latent_h, + latent_w, + device=torch.device(AI_DEVICE), + ) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, latent_h, latent_w) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, config["target_video_length"] - 1, h, w), + ], + dim=1, + ).to(AI_DEVICE) + + vae_encoder_out = vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) + vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) + return vae_encoder_out + + +def main(): + # 1. Configuration + model_path = "/root/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" + task = "i2v" + model_cls = "wan2.2_moe" + + # Generation parameters + seed = 42 + prompt = ( + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " + "and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if " + "savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details " + "and the refreshing atmosphere of the seaside." + ) + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走" + ) + image_path = "/root/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" + save_result_path = "/root/LightX2V/save_results/wan_i2v_A14B_disagg.mp4" + + # Initialize configuration + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=40, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=[3.5, 3.5], + sample_shift=5.0, + fps=16, + enable_cfg=True, + use_image_encoder=False, + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Load Models + logger.info("Loading models...") + + text_encoders = load_wan_text_encoder(config) + text_encoder = text_encoders[0] + + image_encoder = load_wan_image_encoder(config) + + model = load_wan_transformer(config) + + vae_encoder = load_wan_vae_encoder(config) + vae_decoder = load_wan_vae_decoder(config) + + logger.info("Models loaded successfully.") + + # 3. Initialize Scheduler + scheduler = WanScheduler(config) + model.set_scheduler(scheduler) + + # 4. Run Inference Pipeline + + # 4.1 Text Encoding + logger.info("Running text encoding...") + text_len = config.get("text_len", 512) + + context = text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + if config.get("enable_cfg", False): + context_null = text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + # 4.2 Image Encoding + VAE Encoding + logger.info("Running image encoding...") + img, _ = read_image_input(image_path) + + if image_encoder is not None: + clip_encoder_out = image_encoder.visual([img]).squeeze(0).to(GET_DTYPE()) + else: + clip_encoder_out = None + + if vae_encoder is None: + raise RuntimeError("VAE encoder is required for i2v task but was not loaded.") + + latent_shape, latent_h, latent_w = compute_latent_shape_from_image(config, img) + vae_encoder_out = get_vae_encoder_output(vae_encoder, config, img, latent_h, latent_w) + + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + } + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + } + + # 4.3 Scheduler Preparation + logger.info("Preparing scheduler...") + scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=image_encoder_output) + + # 4.4 Denoising Loop + logger.info("Starting denoising loop...") + infer_steps = scheduler.infer_steps + + for step_index in range(infer_steps): + logger.info(f"Step {step_index + 1}/{infer_steps}") + scheduler.step_pre(step_index=step_index) + model.infer(inputs) + scheduler.step_post() + + latents = scheduler.latents + + # 4.5 VAE Decoding + logger.info("Decoding latents...") + gen_video = vae_decoder.decode(latents.to(GET_DTYPE())) + + # 5. Post-processing and Saving + logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + logger.info(f"Saving video to {save_result_path}...") + save_to_video(gen_video_final, save_result_path, fps=config.get("fps", 16), method="ffmpeg") + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/wan_t2v.py b/lightx2v/disagg/examples/wan_t2v.py similarity index 100% rename from lightx2v/disagg/wan_t2v.py rename to lightx2v/disagg/examples/wan_t2v.py diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 71dc35379..a5abc0eb8 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -47,6 +47,11 @@ def set_config( double_precision_rope=True, norm_modulate_backend="torch", distilled_sigma_values=None, + cpu_offload=False, + offload_granularity="block", + text_encoder_offload=False, + image_encoder_offload=False, + vae_offload=False, **kwargs ): """ @@ -61,6 +66,11 @@ def set_config( "model_path": model_path, "model_cls": model_cls, "config_json": config_path, + "cpu_offload": cpu_offload, + "offload_granularity": offload_granularity, + "t5_cpu_offload": text_encoder_offload, # Map to internal keys + "clip_cpu_offload": image_encoder_offload, # Map to internal keys + "vae_cpu_offload": vae_offload, # Map to internal keys } # Simulate logic from LightX2VPipeline.create_generator @@ -304,10 +314,60 @@ def load_wan_transformer(config: Dict[str, Any]): init_device = torch.device("cpu") else: init_device = torch.device(AI_DEVICE) - wan_model_kwargs = {"model_path": config["model_path"], "config": config, "device": init_device} - lora_configs = config.get("lora_configs") - if not lora_configs: - model = WanModel(**wan_model_kwargs) + + if config.get("model_cls") == "wan2.1": + wan_model_kwargs = {"model_path": config["model_path"], "config": config, "device": init_device} + lora_configs = config.get("lora_configs") + if not lora_configs: + model = WanModel(**wan_model_kwargs) + else: + model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") + return model + elif config.get("model_cls") == "wan2.2_moe": + from lightx2v.models.runners.wan.wan_runner import MultiModelStruct + + high_noise_model_path = os.path.join(config["model_path"], "high_noise_model") + if config.get("dit_quantized", False) and config.get("high_noise_quantized_ckpt", None): + high_noise_model_path = config["high_noise_quantized_ckpt"] + elif config.get("high_noise_original_ckpt", None): + high_noise_model_path = config["high_noise_original_ckpt"] + + low_noise_model_path = os.path.join(config["model_path"], "low_noise_model") + if config.get("dit_quantized", False) and config.get("low_noise_quantized_ckpt", None): + low_noise_model_path = config["low_noise_quantized_ckpt"] + elif not config.get("dit_quantized", False) and config.get("low_noise_original_ckpt", None): + low_noise_model_path = config["low_noise_original_ckpt"] + + if not config.get("lazy_load", False) and not config.get("unload_modules", False): + lora_configs = config.get("lora_configs") + high_model_kwargs = { + "model_path": high_noise_model_path, + "config": config, + "device": init_device, + "model_type": "wan2.2_moe_high_noise", + } + low_model_kwargs = { + "model_path": low_noise_model_path, + "config": config, + "device": init_device, + "model_type": "wan2.2_moe_low_noise", + } + if not lora_configs: + high_noise_model = WanModel(**high_model_kwargs) + low_noise_model = WanModel(**low_model_kwargs) + else: + high_noise_model = build_wan_model_with_lora(WanModel, config, high_model_kwargs, lora_configs, model_type="high_noise_model") + low_noise_model = build_wan_model_with_lora(WanModel, config, low_model_kwargs, lora_configs, model_type="low_noise_model") + + return MultiModelStruct([high_noise_model, low_noise_model], config, config.get("boundary", 0.875)) + else: + model_struct = MultiModelStruct([None, None], config, config.get("boundary", 0.875)) + model_struct.low_noise_model_path = low_noise_model_path + model_struct.high_noise_model_path = high_noise_model_path + model_struct.init_device = init_device + return model_struct else: - model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") - return model + logger.error(f"Unsupported model_cls: {config.get('model_cls')}") + raise ValueError(f"Unsupported model_cls: {config.get('model_cls')}") + + From 8f359208c2e2d0818c7041cd25e9f53b395faad9 Mon Sep 17 00:00:00 2001 From: jasonzhang517 Date: Wed, 4 Feb 2026 17:19:40 +0800 Subject: [PATCH 4/8] services (local version) --- lightx2v/disagg/examples/wan_i2v.py | 13 +- lightx2v/disagg/examples/wan_i2v_service.py | 98 +++++++++++ lightx2v/disagg/examples/wan_t2v.py | 4 +- lightx2v/disagg/examples/wan_t2v_service.py | 79 +++++++++ lightx2v/disagg/services/base.py | 25 +++ lightx2v/disagg/services/encoder.py | 180 ++++++++++++++++++++ lightx2v/disagg/services/transformer.py | 93 ++++++++++ lightx2v/disagg/utils.py | 7 + 8 files changed, 488 insertions(+), 11 deletions(-) create mode 100644 lightx2v/disagg/examples/wan_i2v_service.py create mode 100644 lightx2v/disagg/examples/wan_t2v_service.py create mode 100644 lightx2v/disagg/services/base.py create mode 100644 lightx2v/disagg/services/encoder.py create mode 100644 lightx2v/disagg/services/transformer.py diff --git a/lightx2v/disagg/examples/wan_i2v.py b/lightx2v/disagg/examples/wan_i2v.py index d543eb3d9..d6e60edf8 100644 --- a/lightx2v/disagg/examples/wan_i2v.py +++ b/lightx2v/disagg/examples/wan_i2v.py @@ -13,6 +13,7 @@ load_wan_vae_encoder, load_wan_transformer, set_config, + read_image_input, ) from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.utils.envs import GET_DTYPE @@ -23,12 +24,6 @@ logging.basicConfig(level=logging.INFO) -def read_image_input(image_path): - img_ori = Image.open(image_path).convert("RGB") - img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) - return img, img_ori - - def get_latent_shape_with_lat_hw(config, latent_h, latent_w): return [ config.get("num_channels_latents", 16), @@ -90,7 +85,7 @@ def get_vae_encoder_output(vae_encoder, config, first_frame, latent_h, latent_w) def main(): # 1. Configuration - model_path = "/root/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" task = "i2v" model_cls = "wan2.2_moe" @@ -110,8 +105,8 @@ def main(): "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," "三条腿,背景人很多,倒着走" ) - image_path = "/root/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" - save_result_path = "/root/LightX2V/save_results/wan_i2v_A14B_disagg.mp4" + image_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" + save_result_path = "/root/zht/LightX2V/save_results/wan_i2v_A14B_disagg.mp4" # Initialize configuration config = set_config( diff --git a/lightx2v/disagg/examples/wan_i2v_service.py b/lightx2v/disagg/examples/wan_i2v_service.py new file mode 100644 index 000000000..0ca92e67e --- /dev/null +++ b/lightx2v/disagg/examples/wan_i2v_service.py @@ -0,0 +1,98 @@ +import logging +import os +import torch +from loguru import logger + +from lightx2v.disagg.utils import set_config +from lightx2v.disagg.services.encoder import EncoderService +from lightx2v.disagg.services.transformer import TransformerService +from lightx2v.utils.utils import seed_all + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" + task = "i2v" + model_cls = "wan2.2_moe" + + # Generation parameters + seed = 42 + prompt = ( + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " + "and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if " + "savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details " + "and the refreshing atmosphere of the seaside." + ) + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走" + ) + image_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" + save_result_path = "/root/zht/LightX2V/save_results/wan_i2v_A14B_disagg_service.mp4" + + # Initialize configuration + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=40, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=[3.5, 3.5], + sample_shift=5.0, + fps=16, + enable_cfg=True, + use_image_encoder=False, + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, + ) + + config.image_path = image_path + config.prompt = prompt + config.negative_prompt = negative_prompt + config.save_path = save_result_path + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Add seed to config so services use it + config.seed = seed + + # 3. Initialize Services + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + + # 4. Run Process + + # 4.1 Encoder Step + logger.info("Running Encoder Service...") + encoder_results = encoder_service.process() + + inputs = { + "text_encoder_output": encoder_results["text_encoder_output"], + "image_encoder_output": encoder_results["image_encoder_output"], + "latent_shape": encoder_results["latent_shape"], + } + + # 4.2 Transformer Step + logger.info("Running Transformer Service...") + result_path = transformer_service.process(inputs=inputs) + + logger.info(f"Video generation completed. Saved to: {result_path}") + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/wan_t2v.py b/lightx2v/disagg/examples/wan_t2v.py index bff5eb4c7..99586b2a2 100644 --- a/lightx2v/disagg/examples/wan_t2v.py +++ b/lightx2v/disagg/examples/wan_t2v.py @@ -18,10 +18,10 @@ def main(): # 1. Configuration - model_path = "/root/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" task = "t2v" model_cls = "wan2.1" - save_result_path = "/root/LightX2V/save_results/test_disagg.mp4" + save_result_path = "/root/zht/LightX2V/save_results/test_disagg.mp4" # Generation parameters seed = 42 diff --git a/lightx2v/disagg/examples/wan_t2v_service.py b/lightx2v/disagg/examples/wan_t2v_service.py new file mode 100644 index 000000000..c6c469488 --- /dev/null +++ b/lightx2v/disagg/examples/wan_t2v_service.py @@ -0,0 +1,79 @@ +import logging +from loguru import logger + +from lightx2v.disagg.services.encoder import EncoderService +from lightx2v.disagg.services.transformer import TransformerService +from lightx2v.disagg.utils import set_config +from lightx2v.utils.utils import seed_all + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" + task = "t2v" + model_cls = "wan2.1" + save_result_path = "/root/zht/LightX2V/save_results/test_disagg.mp4" + + # Generation parameters + seed = 42 + prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部," + "畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + + # Initialize configuration (same as wan_t2v.py) + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=50, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + enable_cfg=True, + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # Add seed into config so services can use it if needed + config.seed = seed + config.prompt = prompt + config.negative_prompt = negative_prompt + config.save_path = save_result_path + + # 2. Initialize Services + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + + # 3. Encoding (via EncoderService) + logger.info("Running Encoder Service...") + encoder_results = encoder_service.process() + + inputs = { + "text_encoder_output": encoder_results["text_encoder_output"], + "image_encoder_output": encoder_results["image_encoder_output"], + "latent_shape": encoder_results["latent_shape"], + } + + # 4. Run Transformer Service + logger.info("Running Transformer Service...") + result_path = transformer_service.process(inputs=inputs) + + logger.info(f"Video generation completed. Saved to: {result_path}") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py new file mode 100644 index 000000000..4e777fac5 --- /dev/null +++ b/lightx2v/disagg/services/base.py @@ -0,0 +1,25 @@ +import logging +import torch +from abc import ABC, abstractmethod + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class BaseService(ABC): + def __init__(self, config): + """ + Base initialization for all services. + Args: + config: A dictionary or object containing configuration parameters. + """ + self.config = config + self.logger = logger + self.logger.info(f"Initializing {self.__class__.__name__} with config: {config}") + + @abstractmethod + def load_models(self): + """ + Abstract method to load necessary models. + """ + pass diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py new file mode 100644 index 000000000..3b171bc8b --- /dev/null +++ b/lightx2v/disagg/services/encoder.py @@ -0,0 +1,180 @@ +import torch +import numpy as np +from typing import Dict, Any, Optional + +from lightx2v.disagg.services.base import BaseService +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import seed_all +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.disagg.utils import ( + load_wan_text_encoder, + load_wan_image_encoder, + load_wan_vae_encoder, + read_image_input, +) + +class EncoderService(BaseService): + def __init__(self, config): + super().__init__(config) + self.text_encoder = None + self.image_encoder = None + self.vae_encoder = None + + # Load models based on config + self.load_models() + + # Seed everything if seed is in config + if "seed" in self.config: + seed_all(self.config["seed"]) + + def load_models(self): + self.logger.info("Loading Encoder Models...") + + # T5 Text Encoder + text_encoders = load_wan_text_encoder(self.config) + self.text_encoder = text_encoders[0] if text_encoders else None + + # CLIP Image Encoder (Optional per usage in wan_i2v.py) + if self.config.get("use_image_encoder", False): + self.image_encoder = load_wan_image_encoder(self.config) + + # VAE Encoder (Required for I2V) + # Note: wan_i2v.py logic: if vae_encoder is None: raise RuntimeError + # But we only load if needed or always? Let's check the config flags. + # It seems always loaded for I2V task, but might be offloaded. + # For simplicity of this service, we load it if the task implies it or just try to load. + # But `load_wan_vae_encoder` will look at the config. + self.vae_encoder = load_wan_vae_encoder(self.config) + + self.logger.info("Encoder Models loaded successfully.") + + def _get_latent_shape_with_lat_hw(self, latent_h, latent_w): + return [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + def _compute_latent_shape_from_image(self, image_tensor: torch.Tensor): + h, w = image_tensor.shape[2:] + aspect_ratio = h / w + max_area = self.config["target_height"] * self.config["target_width"] + + latent_h = round( + np.sqrt(max_area * aspect_ratio) + // self.config["vae_stride"][1] + // self.config["patch_size"][1] + * self.config["patch_size"][1] + ) + latent_w = round( + np.sqrt(max_area / aspect_ratio) + // self.config["vae_stride"][2] + // self.config["patch_size"][2] + * self.config["patch_size"][2] + ) + latent_shape = self._get_latent_shape_with_lat_hw(latent_h, latent_w) + return latent_shape, latent_h, latent_w + + def _get_vae_encoder_output(self, first_frame: torch.Tensor, latent_h: int, latent_w: int): + h = latent_h * self.config["vae_stride"][1] + w = latent_w * self.config["vae_stride"][2] + + msk = torch.ones( + 1, + self.config["target_video_length"], + latent_h, + latent_w, + device=torch.device(AI_DEVICE), + ) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, latent_h, latent_w) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, self.config["target_video_length"] - 1, h, w), + ], + dim=1, + ).to(AI_DEVICE) + + vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) + vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) + return vae_encoder_out + + def process(self) -> Dict[str, Any]: + """ + Generates encoder outputs from prompt and image input. + """ + self.logger.info("Starting processing in EncoderService...") + + prompt = self.config.get("prompt") + negative_prompt = self.config.get("negative_prompt") + if prompt is None: + raise ValueError("prompt is required in config.") + + # 1. Text Encoding + text_len = self.config.get("text_len", 512) + + context = self.text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + if self.config.get("enable_cfg", False): + if negative_prompt is None: + raise ValueError("negative_prompt is required in config when enable_cfg is True.") + context_null = self.text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + task = self.config.get("task") + + if task == "t2v": + latent_h = self.config["target_height"] // self.config["vae_stride"][1] + latent_w = self.config["target_width"] // self.config["vae_stride"][2] + latent_shape = [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + image_encoder_output = None + elif task == "i2v": + image_path = self.config.get("image_path") + if image_path is None: + raise ValueError("image_path is required for i2v task.") + + # 2. Image Encoding + VAE Encoding + img, _ = read_image_input(image_path) + + clip_encoder_out = None + if self.image_encoder is not None: + # Assuming image_encoder.visual handles list of images + clip_encoder_out = self.image_encoder.visual([img]).squeeze(0).to(GET_DTYPE()) + + if self.vae_encoder is None: + raise RuntimeError("VAE encoder is required but was not loaded.") + + latent_shape, latent_h, latent_w = self._compute_latent_shape_from_image(img) + vae_encoder_out = self._get_vae_encoder_output(img, latent_h, latent_w) + + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + } + else: + raise ValueError(f"Unsupported task: {task}") + + # Return both outputs and potentially latent_shape if needed downstream + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + "latent_shape": latent_shape # Often needed by scheduler downstream + } diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py new file mode 100644 index 000000000..e966e9adb --- /dev/null +++ b/lightx2v/disagg/services/transformer.py @@ -0,0 +1,93 @@ +import torch +import logging +from typing import Dict, Any, List, Optional + +from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.utils import ( + load_wan_transformer, + load_wan_vae_decoder, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import seed_all, save_to_video, wan_vae_to_comfy + +class TransformerService(BaseService): + def __init__(self, config): + super().__init__(config) + self.transformer = None + self.vae_decoder = None + self.scheduler = None + + self.load_models() + + # Set global seed if present in config, though specific process calls might reuse it + if "seed" in self.config: + seed_all(self.config["seed"]) + + def load_models(self): + self.logger.info("Loading Transformer Models...") + + self.transformer = load_wan_transformer(self.config) + self.vae_decoder = load_wan_vae_decoder(self.config) + + # Initialize scheduler + self.scheduler = WanScheduler(self.config) + self.transformer.set_scheduler(self.scheduler) + + self.logger.info("Transformer Models loaded successfully.") + + def process(self, inputs: Dict[str, Any]): + """ + Executes the diffusion process and video decoding. + + Args: + inputs: Dictionary containing 'text_encoder_output', 'image_encoder_output', and 'latent_shape'. + """ + self.logger.info("Starting processing in TransformerService...") + + seed = self.config.get("seed") + save_path = self.config.get("save_path") + if seed is None: + raise ValueError("seed is required in config.") + if save_path is None: + raise ValueError("save_path is required in config.") + + image_encoder_output = inputs.get("image_encoder_output") + latent_shape = inputs.get("latent_shape") + if latent_shape is None: + raise ValueError("latent_shape is required in inputs.") + + # Scheduler Preparation + self.logger.info(f"Preparing scheduler with seed {seed}...") + self.scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=image_encoder_output) + + # Denoising Loop + self.logger.info("Starting denoising loop...") + infer_steps = self.scheduler.infer_steps + + for step_index in range(infer_steps): + if step_index % 10 == 0: + self.logger.info(f"Step {step_index + 1}/{infer_steps}") + self.scheduler.step_pre(step_index=step_index) + self.transformer.infer(inputs) + self.scheduler.step_post() + + latents = self.scheduler.latents + + # VAE Decoding + self.logger.info("Decoding latents...") + if self.vae_decoder is None: + raise RuntimeError("VAE decoder is not loaded.") + + gen_video = self.vae_decoder.decode(latents.to(GET_DTYPE())) + + # Post-processing + self.logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + # Saving + self.logger.info(f"Saving video to {save_path}...") + save_to_video(gen_video_final, save_path, fps=self.config.get("fps", 16), method="ffmpeg") + self.logger.info("Done!") + + return save_path diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index a5abc0eb8..2c05af939 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -3,6 +3,8 @@ import os import torch import torch.distributed as dist +import torchvision.transforms.functional as TF +from PIL import Image from typing import Dict, Any, Optional from lightx2v_platform.base.global_var import AI_DEVICE @@ -25,6 +27,11 @@ class ConfigObj: def __init__(self, **kwargs): self.__dict__.update(kwargs) +def read_image_input(image_path): + img_ori = Image.open(image_path).convert("RGB") + img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + return img, img_ori + def set_config( model_path, task, From ff5e7738486a1802b6a57e269bb26c14edd63120 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 6 Feb 2026 13:57:12 +0800 Subject: [PATCH 5/8] add disaggregation --- lightx2v/disagg/conn.py | 318 ++++++++++++++++++++ lightx2v/disagg/examples/wan_i2v_service.py | 54 ++-- lightx2v/disagg/examples/wan_t2v_service.py | 54 ++-- lightx2v/disagg/mooncake.py | 108 +++++++ lightx2v/disagg/protocol.py | 40 +++ lightx2v/disagg/services/encoder.py | 123 +++++++- lightx2v/disagg/services/transformer.py | 174 ++++++++++- lightx2v/disagg/utils.py | 46 ++- 8 files changed, 856 insertions(+), 61 deletions(-) create mode 100644 lightx2v/disagg/conn.py create mode 100644 lightx2v/disagg/mooncake.py diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py new file mode 100644 index 000000000..fb3a43358 --- /dev/null +++ b/lightx2v/disagg/conn.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import asyncio +import logging +import struct +import threading +from functools import cache +from typing import Dict, List, Optional, Tuple +from enum import Enum + +import numpy as np +import numpy.typing as npt +import zmq +from aiohttp import web + +from lightx2v.disagg.mooncake import MooncakeTransferEngine + +logger = logging.getLogger(__name__) + +class DisaggregationMode(Enum): + NULL = "null" + ENCODE = "encode" + TRANSFORMER = "transformer" + +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src_groups = [] + dst_groups = [] + current_src = [src_indices[0]] + current_dst = [dst_indices[0]] + + for i in range(1, len(src_indices)): + src_contiguous = src_indices[i] == src_indices[i - 1] + 1 + dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 + if src_contiguous and dst_contiguous: + current_src.append(src_indices[i]) + current_dst.append(dst_indices[i]) + else: + src_groups.append(current_src) + dst_groups.append(current_dst) + current_src = [src_indices[i]] + current_dst = [dst_indices[i]] + + src_groups.append(current_src) + dst_groups.append(current_dst) + + return src_groups, dst_groups + + +class DataArgs: + engine_rank: int + data_ptrs: list[int] + data_lens: list[int] + data_item_lens: list[int] + ib_device: str + + +class DataPoll: + Failed = 0 + Bootstrapping = 1 + WaitingForInput = 2 + Transferring = 3 + Success = 4 + + +RequestPoolType = Dict[int, List[int]] +WaitingPoolType = Dict[ + int, Tuple[str, list[int]] +] +DATASENDER_POLLING_PORT = 17788 +DATARECEIVER_POLLING_PORT = 27788 + + +class DataManager: + # TODO: make it general and support multiple transfer backend before merging + def __init__(self, args: DataArgs, disaggregation_mode: DisaggregationMode): + self.engine = MooncakeTransferEngine() + self.data_args = args + self.disaggregation_mode = disaggregation_mode + self.request_pool: RequestPoolType = {} + self.request_status: Dict[int, DataPoll] = {} + self.server_socket = zmq.Context().socket(zmq.PULL) + self.register_buffer_to_engine() + if self.disaggregation_mode == DisaggregationMode.ENCODE: + self.waiting_pool: WaitingPoolType = {} + self.transfer_event = threading.Event() + self.start_encode_thread() + elif self.disaggregation_mode == DisaggregationMode.TRANSFORMER: + self.start_transformer_thread() + else: + raise ValueError( + f"Unsupported DisaggregationMode: {self.disaggregation_mode}" + ) + + def register_buffer_to_engine(self): + for data_ptr, data_len in zip( + self.data_args.data_ptrs, self.data_args.data_lens + ): + self.engine.register(data_ptr, data_len) + + def update_data_args(self, data_args: DataArgs): + self.data_args = data_args + self.register_buffer_to_engine() + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def send_data( + self, + mooncake_session_id: str, + encode_data_ptrs: List[int], + transformer_ptrs: list[int], + ): + tensor_num = int(len(self.data_args.data_ptrs)) + for tensor_id in range(tensor_num): + encode_addr = encode_data_ptrs[tensor_id] + item_len = self.data_args.data_item_lens[tensor_id] + transformer_addr = transformer_ptrs[tensor_id] + + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + encode_addr, + transformer_addr, + item_len, + ) + if status != 0: + return status + return 0 + + def sync_status_to_transformer_endpoint(self, remote: str, room: int): + if ":" in remote: + remote = remote.split(":")[0] + self._connect( + "tcp://" + + remote + + ":" + + str(DATARECEIVER_POLLING_PORT + self.data_args.engine_rank) + ).send_multipart( + [ + str(room).encode("ascii"), + str(self.request_status[room]).encode("ascii"), + ] + ) + + def start_encode_thread(self): + sender_rank_port = DATASENDER_POLLING_PORT + self.data_args.engine_rank + self.server_socket.bind("tcp://*:" + str(sender_rank_port)) + + def encode_thread(): + while True: + ( + endpoint, + mooncake_session_id, + bootstrap_room, + transformer_ptrs, + ) = self.server_socket.recv_multipart() + if bootstrap_room.decode("ascii") == "None": + continue + endpoint = endpoint.decode("ascii") + mooncake_session_id = mooncake_session_id.decode("ascii") + bootstrap_room = int(bootstrap_room.decode("ascii")) + transformer_ptrs = list( + struct.unpack(f"{len(transformer_ptrs)//8}Q", transformer_ptrs) + ) + self.waiting_pool[bootstrap_room] = ( + endpoint, + mooncake_session_id, + transformer_ptrs, + ) + self.transfer_event.set() + + threading.Thread(target=encode_thread).start() + + def transfer_thread(): + while True: + self.transfer_event.wait() + self.transfer_event.clear() + bootstrap_room_ready = self.request_pool.keys() + bootstrap_room_request = self.waiting_pool.keys() + for room in list(bootstrap_room_request): + if room not in list(bootstrap_room_ready): + continue + status = DataPoll.Transferring + self.request_status[room] = status + ( + endpoint, + mooncake_session_id, + transformer_ptrs, + ) = self.waiting_pool.pop(room) + self.sync_status_to_transformer_endpoint(endpoint, room) + encode_data_ptrs = self.request_pool.pop(room) + ret = self.send_data( + mooncake_session_id, + encode_data_ptrs, + transformer_ptrs, + ) + if ret != 0: + status = DataPoll.Failed + self.sync_status_to_transformer_endpoint(endpoint, room) + continue + status = DataPoll.Success + self.request_status[room] = status + self.sync_status_to_transformer_endpoint(endpoint, room) + + threading.Thread(target=transfer_thread).start() + + def start_transformer_thread(self): + receiver_rank_port = DATARECEIVER_POLLING_PORT + self.data_args.engine_rank + self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) + + def transformer_thread(): + while True: + (bootstrap_room, status) = self.server_socket.recv_multipart() + status = int(status.decode("ascii")) + bootstrap_room = int(bootstrap_room.decode("ascii")) + self.request_status[bootstrap_room] = status + + threading.Thread(target=transformer_thread).start() + + def enqueue_request( + self, + bootstrap_room: int, + data_ptrs: List[int], + ): + self.request_pool[bootstrap_room] = data_ptrs + self.request_status[bootstrap_room] = DataPoll.WaitingForInput + if self.disaggregation_mode == DisaggregationMode.ENCODE: + self.transfer_event.set() + + def check_status(self, bootstrap_room: int): + if ( + self.disaggregation_mode == DisaggregationMode.TRANSFORMER + and self.request_status[bootstrap_room] == DataPoll.Success + ): + if bootstrap_room in self.request_pool: + self.request_pool.pop(bootstrap_room) + + return self.request_status[bootstrap_room] + + def set_status(self, bootstrap_room: int, status: DataPoll): + self.request_status[bootstrap_room] = status + + def get_localhost(self): + return self.engine.get_localhost() + + def get_session_id(self): + return self.engine.get_session_id() + + +class DataSender: + + def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: int): + self.data_mgr = mgr + self.bootstrap_room = bootstrap_room + self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) + + def init(self, num_data_indices: int): + self.num_data_indices = num_data_indices + + def send(self, data_ptrs: List[int]): + self.data_mgr.enqueue_request(self.bootstrap_room, data_ptrs) + + def poll(self) -> DataPoll: + return self.data_mgr.check_status(self.bootstrap_room) + + def failure_exception(self): + raise Exception("Fake DataSender Exception") + + +class DataReceiver: + + def __init__( + self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None + ): + self.bootstrap_room = bootstrap_room + self.bootstrap_addr = bootstrap_addr + self.data_mgr = mgr + self.encode_server_url = ( + bootstrap_addr.split(":")[0] + + ":" + + str(DATASENDER_POLLING_PORT + self.data_mgr.data_args.engine_rank) + ) + self.transformer_ip = self.data_mgr.get_localhost() + self.session_id = self.data_mgr.get_session_id() + self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def init(self, data_indices: npt.NDArray[np.int64]): + self.data_mgr.enqueue_request(self.bootstrap_room, data_indices) + packed_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.data_mgr.data_args.data_ptrs + ) + self._connect("tcp://" + self.encode_server_url).send_multipart( + [ + self.transformer_ip.encode("ascii"), + self.session_id.encode("ascii"), + str(self.bootstrap_room).encode("ascii"), + packed_data_ptrs, + data_indices.tobytes(), + ] + ) + + def poll(self) -> DataPoll: + return self.data_mgr.check_status(self.bootstrap_room) + + def failure_exception(self): + raise Exception("Fake DataReceiver Exception") + diff --git a/lightx2v/disagg/examples/wan_i2v_service.py b/lightx2v/disagg/examples/wan_i2v_service.py index 0ca92e67e..6e0185406 100644 --- a/lightx2v/disagg/examples/wan_i2v_service.py +++ b/lightx2v/disagg/examples/wan_i2v_service.py @@ -56,6 +56,8 @@ def main(): text_encoder_offload=True, image_encoder_offload=False, vae_offload=False, + data_bootstrap_addr="127.0.0.1", + data_bootstrap_room=0, ) config.image_path = image_path @@ -69,30 +71,36 @@ def main(): # 2. Add seed to config so services use it config.seed = seed - # 3. Initialize Services - logger.info("Initializing Encoder Service...") - encoder_service = EncoderService(config) - - logger.info("Initializing Transformer Service...") - transformer_service = TransformerService(config) + # 3. Define service threads + import threading - # 4. Run Process - - # 4.1 Encoder Step - logger.info("Running Encoder Service...") - encoder_results = encoder_service.process() - - inputs = { - "text_encoder_output": encoder_results["text_encoder_output"], - "image_encoder_output": encoder_results["image_encoder_output"], - "latent_shape": encoder_results["latent_shape"], - } - - # 4.2 Transformer Step - logger.info("Running Transformer Service...") - result_path = transformer_service.process(inputs=inputs) - - logger.info(f"Video generation completed. Saved to: {result_path}") + def run_encoder(): + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + logger.info("Running Encoder Service...") + encoder_service.process() + encoder_service.release_memory() + + def run_transformer(): + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + logger.info("Running Transformer Service...") + result_path = transformer_service.process() + logger.info(f"Video generation completed. Saved to: {result_path}") + transformer_service.release_memory() + + # 4. Start threads + encoder_thread = threading.Thread(target=run_encoder) + transformer_thread = threading.Thread(target=run_transformer) + + logger.info("Starting services in separate threads...") + encoder_thread.start() + transformer_thread.start() + + # 5. Wait for completion + encoder_thread.join() + transformer_thread.join() + logger.info("All services finished.") if __name__ == "__main__": main() diff --git a/lightx2v/disagg/examples/wan_t2v_service.py b/lightx2v/disagg/examples/wan_t2v_service.py index c6c469488..09dc11501 100644 --- a/lightx2v/disagg/examples/wan_t2v_service.py +++ b/lightx2v/disagg/examples/wan_t2v_service.py @@ -40,6 +40,8 @@ def main(): sample_shift=5.0, fps=16, enable_cfg=True, + data_bootstrap_addr="127.0.0.1", + data_bootstrap_room=0, ) logger.info(f"Config initialized for task: {task}") @@ -51,28 +53,36 @@ def main(): config.negative_prompt = negative_prompt config.save_path = save_result_path - # 2. Initialize Services - logger.info("Initializing Encoder Service...") - encoder_service = EncoderService(config) - - logger.info("Initializing Transformer Service...") - transformer_service = TransformerService(config) - - # 3. Encoding (via EncoderService) - logger.info("Running Encoder Service...") - encoder_results = encoder_service.process() - - inputs = { - "text_encoder_output": encoder_results["text_encoder_output"], - "image_encoder_output": encoder_results["image_encoder_output"], - "latent_shape": encoder_results["latent_shape"], - } - - # 4. Run Transformer Service - logger.info("Running Transformer Service...") - result_path = transformer_service.process(inputs=inputs) - - logger.info(f"Video generation completed. Saved to: {result_path}") + # 2. Define service threads + import threading + + def run_encoder(): + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + logger.info("Running Encoder Service...") + encoder_service.process() + encoder_service.release_memory() + + def run_transformer(): + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + logger.info("Running Transformer Service...") + result_path = transformer_service.process() + logger.info(f"Video generation completed. Saved to: {result_path}") + transformer_service.release_memory() + + # 3. Start threads + encoder_thread = threading.Thread(target=run_encoder) + transformer_thread = threading.Thread(target=run_transformer) + + logger.info("Starting services in separate threads...") + encoder_thread.start() + transformer_thread.start() + + # 4. Wait for completion + encoder_thread.join() + transformer_thread.join() + logger.info("All services finished.") if __name__ == "__main__": diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py new file mode 100644 index 000000000..adee5804b --- /dev/null +++ b/lightx2v/disagg/mooncake.py @@ -0,0 +1,108 @@ +import json +import logging +import os +import uuid +import socket +import struct +import pickle +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + +logger = logging.getLogger(__name__) + +@dataclass +class MooncakeTransferEngineConfig: + local_hostname: str + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeTransferEngineConfig": + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + local_hostname=config.get("local_hostname", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "rdma"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> "MooncakeTransferEngineConfig": + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") + if config_file_path is None: + raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + def __init__(self): + self.engine = None + try: + from mooncake.engine import TransferEngine + self.engine = TransferEngine() + except ImportError as e: + logger.warning( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " + "to run with MooncakeTransferEngine." + ) + # We allow continuing without engine for non-transfer operations or testing structure + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except Exception as e: + logger.error(f"Failed to load Mooncake config: {e}") + raise + + session_suffix = "_" + str(uuid.uuid4()) + self.session_id = self.config.local_hostname + session_suffix + self.initialize( + self.session_id, + self.config.metadata_server, + self.config.protocol, + self.config.device_name, + ) + + def register(self, ptr, length): + if self.engine: + self.engine.register_memory(ptr, length) + + def deregister(self, ptr): + if self.engine: + self.engine.unregister_memory(ptr) + + def initialize( + self, + local_hostname: str, + metadata_server: str, + protocol: str, + device_name: str, + ) -> None: + """Initialize the mooncake instance.""" + if self.engine: + self.engine.initialize(local_hostname, metadata_server, protocol, device_name) + + def transfer_sync( + self, session_id: str, buffer: int, peer_buffer_address: int, length: int + ) -> int: + """Synchronously transfer data to the specified address.""" + if self.engine: + ret = self.engine.transfer_sync_write( + session_id, buffer, peer_buffer_address, length + ) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + return -1 + + def get_localhost(self): + return self.config.local_hostname + + def get_session_id(self): + return self.session_id diff --git a/lightx2v/disagg/protocol.py b/lightx2v/disagg/protocol.py index e69de29bb..1f8ff21ba 100644 --- a/lightx2v/disagg/protocol.py +++ b/lightx2v/disagg/protocol.py @@ -0,0 +1,40 @@ +import zmq +import torch +import pickle +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from lightx2v.disagg.mooncake import MooncakeTransferEngine + +logger = logging.getLogger(__name__) + +@dataclass +class TensorMetadata: + id: int + shape: Tuple[int, ...] + dtype: torch.dtype + nbytes: int + +@dataclass +class AllocationRequest: + """ + Request sent from Sender (Encoder) to Receiver (Transformer). + - bootstrap_room: Unique ID for the transfer slot/session. + - config: Inference config used to estimate upper-bound buffer sizes. + """ + bootstrap_room: str + config: Dict[str, Any] + +@dataclass +class RemoteBuffer: + addr: int + session_id: str + nbytes: int + +@dataclass +class MemoryHandle: + """ + Handle sent from Receiver (Transformer) to Sender (Encoder). + - buffers: List of remote buffer details corresponding to the tensor_specs in the request. + """ + buffers: List[RemoteBuffer] diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 3b171bc8b..323e70128 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -1,8 +1,10 @@ +import math import torch import numpy as np -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, List from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.conn import DataArgs, DataManager, DataSender, DisaggregationMode, DataPoll from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.utils import seed_all from lightx2v_platform.base.global_var import AI_DEVICE @@ -19,6 +21,10 @@ def __init__(self, config): self.text_encoder = None self.image_encoder = None self.vae_encoder = None + self.engine_rank = 0 + self.data_mgr = None + self.data_sender = None + self._rdma_buffers: List[torch.Tensor] = [] # Load models based on config self.load_models() @@ -27,6 +33,23 @@ def __init__(self, config): if "seed" in self.config: seed_all(self.config["seed"]) + data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") + data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + + if data_bootstrap_addr is not None and data_bootstrap_room is not None: + data_ptrs, data_lens, data_item_lens = self.alloc_bufs() + data_args = DataArgs( + engine_rank=self.engine_rank, + data_ptrs=data_ptrs, + data_lens=data_lens, + data_item_lens=data_item_lens, + ib_device=self.config.get("ib_device", ""), + ) + self.data_mgr = DataManager(data_args, DisaggregationMode.ENCODE) + self.data_sender = DataSender( + self.data_mgr, data_bootstrap_addr, int(data_bootstrap_room) + ) + def load_models(self): self.logger.info("Loading Encoder Models...") @@ -104,6 +127,54 @@ def _get_vae_encoder_output(self, first_frame: torch.Tensor, latent_h: int, late vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) return vae_encoder_out + def alloc_bufs(self): + text_len = int(self.config.get("text_len", 512)) + enable_cfg = bool(self.config.get("enable_cfg", False)) + use_image_encoder = bool(self.config.get("use_image_encoder", True)) + task = self.config.get("task", "i2v") + + text_dim = int(self.config.get("text_encoder_dim", 4096)) + clip_dim = int(self.config.get("clip_embed_dim", 1024)) + z_dim = int(self.config.get("vae_z_dim", 16)) + + vae_stride = self.config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + + target_video_length = int(self.config.get("target_video_length", 81)) + target_height = int(self.config.get("target_height", 480)) + target_width = int(self.config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(math.ceil(target_height / stride_h)) + w_prime = int(math.ceil(target_width / stride_w)) + + self._rdma_buffers = [] + data_ptrs: List[int] = [] + data_lens: List[int] = [] + data_item_lens: List[int] = [] + + def _alloc_buffer(shape, dtype): + buf = torch.empty(shape, dtype=dtype, device=torch.device(f"cuda:{self.engine_rank}")) + self._rdma_buffers.append(buf) + nbytes = buf.numel() * buf.element_size() + data_ptrs.append(buf.data_ptr()) + data_lens.append(nbytes) + data_item_lens.append(nbytes) + + _alloc_buffer((1, text_len, text_dim), GET_DTYPE()) + if enable_cfg: + _alloc_buffer((1, text_len, text_dim), GET_DTYPE()) + + if task == "i2v": + if use_image_encoder: + _alloc_buffer((clip_dim,), GET_DTYPE()) + _alloc_buffer((z_dim + 4, t_prime, h_prime, w_prime), GET_DTYPE()) + + _alloc_buffer((4,), torch.int64) + return data_ptrs, data_lens, data_item_lens + def process(self) -> Dict[str, Any]: """ Generates encoder outputs from prompt and image input. @@ -171,10 +242,46 @@ def process(self) -> Dict[str, Any]: } else: raise ValueError(f"Unsupported task: {task}") - - # Return both outputs and potentially latent_shape if needed downstream - return { - "text_encoder_output": text_encoder_output, - "image_encoder_output": image_encoder_output, - "latent_shape": latent_shape # Often needed by scheduler downstream - } + + if self.data_mgr is not None and self.data_sender is not None: + buffer_index = 0 + self._rdma_buffers[buffer_index].copy_(context) + buffer_index += 1 + if self.config.get("enable_cfg", False): + self._rdma_buffers[buffer_index].copy_(context_null) + buffer_index += 1 + + if task == "i2v": + if self.config.get("use_image_encoder", True): + if image_encoder_output.get("clip_encoder_out") is not None: + self._rdma_buffers[buffer_index].copy_(image_encoder_output["clip_encoder_out"]) + else: + self._rdma_buffers[buffer_index].zero_() + buffer_index += 1 + + vae_buf = self._rdma_buffers[buffer_index] + vae_buf.zero_() + vae_flat = vae_buf.view(-1) + src_flat = image_encoder_output["vae_encoder_out"].reshape(-1) + vae_flat[: src_flat.numel()].copy_(src_flat) + buffer_index += 1 + + latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) + self._rdma_buffers[buffer_index].copy_(latent_tensor) + + buffer_ptrs = [buf.data_ptr() for buf in self._rdma_buffers] + self.data_sender.send(buffer_ptrs) + + import time + while True: + status = self.data_sender.poll() + if status == DataPoll.Success: + break + time.sleep(0.01) + + def release_memory(self): + """ + Releases the RDMA buffers and clears GPU cache. + """ + self._rdma_buffers = [] + torch.cuda.empty_cache() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index e966e9adb..517aa369a 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -1,12 +1,18 @@ import torch import logging +import numpy as np from typing import Dict, Any, List, Optional from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.conn import DataArgs, DataManager, DataReceiver, DisaggregationMode, DataPoll from lightx2v.disagg.utils import ( + estimate_encoder_buffer_sizes, load_wan_transformer, load_wan_vae_decoder, ) +from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer +from lightx2v.disagg.mooncake import MooncakeTransferEngine +from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.utils import seed_all, save_to_video, wan_vae_to_comfy @@ -17,6 +23,11 @@ def __init__(self, config): self.transformer = None self.vae_decoder = None self.scheduler = None + self.transfer_engine = None + self._rdma_buffers: List[torch.Tensor] = [] + self.engine_rank = 1 + self.data_mgr = None + self.data_receiver = None self.load_models() @@ -24,6 +35,33 @@ def __init__(self, config): if "seed" in self.config: seed_all(self.config["seed"]) + data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") + data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + + if data_bootstrap_addr is None or data_bootstrap_room is None: + return + + request = AllocationRequest( + bootstrap_room=str(data_bootstrap_room), + config=self.config, + ) + handle = self.alloc_memory(request) + data_ptrs = [buf.addr for buf in handle.buffers] + data_lens = [buf.nbytes for buf in handle.buffers] + + data_args = DataArgs( + engine_rank=self.engine_rank, + data_ptrs=data_ptrs, + data_lens=data_lens, + data_item_lens=data_lens, + ib_device=None, + ) + self.data_mgr = DataManager(data_args, DisaggregationMode.TRANSFORMER) + self.data_receiver = DataReceiver( + self.data_mgr, data_bootstrap_addr, int(data_bootstrap_room) + ) + self.data_receiver.init() + def load_models(self): self.logger.info("Loading Transformer Models...") @@ -36,14 +74,128 @@ def load_models(self): self.logger.info("Transformer Models loaded successfully.") - def process(self, inputs: Dict[str, Any]): + def alloc_memory(self, request: AllocationRequest) -> MemoryHandle: """ - Executes the diffusion process and video decoding. - + Estimate upper-bound memory for encoder results and allocate GPU buffers. + Args: - inputs: Dictionary containing 'text_encoder_output', 'image_encoder_output', and 'latent_shape'. + request: AllocationRequest containing config and tensor specs. + + Returns: + MemoryHandle with RDMA-registered buffer addresses. + """ + config = request.config + estimated_sizes = estimate_encoder_buffer_sizes(config) + buffer_sizes = estimated_sizes + + if self.transfer_engine is None: + self.transfer_engine = MooncakeTransferEngine() + + self._rdma_buffers = [] + buffers: List[RemoteBuffer] = [] + for nbytes in buffer_sizes: + if nbytes <= 0: + continue + buf = torch.empty((nbytes,), dtype=torch.uint8, device=torch.device(f"cuda:{self.engine_rank}")) + ptr = buf.data_ptr() + self.transfer_engine.register(ptr, nbytes) + self._rdma_buffers.append(buf) + buffers.append( + RemoteBuffer(addr=ptr, session_id=self.transfer_engine.get_session_id(), nbytes=nbytes) + ) + + return MemoryHandle(buffers=buffers) + + def process(self): + """ + Executes the diffusion process and video decoding. """ self.logger.info("Starting processing in TransformerService...") + + # Poll for data from EncoderService + import time + if self.data_receiver is not None: + while True: + status = self.data_receiver.poll() + if status == DataPoll.Success: + break + time.sleep(0.01) + else: + self.logger.warning("DataReceiver is not initialized. Using dummy or existing data if any.") + pass + + # Reconstruct inputs from _rdma_buffers + text_len = int(self.config.get("text_len", 512)) + text_dim = int(self.config.get("text_encoder_dim", 4096)) + clip_dim = int(self.config.get("clip_embed_dim", 1024)) + z_dim = int(self.config.get("vae_z_dim", 16)) + + vae_stride = self.config.get("vae_stride", (4, 8, 8)) + target_video_length = int(self.config.get("target_video_length", 81)) + target_height = int(self.config.get("target_height", 480)) + target_width = int(self.config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // int(vae_stride[0]) + h_prime = int(np.ceil(target_height / int(vae_stride[1]))) + w_prime = int(np.ceil(target_width / int(vae_stride[2]))) + + enable_cfg = bool(self.config.get("enable_cfg", False)) + task = self.config.get("task", "i2v") + use_image_encoder = bool(self.config.get("use_image_encoder", True)) + + buffer_index = 0 + + # 1. Text Context + context_buf = self._rdma_buffers[buffer_index] + context = context_buf.view(GET_DTYPE()).reshape(1, text_len, text_dim) + buffer_index += 1 + + context_null = None + if enable_cfg: + context_null_buf = self._rdma_buffers[buffer_index] + context_null = context_null_buf.view(GET_DTYPE()).reshape(1, text_len, text_dim) + buffer_index += 1 + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + image_encoder_output = {} + clip_encoder_out = None + vae_encoder_out_padded = None + + if task == "i2v": + if use_image_encoder: + clip_buf = self._rdma_buffers[buffer_index] + clip_encoder_out = clip_buf.view(GET_DTYPE()).reshape(clip_dim) + buffer_index += 1 + + vae_buf = self._rdma_buffers[buffer_index] + vae_encoder_out_padded = vae_buf.view(GET_DTYPE()).reshape(z_dim + 4, t_prime, h_prime, w_prime) + buffer_index += 1 + + latent_shape_buf = self._rdma_buffers[buffer_index] + latent_shape = latent_shape_buf.view(torch.int64).tolist() + + vae_encoder_out = None + if vae_encoder_out_padded is not None: + valid_t = latent_shape[1] + valid_h = latent_shape[2] + valid_w = latent_shape[3] + vae_encoder_out = vae_encoder_out_padded[:, :valid_t, :valid_h, :valid_w] + + if task == "i2v": + image_encoder_output["clip_encoder_out"] = clip_encoder_out + image_encoder_output["vae_encoder_out"] = vae_encoder_out + else: + image_encoder_output = None + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + "latent_shape": latent_shape, + } seed = self.config.get("seed") save_path = self.config.get("save_path") @@ -52,8 +204,6 @@ def process(self, inputs: Dict[str, Any]): if save_path is None: raise ValueError("save_path is required in config.") - image_encoder_output = inputs.get("image_encoder_output") - latent_shape = inputs.get("latent_shape") if latent_shape is None: raise ValueError("latent_shape is required in inputs.") @@ -91,3 +241,15 @@ def process(self, inputs: Dict[str, Any]): self.logger.info("Done!") return save_path + + def release_memory(self): + """ + Releases the RDMA buffers, deregisters them from transfer engine, and clears GPU cache. + """ + if self._rdma_buffers: + for buf in self._rdma_buffers: + if self.transfer_engine: + self.transfer_engine.deregister(buf.data_ptr()) + self._rdma_buffers = [] + + torch.cuda.empty_cache() diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 2c05af939..ddd9489ec 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -1,11 +1,12 @@ import json import logging +import math import os import torch import torch.distributed as dist import torchvision.transforms.functional as TF from PIL import Image -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v.utils.envs import GET_DTYPE @@ -377,4 +378,45 @@ def load_wan_transformer(config: Dict[str, Any]): logger.error(f"Unsupported model_cls: {config.get('model_cls')}") raise ValueError(f"Unsupported model_cls: {config.get('model_cls')}") - +def estimate_encoder_buffer_sizes(config: Dict[str, Any]) -> List[int]: + text_len = int(config.get("text_len", 512)) + enable_cfg = bool(config.get("enable_cfg", False)) + use_image_encoder = bool(config.get("use_image_encoder", True)) + task = config.get("task", "i2v") + + text_dim = int(config.get("text_encoder_dim", 4096)) + clip_dim = int(config.get("clip_embed_dim", 1024)) + z_dim = int(config.get("vae_z_dim", 16)) + + vae_stride = config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + + target_video_length = int(config.get("target_video_length", 81)) + target_height = int(config.get("target_height", 480)) + target_width = int(config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(math.ceil(target_height / stride_h)) + w_prime = int(math.ceil(target_width / stride_w)) + + bytes_per_elem = torch.tensor([], dtype=torch.float32).element_size() + int_bytes_per_elem = torch.tensor([], dtype=torch.int64).element_size() + + buffer_sizes = [] + context_bytes = text_len * text_dim * bytes_per_elem + buffer_sizes.append(context_bytes) + if enable_cfg: + buffer_sizes.append(context_bytes) + + if task == "i2v": + if use_image_encoder: + buffer_sizes.append(clip_dim * bytes_per_elem) + vae_bytes = (z_dim + 4) * t_prime * h_prime * w_prime * bytes_per_elem + buffer_sizes.append(vae_bytes) + + latent_shape_bytes = 4 * int_bytes_per_elem + buffer_sizes.append(latent_shape_bytes) + + return buffer_sizes From 4fa76469624336f6a1a9385b435938d77f614c30 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 6 Feb 2026 16:00:28 +0800 Subject: [PATCH 6/8] fix mooncake config --- configs/mooncake_config.json | 6 ++++++ lightx2v/disagg/mooncake.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 configs/mooncake_config.json diff --git a/configs/mooncake_config.json b/configs/mooncake_config.json new file mode 100644 index 000000000..3ca2cd12c --- /dev/null +++ b/configs/mooncake_config.json @@ -0,0 +1,6 @@ +{ + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE", + "protocol": "rdma", + "device_name": "" +} \ No newline at end of file diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py index adee5804b..f80a38b83 100644 --- a/lightx2v/disagg/mooncake.py +++ b/lightx2v/disagg/mooncake.py @@ -32,7 +32,7 @@ def from_file(file_path: str) -> "MooncakeTransferEngineConfig": @staticmethod def load_from_env() -> "MooncakeTransferEngineConfig": - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH", "/root/zht/LightX2V/configs/mooncake_config.json") if config_file_path is None: raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") return MooncakeTransferEngineConfig.from_file(config_file_path) From e861a8ec03ee34c8b9ad5048be934768cd4440e8 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 6 Feb 2026 16:03:38 +0800 Subject: [PATCH 7/8] fix some bugs --- lightx2v/disagg/conn.py | 9 +++++---- lightx2v/disagg/examples/wan_t2v_service.py | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index fb3a43358..7de25bace 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -7,6 +7,7 @@ from functools import cache from typing import Dict, List, Optional, Tuple from enum import Enum +from dataclasses import dataclass import numpy as np import numpy.typing as npt @@ -48,12 +49,13 @@ def group_concurrent_contiguous( return src_groups, dst_groups +@dataclass class DataArgs: engine_rank: int data_ptrs: list[int] data_lens: list[int] data_item_lens: list[int] - ib_device: str + ib_device: Optional[str] = None class DataPoll: @@ -295,18 +297,17 @@ def _connect(self, endpoint: str): socket.connect(endpoint) return socket - def init(self, data_indices: npt.NDArray[np.int64]): - self.data_mgr.enqueue_request(self.bootstrap_room, data_indices) + def init(self): packed_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.data_mgr.data_args.data_ptrs ) + self.data_mgr.enqueue_request(self.bootstrap_room, packed_data_ptrs) self._connect("tcp://" + self.encode_server_url).send_multipart( [ self.transformer_ip.encode("ascii"), self.session_id.encode("ascii"), str(self.bootstrap_room).encode("ascii"), packed_data_ptrs, - data_indices.tobytes(), ] ) diff --git a/lightx2v/disagg/examples/wan_t2v_service.py b/lightx2v/disagg/examples/wan_t2v_service.py index 09dc11501..9beb992f6 100644 --- a/lightx2v/disagg/examples/wan_t2v_service.py +++ b/lightx2v/disagg/examples/wan_t2v_service.py @@ -48,10 +48,10 @@ def main(): seed_all(seed) # Add seed into config so services can use it if needed - config.seed = seed - config.prompt = prompt - config.negative_prompt = negative_prompt - config.save_path = save_result_path + config["seed"] = seed + config["prompt"] = prompt + config["negative_prompt"] = negative_prompt + config["save_path"] = save_result_path # 2. Define service threads import threading From b12e39ceeab7c3c3dfc016a5abdb03c90b898024 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Wed, 11 Feb 2026 16:30:10 +0800 Subject: [PATCH 8/8] enable disaggregation --- lightx2v/disagg/conn.py | 25 ++-- lightx2v/disagg/examples/mooncake_client.py | 78 ++++++++++ lightx2v/disagg/examples/mooncake_server.py | 71 +++++++++ lightx2v/disagg/examples/wan_i2v_service.py | 11 +- lightx2v/disagg/mooncake.py | 18 ++- lightx2v/disagg/services/encoder.py | 157 +++++++++++++------- lightx2v/disagg/services/transformer.py | 157 ++++++++++++++++---- lightx2v/disagg/utils.py | 3 + 8 files changed, 418 insertions(+), 102 deletions(-) create mode 100644 lightx2v/disagg/examples/mooncake_client.py create mode 100644 lightx2v/disagg/examples/mooncake_server.py diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index 7de25bace..1db1f4f55 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -4,6 +4,7 @@ import logging import struct import threading +import torch from functools import cache from typing import Dict, List, Optional, Tuple from enum import Enum @@ -51,7 +52,8 @@ def group_concurrent_contiguous( @dataclass class DataArgs: - engine_rank: int + sender_engine_rank: int + receiver_engine_rank: int data_ptrs: list[int] data_lens: list[int] data_item_lens: list[int] @@ -101,10 +103,6 @@ def register_buffer_to_engine(self): ): self.engine.register(data_ptr, data_len) - def update_data_args(self, data_args: DataArgs): - self.data_args = data_args - self.register_buffer_to_engine() - @cache def _connect(self, endpoint: str): socket = zmq.Context().socket(zmq.PUSH) @@ -141,7 +139,7 @@ def sync_status_to_transformer_endpoint(self, remote: str, room: int): "tcp://" + remote + ":" - + str(DATARECEIVER_POLLING_PORT + self.data_args.engine_rank) + + str(DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank) ).send_multipart( [ str(room).encode("ascii"), @@ -150,7 +148,8 @@ def sync_status_to_transformer_endpoint(self, remote: str, room: int): ) def start_encode_thread(self): - sender_rank_port = DATASENDER_POLLING_PORT + self.data_args.engine_rank + sender_rank_port = DATASENDER_POLLING_PORT + self.data_args.sender_engine_rank + logger.info("Encoder sender_rank_port=%s", sender_rank_port) self.server_socket.bind("tcp://*:" + str(sender_rank_port)) def encode_thread(): @@ -169,6 +168,13 @@ def encode_thread(): transformer_ptrs = list( struct.unpack(f"{len(transformer_ptrs)//8}Q", transformer_ptrs) ) + logger.info( + "Encoder received ZMQ: endpoint=%s session_id=%s room=%s transformer_ptrs=%s", + endpoint, + mooncake_session_id, + bootstrap_room, + transformer_ptrs, + ) self.waiting_pool[bootstrap_room] = ( endpoint, mooncake_session_id, @@ -212,7 +218,7 @@ def transfer_thread(): threading.Thread(target=transfer_thread).start() def start_transformer_thread(self): - receiver_rank_port = DATARECEIVER_POLLING_PORT + self.data_args.engine_rank + receiver_rank_port = DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) def transformer_thread(): @@ -285,8 +291,9 @@ def __init__( self.encode_server_url = ( bootstrap_addr.split(":")[0] + ":" - + str(DATASENDER_POLLING_PORT + self.data_mgr.data_args.engine_rank) + + str(DATASENDER_POLLING_PORT + self.data_mgr.data_args.sender_engine_rank) ) + logger.info("DataReceiver encode_server_url=%s", self.encode_server_url) self.transformer_ip = self.data_mgr.get_localhost() self.session_id = self.data_mgr.get_session_id() self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) diff --git a/lightx2v/disagg/examples/mooncake_client.py b/lightx2v/disagg/examples/mooncake_client.py new file mode 100644 index 000000000..daa2515b4 --- /dev/null +++ b/lightx2v/disagg/examples/mooncake_client.py @@ -0,0 +1,78 @@ + + +import numpy as np +import zmq +import torch +from mooncake.engine import TransferEngine + +def main(): + # Initialize ZMQ context and socket + context = zmq.Context() + socket = context.socket(zmq.PULL) + socket.connect(f"tcp://localhost:5555") + + # Wait for buffer info from server + print("Waiting for server buffer information...") + buffer_info = socket.recv_json() + server_session_id = buffer_info["session_id"] + server_ptr = buffer_info["ptr"] + server_len = buffer_info["len"] + print(f"Received server info - Session ID: {server_session_id}") + print(f"Server buffer address: {server_ptr}, length: {server_len}") + + # Initialize client engine + HOSTNAME = "localhost" # localhost for simple demo + METADATA_SERVER = "P2PHANDSHAKE" # [ETCD_SERVER_URL, P2PHANDSHAKE, ...] + PROTOCOL = "rdma" # [rdma, tcp, ...] + DEVICE_NAME = "" # auto discovery if empty + + client_engine = TransferEngine() + client_engine.initialize( + HOSTNAME, + METADATA_SERVER, + PROTOCOL, + DEVICE_NAME + ) + session_id = f"{HOSTNAME}:{client_engine.get_rpc_port()}" + + # Allocate and initialize client buffer (1MB) + client_buffer = torch.ones(1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")) # Fill with ones + client_ptr = client_buffer.data_ptr() + client_len = client_buffer.element_size() * client_buffer.nelement() + + # Register memory with Mooncake + if PROTOCOL == "rdma": + ret_value = client_engine.register_memory(client_ptr, client_len) + if ret_value != 0: + print("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + print(f"Client initialized with session ID: {session_id}") + + # Transfer data from client to server + print("Transferring data to server...") + for _ in range(10): + ret = client_engine.transfer_sync_write( + server_session_id, + client_ptr, + server_ptr, + min(client_len, server_len) # Transfer minimum of both lengths + ) + + if ret >= 0: + print("Transfer successful!") + else: + print("Transfer failed!") + + # Cleanup + if PROTOCOL == "rdma": + ret_value = client_engine.unregister_memory(client_ptr) + if ret_value != 0: + print("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + socket.close() + context.term() + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/mooncake_server.py b/lightx2v/disagg/examples/mooncake_server.py new file mode 100644 index 000000000..ca0d1e78c --- /dev/null +++ b/lightx2v/disagg/examples/mooncake_server.py @@ -0,0 +1,71 @@ + +import numpy as np +import zmq +import torch +from mooncake.engine import TransferEngine + +def main(): + # Initialize ZMQ context and socket + context = zmq.Context() + socket = context.socket(zmq.PUSH) + socket.bind("tcp://*:5555") # Bind to port 5555 for buffer info + + HOSTNAME = "localhost" # localhost for simple demo + METADATA_SERVER = "P2PHANDSHAKE" # [ETCD_SERVER_URL, P2PHANDSHAKE, ...] + PROTOCOL = "rdma" # [rdma, tcp, ...] + DEVICE_NAME = "" # auto discovery if empty + + # Initialize server engine + server_engine = TransferEngine() + server_engine.initialize( + HOSTNAME, + METADATA_SERVER, + PROTOCOL, + DEVICE_NAME + ) + session_id = f"{HOSTNAME}:{server_engine.get_rpc_port()}" + + # Allocate memory on server side (1MB buffer) + server_buffer = torch.zeros(1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:1")) + server_ptr = server_buffer.data_ptr() + server_len = server_buffer.element_size() * server_buffer.nelement() + + # Register memory with Mooncake + if PROTOCOL == "rdma": + ret_value = server_engine.register_memory(server_ptr, server_len) + if ret_value != 0: + print("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + print(f"Server initialized with session ID: {session_id}") + print(f"Server buffer address: {server_ptr}, length: {server_len}") + + # Send buffer info to client + buffer_info = { + "session_id": session_id, + "ptr": server_ptr, + "len": server_len + } + socket.send_json(buffer_info) + print("Buffer information sent to client") + + # Keep server running + try: + while True: + input("Press Ctrl+C to exit...") + except KeyboardInterrupt: + print("\nShutting down server...") + finally: + # Cleanup + if PROTOCOL == "rdma": + ret_value = server_engine.unregister_memory(server_ptr) + if ret_value != 0: + print("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + socket.close() + context.term() + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/lightx2v/disagg/examples/wan_i2v_service.py b/lightx2v/disagg/examples/wan_i2v_service.py index 6e0185406..124a99b40 100644 --- a/lightx2v/disagg/examples/wan_i2v_service.py +++ b/lightx2v/disagg/examples/wan_i2v_service.py @@ -60,16 +60,16 @@ def main(): data_bootstrap_room=0, ) - config.image_path = image_path - config.prompt = prompt - config.negative_prompt = negative_prompt - config.save_path = save_result_path + config["image_path"] = image_path + config["prompt"] = prompt + config["negative_prompt"] = negative_prompt + config["save_path"] = save_result_path logger.info(f"Config initialized for task: {task}") seed_all(seed) # 2. Add seed to config so services use it - config.seed = seed + config["seed"] = seed # 3. Define service threads import threading @@ -79,6 +79,7 @@ def run_encoder(): encoder_service = EncoderService(config) logger.info("Running Encoder Service...") encoder_service.process() + logger.info("Encoder Service completed.") encoder_service.release_memory() def run_transformer(): diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py index f80a38b83..6579d284c 100644 --- a/lightx2v/disagg/mooncake.py +++ b/lightx2v/disagg/mooncake.py @@ -59,22 +59,30 @@ def __init__(self): logger.error(f"Failed to load Mooncake config: {e}") raise - session_suffix = "_" + str(uuid.uuid4()) - self.session_id = self.config.local_hostname + session_suffix + # session_suffix = "_" + str(uuid.uuid4()) self.initialize( - self.session_id, + self.config.local_hostname, self.config.metadata_server, self.config.protocol, self.config.device_name, ) + # session_suffix = ":" + self.engine.get_rpc_port() + # self.session_id = self.config.local_hostname + session_suffix + self.session_id = f"{self.config.local_hostname}:{self.engine.get_rpc_port()}" def register(self, ptr, length): if self.engine: - self.engine.register_memory(ptr, length) + ret = self.engine.register_memory(ptr, length) + if ret != 0: + logger.error("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") def deregister(self, ptr): if self.engine: - self.engine.unregister_memory(ptr) + ret = self.engine.unregister_memory(ptr) + if ret != 0: + logger.error("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") def initialize( self, diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 323e70128..351a1f0ab 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -1,5 +1,6 @@ -import math import torch +import hashlib +import json import numpy as np from typing import Dict, Any, Optional, List @@ -13,6 +14,7 @@ load_wan_image_encoder, load_wan_vae_encoder, read_image_input, + estimate_encoder_buffer_sizes, ) class EncoderService(BaseService): @@ -21,7 +23,8 @@ def __init__(self, config): self.text_encoder = None self.image_encoder = None self.vae_encoder = None - self.engine_rank = 0 + self.sender_engine_rank = int(self.config.get("sender_engine_rank", 0)) + self.receiver_engine_rank = int(self.config.get("receiver_engine_rank", 1)) self.data_mgr = None self.data_sender = None self._rdma_buffers: List[torch.Tensor] = [] @@ -37,13 +40,14 @@ def __init__(self, config): data_bootstrap_room = self.config.get("data_bootstrap_room", 0) if data_bootstrap_addr is not None and data_bootstrap_room is not None: - data_ptrs, data_lens, data_item_lens = self.alloc_bufs() + data_ptrs, data_lens = self.alloc_bufs() data_args = DataArgs( - engine_rank=self.engine_rank, + sender_engine_rank=self.sender_engine_rank, + receiver_engine_rank=self.receiver_engine_rank, data_ptrs=data_ptrs, data_lens=data_lens, - data_item_lens=data_item_lens, - ib_device=self.config.get("ib_device", ""), + data_item_lens=data_lens, + ib_device=None, ) self.data_mgr = DataManager(data_args, DisaggregationMode.ENCODE) self.data_sender = DataSender( @@ -128,52 +132,25 @@ def _get_vae_encoder_output(self, first_frame: torch.Tensor, latent_h: int, late return vae_encoder_out def alloc_bufs(self): - text_len = int(self.config.get("text_len", 512)) - enable_cfg = bool(self.config.get("enable_cfg", False)) - use_image_encoder = bool(self.config.get("use_image_encoder", True)) - task = self.config.get("task", "i2v") - - text_dim = int(self.config.get("text_encoder_dim", 4096)) - clip_dim = int(self.config.get("clip_embed_dim", 1024)) - z_dim = int(self.config.get("vae_z_dim", 16)) - - vae_stride = self.config.get("vae_stride", (4, 8, 8)) - stride_t = int(vae_stride[0]) - stride_h = int(vae_stride[1]) - stride_w = int(vae_stride[2]) - - target_video_length = int(self.config.get("target_video_length", 81)) - target_height = int(self.config.get("target_height", 480)) - target_width = int(self.config.get("target_width", 832)) - - t_prime = 1 + (target_video_length - 1) // stride_t - h_prime = int(math.ceil(target_height / stride_h)) - w_prime = int(math.ceil(target_width / stride_w)) - + # torch.cuda.set_device(self.sender_engine_rank) + buffer_sizes = estimate_encoder_buffer_sizes(self.config) self._rdma_buffers = [] data_ptrs: List[int] = [] data_lens: List[int] = [] - data_item_lens: List[int] = [] - def _alloc_buffer(shape, dtype): - buf = torch.empty(shape, dtype=dtype, device=torch.device(f"cuda:{self.engine_rank}")) + for nbytes in buffer_sizes: + if nbytes <= 0: + continue + buf = torch.empty( + (nbytes,), + dtype=torch.uint8, + # device=torch.device(f"cuda:{self.sender_engine_rank}"), + ) self._rdma_buffers.append(buf) - nbytes = buf.numel() * buf.element_size() data_ptrs.append(buf.data_ptr()) data_lens.append(nbytes) - data_item_lens.append(nbytes) - - _alloc_buffer((1, text_len, text_dim), GET_DTYPE()) - if enable_cfg: - _alloc_buffer((1, text_len, text_dim), GET_DTYPE()) - - if task == "i2v": - if use_image_encoder: - _alloc_buffer((clip_dim,), GET_DTYPE()) - _alloc_buffer((z_dim + 4, t_prime, h_prime, w_prime), GET_DTYPE()) - _alloc_buffer((4,), torch.int64) - return data_ptrs, data_lens, data_item_lens + return data_ptrs, data_lens def process(self) -> Dict[str, Any]: """ @@ -206,6 +183,7 @@ def process(self) -> Dict[str, Any]: } task = self.config.get("task") + clip_encoder_out = None if task == "t2v": latent_h = self.config["target_height"] // self.config["vae_stride"][1] @@ -225,7 +203,6 @@ def process(self) -> Dict[str, Any]: # 2. Image Encoding + VAE Encoding img, _ = read_image_input(image_path) - clip_encoder_out = None if self.image_encoder is not None: # Assuming image_encoder.visual handles list of images clip_encoder_out = self.image_encoder.visual([img]).squeeze(0).to(GET_DTYPE()) @@ -243,23 +220,70 @@ def process(self) -> Dict[str, Any]: else: raise ValueError(f"Unsupported task: {task}") + self.logger.info("Encode processing completed. Preparing to send data...") + if self.data_mgr is not None and self.data_sender is not None: + def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: + view = torch.empty(0, dtype=dtype, device=buf.device) + view.set_(buf.untyped_storage(), 0, shape) + return view + + def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + data_tensor = tensor.detach() + if data_tensor.dtype == torch.bfloat16: + data_tensor = data_tensor.to(torch.float32) + data = data_tensor.contiguous().cpu().numpy().tobytes() + return hashlib.sha256(data).hexdigest() + + text_len = int(self.config.get("text_len", 512)) + text_dim = int(self.config.get("text_encoder_dim", 4096)) + clip_dim = int(self.config.get("clip_embed_dim", 1024)) + z_dim = int(self.config.get("vae_z_dim", 16)) + + vae_stride = self.config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + + target_video_length = int(self.config.get("target_video_length", 81)) + target_height = int(self.config.get("target_height", 480)) + target_width = int(self.config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(np.ceil(target_height / stride_h)) + w_prime = int(np.ceil(target_width / stride_w)) + buffer_index = 0 - self._rdma_buffers[buffer_index].copy_(context) + context_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (1, text_len, text_dim) + ) + context_buf.copy_(context) buffer_index += 1 if self.config.get("enable_cfg", False): - self._rdma_buffers[buffer_index].copy_(context_null) + context_null_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (1, text_len, text_dim) + ) + context_null_buf.copy_(context_null) buffer_index += 1 if task == "i2v": if self.config.get("use_image_encoder", True): + clip_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (clip_dim,) + ) if image_encoder_output.get("clip_encoder_out") is not None: - self._rdma_buffers[buffer_index].copy_(image_encoder_output["clip_encoder_out"]) + clip_buf.copy_(image_encoder_output["clip_encoder_out"]) else: - self._rdma_buffers[buffer_index].zero_() + clip_buf.zero_() buffer_index += 1 - vae_buf = self._rdma_buffers[buffer_index] + vae_buf = _buffer_view( + self._rdma_buffers[buffer_index], + GET_DTYPE(), + (z_dim + 4, t_prime, h_prime, w_prime), + ) vae_buf.zero_() vae_flat = vae_buf.view(-1) src_flat = image_encoder_output["vae_encoder_out"].reshape(-1) @@ -267,7 +291,32 @@ def process(self) -> Dict[str, Any]: buffer_index += 1 latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) - self._rdma_buffers[buffer_index].copy_(latent_tensor) + latent_buf = _buffer_view( + self._rdma_buffers[buffer_index], torch.int64, (4,) + ) + latent_buf.copy_(latent_tensor) + buffer_index += 1 + + meta = { + "version": 1, + "context_shape": list(context.shape), + "context_hash": _sha256_tensor(context), + "context_null_shape": list(context_null.shape) if context_null is not None else None, + "context_null_hash": _sha256_tensor(context_null), + "clip_shape": list(clip_encoder_out.shape) if clip_encoder_out is not None else None, + "clip_hash": _sha256_tensor(clip_encoder_out), + "vae_shape": list(image_encoder_output["vae_encoder_out"].shape) if image_encoder_output is not None else None, + "vae_hash": _sha256_tensor(image_encoder_output["vae_encoder_out"]) if image_encoder_output is not None else None, + "latent_shape": list(latent_shape), + "latent_hash": _sha256_tensor(latent_tensor), + } + meta_bytes = json.dumps(meta, ensure_ascii=True).encode("utf-8") + meta_buf = _buffer_view(self._rdma_buffers[buffer_index], torch.uint8, (self._rdma_buffers[buffer_index].numel(),)) + if meta_bytes and len(meta_bytes) > meta_buf.numel(): + raise ValueError("metadata buffer too small for hash/shape payload") + meta_buf.zero_() + if meta_bytes: + meta_buf[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) buffer_ptrs = [buf.data_ptr() for buf in self._rdma_buffers] self.data_sender.send(buffer_ptrs) @@ -283,5 +332,9 @@ def release_memory(self): """ Releases the RDMA buffers and clears GPU cache. """ - self._rdma_buffers = [] + if self._rdma_buffers: + for buf in self._rdma_buffers: + if self.data_mgr is not None: + self.data_mgr.engine.deregister(buf.data_ptr()) + self._rdma_buffers = [] torch.cuda.empty_cache() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 517aa369a..1c3336636 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -1,4 +1,7 @@ import torch +import torch.nn.functional as F +import hashlib +import json import logging import numpy as np from typing import Dict, Any, List, Optional @@ -11,7 +14,6 @@ load_wan_vae_decoder, ) from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer -from lightx2v.disagg.mooncake import MooncakeTransferEngine from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.utils.envs import GET_DTYPE @@ -23,9 +25,9 @@ def __init__(self, config): self.transformer = None self.vae_decoder = None self.scheduler = None - self.transfer_engine = None self._rdma_buffers: List[torch.Tensor] = [] - self.engine_rank = 1 + self.sender_engine_rank = int(self.config.get("sender_engine_rank", 0)) + self.receiver_engine_rank = int(self.config.get("receiver_engine_rank", 1)) self.data_mgr = None self.data_receiver = None @@ -50,7 +52,8 @@ def __init__(self, config): data_lens = [buf.nbytes for buf in handle.buffers] data_args = DataArgs( - engine_rank=self.engine_rank, + sender_engine_rank=self.sender_engine_rank, + receiver_engine_rank=self.receiver_engine_rank, data_ptrs=data_ptrs, data_lens=data_lens, data_item_lens=data_lens, @@ -88,20 +91,26 @@ def alloc_memory(self, request: AllocationRequest) -> MemoryHandle: estimated_sizes = estimate_encoder_buffer_sizes(config) buffer_sizes = estimated_sizes - if self.transfer_engine is None: - self.transfer_engine = MooncakeTransferEngine() + # torch.cuda.set_device(self.receiver_engine_rank) self._rdma_buffers = [] buffers: List[RemoteBuffer] = [] for nbytes in buffer_sizes: if nbytes <= 0: continue - buf = torch.empty((nbytes,), dtype=torch.uint8, device=torch.device(f"cuda:{self.engine_rank}")) + buf = torch.empty((nbytes,), dtype=torch.uint8, #device=torch.device(f"cuda:{self.receiver_engine_rank}") + ) ptr = buf.data_ptr() - self.transfer_engine.register(ptr, nbytes) self._rdma_buffers.append(buf) + session_id = self.data_mgr.get_session_id() if self.data_mgr is not None else "" buffers.append( - RemoteBuffer(addr=ptr, session_id=self.transfer_engine.get_session_id(), nbytes=nbytes) + RemoteBuffer(addr=ptr, session_id=session_id, nbytes=nbytes) + ) + + if buffers: + self.logger.info( + "Transformer allocated RDMA buffers: %s", + [buf.addr for buf in buffers], ) return MemoryHandle(buffers=buffers) @@ -111,6 +120,37 @@ def process(self): Executes the diffusion process and video decoding. """ self.logger.info("Starting processing in TransformerService...") + + def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: + view = torch.empty(0, dtype=dtype, device=buf.device) + view.set_(buf.untyped_storage(), 0, shape) + if view.device != torch.device(AI_DEVICE): + view = view.to(torch.device(AI_DEVICE)) + return view + + def _align_vae_to_latents(vae: torch.Tensor, target_t: int, target_h: int, target_w: int) -> torch.Tensor: + if vae is None: + return vae + _, t, h, w = vae.shape + t_slice = min(t, target_t) + h_slice = min(h, target_h) + w_slice = min(w, target_w) + vae = vae[:, :t_slice, :h_slice, :w_slice] + pad_t = target_t - t_slice + pad_h = target_h - h_slice + pad_w = target_w - w_slice + if pad_t > 0 or pad_h > 0 or pad_w > 0: + vae = F.pad(vae, (0, pad_w, 0, pad_h, 0, pad_t)) + return vae + + def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + data_tensor = tensor.detach() + if data_tensor.dtype == torch.bfloat16: + data_tensor = data_tensor.to(torch.float32) + data = data_tensor.contiguous().cpu().numpy().tobytes() + return hashlib.sha256(data).hexdigest() # Poll for data from EncoderService import time @@ -118,6 +158,7 @@ def process(self): while True: status = self.data_receiver.poll() if status == DataPoll.Success: + self.logger.info("Data received successfully in TransformerService.") break time.sleep(0.01) else: @@ -129,7 +170,7 @@ def process(self): text_dim = int(self.config.get("text_encoder_dim", 4096)) clip_dim = int(self.config.get("clip_embed_dim", 1024)) z_dim = int(self.config.get("vae_z_dim", 16)) - + vae_stride = self.config.get("vae_stride", (4, 8, 8)) target_video_length = int(self.config.get("target_video_length", 81)) target_height = int(self.config.get("target_height", 480)) @@ -138,24 +179,47 @@ def process(self): t_prime = 1 + (target_video_length - 1) // int(vae_stride[0]) h_prime = int(np.ceil(target_height / int(vae_stride[1]))) w_prime = int(np.ceil(target_width / int(vae_stride[2]))) - + enable_cfg = bool(self.config.get("enable_cfg", False)) task = self.config.get("task", "i2v") use_image_encoder = bool(self.config.get("use_image_encoder", True)) buffer_index = 0 - - # 1. Text Context + context_buf = self._rdma_buffers[buffer_index] - context = context_buf.view(GET_DTYPE()).reshape(1, text_len, text_dim) buffer_index += 1 - - context_null = None + + context_null_buf = None if enable_cfg: context_null_buf = self._rdma_buffers[buffer_index] - context_null = context_null_buf.view(GET_DTYPE()).reshape(1, text_len, text_dim) buffer_index += 1 - + + clip_buf = None + vae_buf = None + if task == "i2v": + if use_image_encoder: + clip_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + vae_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + latent_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + meta_buf = self._rdma_buffers[buffer_index] + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" + meta = json.loads(meta_str) if meta_str else {} + + context_shape = tuple(meta.get("context_shape") or (1, text_len, text_dim)) + context = _buffer_view(context_buf, GET_DTYPE(), context_shape) + + context_null = None + if enable_cfg and context_null_buf is not None: + context_null_shape = tuple(meta.get("context_null_shape") or (1, text_len, text_dim)) + context_null = _buffer_view(context_null_buf, GET_DTYPE(), context_null_shape) + text_encoder_output = { "context": context, "context_null": context_null, @@ -164,19 +228,17 @@ def process(self): image_encoder_output = {} clip_encoder_out = None vae_encoder_out_padded = None - + if task == "i2v": - if use_image_encoder: - clip_buf = self._rdma_buffers[buffer_index] - clip_encoder_out = clip_buf.view(GET_DTYPE()).reshape(clip_dim) - buffer_index += 1 + if use_image_encoder and clip_buf is not None: + clip_shape = tuple(meta.get("clip_shape") or (clip_dim,)) + clip_encoder_out = _buffer_view(clip_buf, GET_DTYPE(), clip_shape) - vae_buf = self._rdma_buffers[buffer_index] - vae_encoder_out_padded = vae_buf.view(GET_DTYPE()).reshape(z_dim + 4, t_prime, h_prime, w_prime) - buffer_index += 1 - - latent_shape_buf = self._rdma_buffers[buffer_index] - latent_shape = latent_shape_buf.view(torch.int64).tolist() + if vae_buf is not None: + vae_shape = tuple(meta.get("vae_shape") or (z_dim + 4, t_prime, h_prime, w_prime)) + vae_encoder_out_padded = _buffer_view(vae_buf, GET_DTYPE(), vae_shape) + + latent_shape = _buffer_view(latent_buf, torch.int64, (4,)).tolist() vae_encoder_out = None if vae_encoder_out_padded is not None: @@ -184,12 +246,45 @@ def process(self): valid_h = latent_shape[2] valid_w = latent_shape[3] vae_encoder_out = vae_encoder_out_padded[:, :valid_t, :valid_h, :valid_w] + vae_encoder_out = _align_vae_to_latents(vae_encoder_out, valid_t, valid_h, valid_w) if task == "i2v": image_encoder_output["clip_encoder_out"] = clip_encoder_out image_encoder_output["vae_encoder_out"] = vae_encoder_out else: image_encoder_output = None + + if meta: + if meta.get("context_shape") is not None and list(context.shape) != meta.get("context_shape"): + raise ValueError("context shape mismatch between encoder and transformer") + if meta.get("context_hash") is not None and _sha256_tensor(context) != meta.get("context_hash"): + raise ValueError("context hash mismatch between encoder and transformer") + if enable_cfg: + if meta.get("context_null_shape") is not None and context_null is not None: + if list(context_null.shape) != meta.get("context_null_shape"): + raise ValueError("context_null shape mismatch between encoder and transformer") + if meta.get("context_null_hash") is not None: + if _sha256_tensor(context_null) != meta.get("context_null_hash"): + raise ValueError("context_null hash mismatch between encoder and transformer") + if task == "i2v": + if meta.get("clip_shape") is not None and clip_encoder_out is not None: + if list(clip_encoder_out.shape) != meta.get("clip_shape"): + raise ValueError("clip shape mismatch between encoder and transformer") + if meta.get("clip_hash") is not None: + if _sha256_tensor(clip_encoder_out) != meta.get("clip_hash"): + raise ValueError("clip hash mismatch between encoder and transformer") + if meta.get("vae_shape") is not None and vae_encoder_out is not None: + if list(vae_encoder_out.shape) != meta.get("vae_shape"): + raise ValueError("vae shape mismatch between encoder and transformer") + if meta.get("vae_hash") is not None: + if _sha256_tensor(vae_encoder_out) != meta.get("vae_hash"): + raise ValueError("vae hash mismatch between encoder and transformer") + if meta.get("latent_shape") is not None and list(latent_shape) != meta.get("latent_shape"): + raise ValueError("latent_shape mismatch between encoder and transformer") + if meta.get("latent_hash") is not None: + latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) + if _sha256_tensor(latent_tensor) != meta.get("latent_hash"): + raise ValueError("latent_shape hash mismatch between encoder and transformer") inputs = { "text_encoder_output": text_encoder_output, @@ -248,8 +343,8 @@ def release_memory(self): """ if self._rdma_buffers: for buf in self._rdma_buffers: - if self.transfer_engine: - self.transfer_engine.deregister(buf.data_ptr()) + if self.data_mgr is not None: + self.data_mgr.engine.deregister(buf.data_ptr()) self._rdma_buffers = [] torch.cuda.empty_cache() diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index ddd9489ec..0a1936b67 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -419,4 +419,7 @@ def estimate_encoder_buffer_sizes(config: Dict[str, Any]) -> List[int]: latent_shape_bytes = 4 * int_bytes_per_elem buffer_sizes.append(latent_shape_bytes) + # Metadata buffer for integrity checks (hashes + shapes) + buffer_sizes.append(4096) + return buffer_sizes