diff --git a/configs/mooncake_config.json b/configs/mooncake_config.json new file mode 100644 index 00000000..3ca2cd12 --- /dev/null +++ b/configs/mooncake_config.json @@ -0,0 +1,6 @@ +{ + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE", + "protocol": "rdma", + "device_name": "" +} \ No newline at end of file diff --git a/lightx2v/disagg/__init__.py b/lightx2v/disagg/__init__.py new file mode 100644 index 00000000..45a7b4af --- /dev/null +++ b/lightx2v/disagg/__init__.py @@ -0,0 +1 @@ +# Disaggregation package initialization diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py new file mode 100644 index 00000000..1db1f4f5 --- /dev/null +++ b/lightx2v/disagg/conn.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import asyncio +import logging +import struct +import threading +import torch +from functools import cache +from typing import Dict, List, Optional, Tuple +from enum import Enum +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt +import zmq +from aiohttp import web + +from lightx2v.disagg.mooncake import MooncakeTransferEngine + +logger = logging.getLogger(__name__) + +class DisaggregationMode(Enum): + NULL = "null" + ENCODE = "encode" + TRANSFORMER = "transformer" + +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src_groups = [] + dst_groups = [] + current_src = [src_indices[0]] + current_dst = [dst_indices[0]] + + for i in range(1, len(src_indices)): + src_contiguous = src_indices[i] == src_indices[i - 1] + 1 + dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 + if src_contiguous and dst_contiguous: + current_src.append(src_indices[i]) + current_dst.append(dst_indices[i]) + else: + src_groups.append(current_src) + dst_groups.append(current_dst) + current_src = [src_indices[i]] + current_dst = [dst_indices[i]] + + src_groups.append(current_src) + dst_groups.append(current_dst) + + return src_groups, dst_groups + + +@dataclass +class DataArgs: + sender_engine_rank: int + receiver_engine_rank: int + data_ptrs: list[int] + data_lens: list[int] + data_item_lens: list[int] + ib_device: Optional[str] = None + + +class DataPoll: + Failed = 0 + Bootstrapping = 1 + WaitingForInput = 2 + Transferring = 3 + Success = 4 + + +RequestPoolType = Dict[int, List[int]] +WaitingPoolType = Dict[ + int, Tuple[str, list[int]] +] +DATASENDER_POLLING_PORT = 17788 +DATARECEIVER_POLLING_PORT = 27788 + + +class DataManager: + # TODO: make it general and support multiple transfer backend before merging + def __init__(self, args: DataArgs, disaggregation_mode: DisaggregationMode): + self.engine = MooncakeTransferEngine() + self.data_args = args + self.disaggregation_mode = disaggregation_mode + self.request_pool: RequestPoolType = {} + self.request_status: Dict[int, DataPoll] = {} + self.server_socket = zmq.Context().socket(zmq.PULL) + self.register_buffer_to_engine() + if self.disaggregation_mode == DisaggregationMode.ENCODE: + self.waiting_pool: WaitingPoolType = {} + self.transfer_event = threading.Event() + self.start_encode_thread() + elif self.disaggregation_mode == DisaggregationMode.TRANSFORMER: + self.start_transformer_thread() + else: + raise ValueError( + f"Unsupported DisaggregationMode: {self.disaggregation_mode}" + ) + + def register_buffer_to_engine(self): + for data_ptr, data_len in zip( + self.data_args.data_ptrs, self.data_args.data_lens + ): + self.engine.register(data_ptr, data_len) + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def send_data( + self, + mooncake_session_id: str, + encode_data_ptrs: List[int], + transformer_ptrs: list[int], + ): + tensor_num = int(len(self.data_args.data_ptrs)) + for tensor_id in range(tensor_num): + encode_addr = encode_data_ptrs[tensor_id] + item_len = self.data_args.data_item_lens[tensor_id] + transformer_addr = transformer_ptrs[tensor_id] + + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + encode_addr, + transformer_addr, + item_len, + ) + if status != 0: + return status + return 0 + + def sync_status_to_transformer_endpoint(self, remote: str, room: int): + if ":" in remote: + remote = remote.split(":")[0] + self._connect( + "tcp://" + + remote + + ":" + + str(DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank) + ).send_multipart( + [ + str(room).encode("ascii"), + str(self.request_status[room]).encode("ascii"), + ] + ) + + def start_encode_thread(self): + sender_rank_port = DATASENDER_POLLING_PORT + self.data_args.sender_engine_rank + logger.info("Encoder sender_rank_port=%s", sender_rank_port) + self.server_socket.bind("tcp://*:" + str(sender_rank_port)) + + def encode_thread(): + while True: + ( + endpoint, + mooncake_session_id, + bootstrap_room, + transformer_ptrs, + ) = self.server_socket.recv_multipart() + if bootstrap_room.decode("ascii") == "None": + continue + endpoint = endpoint.decode("ascii") + mooncake_session_id = mooncake_session_id.decode("ascii") + bootstrap_room = int(bootstrap_room.decode("ascii")) + transformer_ptrs = list( + struct.unpack(f"{len(transformer_ptrs)//8}Q", transformer_ptrs) + ) + logger.info( + "Encoder received ZMQ: endpoint=%s session_id=%s room=%s transformer_ptrs=%s", + endpoint, + mooncake_session_id, + bootstrap_room, + transformer_ptrs, + ) + self.waiting_pool[bootstrap_room] = ( + endpoint, + mooncake_session_id, + transformer_ptrs, + ) + self.transfer_event.set() + + threading.Thread(target=encode_thread).start() + + def transfer_thread(): + while True: + self.transfer_event.wait() + self.transfer_event.clear() + bootstrap_room_ready = self.request_pool.keys() + bootstrap_room_request = self.waiting_pool.keys() + for room in list(bootstrap_room_request): + if room not in list(bootstrap_room_ready): + continue + status = DataPoll.Transferring + self.request_status[room] = status + ( + endpoint, + mooncake_session_id, + transformer_ptrs, + ) = self.waiting_pool.pop(room) + self.sync_status_to_transformer_endpoint(endpoint, room) + encode_data_ptrs = self.request_pool.pop(room) + ret = self.send_data( + mooncake_session_id, + encode_data_ptrs, + transformer_ptrs, + ) + if ret != 0: + status = DataPoll.Failed + self.sync_status_to_transformer_endpoint(endpoint, room) + continue + status = DataPoll.Success + self.request_status[room] = status + self.sync_status_to_transformer_endpoint(endpoint, room) + + threading.Thread(target=transfer_thread).start() + + def start_transformer_thread(self): + receiver_rank_port = DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank + self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) + + def transformer_thread(): + while True: + (bootstrap_room, status) = self.server_socket.recv_multipart() + status = int(status.decode("ascii")) + bootstrap_room = int(bootstrap_room.decode("ascii")) + self.request_status[bootstrap_room] = status + + threading.Thread(target=transformer_thread).start() + + def enqueue_request( + self, + bootstrap_room: int, + data_ptrs: List[int], + ): + self.request_pool[bootstrap_room] = data_ptrs + self.request_status[bootstrap_room] = DataPoll.WaitingForInput + if self.disaggregation_mode == DisaggregationMode.ENCODE: + self.transfer_event.set() + + def check_status(self, bootstrap_room: int): + if ( + self.disaggregation_mode == DisaggregationMode.TRANSFORMER + and self.request_status[bootstrap_room] == DataPoll.Success + ): + if bootstrap_room in self.request_pool: + self.request_pool.pop(bootstrap_room) + + return self.request_status[bootstrap_room] + + def set_status(self, bootstrap_room: int, status: DataPoll): + self.request_status[bootstrap_room] = status + + def get_localhost(self): + return self.engine.get_localhost() + + def get_session_id(self): + return self.engine.get_session_id() + + +class DataSender: + + def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: int): + self.data_mgr = mgr + self.bootstrap_room = bootstrap_room + self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) + + def init(self, num_data_indices: int): + self.num_data_indices = num_data_indices + + def send(self, data_ptrs: List[int]): + self.data_mgr.enqueue_request(self.bootstrap_room, data_ptrs) + + def poll(self) -> DataPoll: + return self.data_mgr.check_status(self.bootstrap_room) + + def failure_exception(self): + raise Exception("Fake DataSender Exception") + + +class DataReceiver: + + def __init__( + self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None + ): + self.bootstrap_room = bootstrap_room + self.bootstrap_addr = bootstrap_addr + self.data_mgr = mgr + self.encode_server_url = ( + bootstrap_addr.split(":")[0] + + ":" + + str(DATASENDER_POLLING_PORT + self.data_mgr.data_args.sender_engine_rank) + ) + logger.info("DataReceiver encode_server_url=%s", self.encode_server_url) + self.transformer_ip = self.data_mgr.get_localhost() + self.session_id = self.data_mgr.get_session_id() + self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def init(self): + packed_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.data_mgr.data_args.data_ptrs + ) + self.data_mgr.enqueue_request(self.bootstrap_room, packed_data_ptrs) + self._connect("tcp://" + self.encode_server_url).send_multipart( + [ + self.transformer_ip.encode("ascii"), + self.session_id.encode("ascii"), + str(self.bootstrap_room).encode("ascii"), + packed_data_ptrs, + ] + ) + + def poll(self) -> DataPoll: + return self.data_mgr.check_status(self.bootstrap_room) + + def failure_exception(self): + raise Exception("Fake DataReceiver Exception") + diff --git a/lightx2v/disagg/examples/mooncake_client.py b/lightx2v/disagg/examples/mooncake_client.py new file mode 100644 index 00000000..daa2515b --- /dev/null +++ b/lightx2v/disagg/examples/mooncake_client.py @@ -0,0 +1,78 @@ + + +import numpy as np +import zmq +import torch +from mooncake.engine import TransferEngine + +def main(): + # Initialize ZMQ context and socket + context = zmq.Context() + socket = context.socket(zmq.PULL) + socket.connect(f"tcp://localhost:5555") + + # Wait for buffer info from server + print("Waiting for server buffer information...") + buffer_info = socket.recv_json() + server_session_id = buffer_info["session_id"] + server_ptr = buffer_info["ptr"] + server_len = buffer_info["len"] + print(f"Received server info - Session ID: {server_session_id}") + print(f"Server buffer address: {server_ptr}, length: {server_len}") + + # Initialize client engine + HOSTNAME = "localhost" # localhost for simple demo + METADATA_SERVER = "P2PHANDSHAKE" # [ETCD_SERVER_URL, P2PHANDSHAKE, ...] + PROTOCOL = "rdma" # [rdma, tcp, ...] + DEVICE_NAME = "" # auto discovery if empty + + client_engine = TransferEngine() + client_engine.initialize( + HOSTNAME, + METADATA_SERVER, + PROTOCOL, + DEVICE_NAME + ) + session_id = f"{HOSTNAME}:{client_engine.get_rpc_port()}" + + # Allocate and initialize client buffer (1MB) + client_buffer = torch.ones(1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")) # Fill with ones + client_ptr = client_buffer.data_ptr() + client_len = client_buffer.element_size() * client_buffer.nelement() + + # Register memory with Mooncake + if PROTOCOL == "rdma": + ret_value = client_engine.register_memory(client_ptr, client_len) + if ret_value != 0: + print("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + print(f"Client initialized with session ID: {session_id}") + + # Transfer data from client to server + print("Transferring data to server...") + for _ in range(10): + ret = client_engine.transfer_sync_write( + server_session_id, + client_ptr, + server_ptr, + min(client_len, server_len) # Transfer minimum of both lengths + ) + + if ret >= 0: + print("Transfer successful!") + else: + print("Transfer failed!") + + # Cleanup + if PROTOCOL == "rdma": + ret_value = client_engine.unregister_memory(client_ptr) + if ret_value != 0: + print("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + socket.close() + context.term() + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/mooncake_server.py b/lightx2v/disagg/examples/mooncake_server.py new file mode 100644 index 00000000..ca0d1e78 --- /dev/null +++ b/lightx2v/disagg/examples/mooncake_server.py @@ -0,0 +1,71 @@ + +import numpy as np +import zmq +import torch +from mooncake.engine import TransferEngine + +def main(): + # Initialize ZMQ context and socket + context = zmq.Context() + socket = context.socket(zmq.PUSH) + socket.bind("tcp://*:5555") # Bind to port 5555 for buffer info + + HOSTNAME = "localhost" # localhost for simple demo + METADATA_SERVER = "P2PHANDSHAKE" # [ETCD_SERVER_URL, P2PHANDSHAKE, ...] + PROTOCOL = "rdma" # [rdma, tcp, ...] + DEVICE_NAME = "" # auto discovery if empty + + # Initialize server engine + server_engine = TransferEngine() + server_engine.initialize( + HOSTNAME, + METADATA_SERVER, + PROTOCOL, + DEVICE_NAME + ) + session_id = f"{HOSTNAME}:{server_engine.get_rpc_port()}" + + # Allocate memory on server side (1MB buffer) + server_buffer = torch.zeros(1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:1")) + server_ptr = server_buffer.data_ptr() + server_len = server_buffer.element_size() * server_buffer.nelement() + + # Register memory with Mooncake + if PROTOCOL == "rdma": + ret_value = server_engine.register_memory(server_ptr, server_len) + if ret_value != 0: + print("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + print(f"Server initialized with session ID: {session_id}") + print(f"Server buffer address: {server_ptr}, length: {server_len}") + + # Send buffer info to client + buffer_info = { + "session_id": session_id, + "ptr": server_ptr, + "len": server_len + } + socket.send_json(buffer_info) + print("Buffer information sent to client") + + # Keep server running + try: + while True: + input("Press Ctrl+C to exit...") + except KeyboardInterrupt: + print("\nShutting down server...") + finally: + # Cleanup + if PROTOCOL == "rdma": + ret_value = server_engine.unregister_memory(server_ptr) + if ret_value != 0: + print("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + socket.close() + context.term() + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/lightx2v/disagg/examples/wan_i2v.py b/lightx2v/disagg/examples/wan_i2v.py new file mode 100644 index 00000000..d6e60edf --- /dev/null +++ b/lightx2v/disagg/examples/wan_i2v.py @@ -0,0 +1,230 @@ +import logging + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from loguru import logger +from PIL import Image + +from lightx2v.disagg.utils import ( + load_wan_image_encoder, + load_wan_text_encoder, + load_wan_vae_decoder, + load_wan_vae_encoder, + load_wan_transformer, + set_config, + read_image_input, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import save_to_video, seed_all, wan_vae_to_comfy +from lightx2v_platform.base.global_var import AI_DEVICE + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + + +def get_latent_shape_with_lat_hw(config, latent_h, latent_w): + return [ + config.get("num_channels_latents", 16), + (config["target_video_length"] - 1) // config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + +def compute_latent_shape_from_image(config, image_tensor): + h, w = image_tensor.shape[2:] + aspect_ratio = h / w + max_area = config["target_height"] * config["target_width"] + + latent_h = round( + np.sqrt(max_area * aspect_ratio) + // config["vae_stride"][1] + // config["patch_size"][1] + * config["patch_size"][1] + ) + latent_w = round( + np.sqrt(max_area / aspect_ratio) + // config["vae_stride"][2] + // config["patch_size"][2] + * config["patch_size"][2] + ) + latent_shape = get_latent_shape_with_lat_hw(config, latent_h, latent_w) + return latent_shape, latent_h, latent_w + + +def get_vae_encoder_output(vae_encoder, config, first_frame, latent_h, latent_w): + h = latent_h * config["vae_stride"][1] + w = latent_w * config["vae_stride"][2] + + msk = torch.ones( + 1, + config["target_video_length"], + latent_h, + latent_w, + device=torch.device(AI_DEVICE), + ) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, latent_h, latent_w) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, config["target_video_length"] - 1, h, w), + ], + dim=1, + ).to(AI_DEVICE) + + vae_encoder_out = vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) + vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) + return vae_encoder_out + + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" + task = "i2v" + model_cls = "wan2.2_moe" + + # Generation parameters + seed = 42 + prompt = ( + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " + "and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if " + "savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details " + "and the refreshing atmosphere of the seaside." + ) + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走" + ) + image_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" + save_result_path = "/root/zht/LightX2V/save_results/wan_i2v_A14B_disagg.mp4" + + # Initialize configuration + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=40, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=[3.5, 3.5], + sample_shift=5.0, + fps=16, + enable_cfg=True, + use_image_encoder=False, + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Load Models + logger.info("Loading models...") + + text_encoders = load_wan_text_encoder(config) + text_encoder = text_encoders[0] + + image_encoder = load_wan_image_encoder(config) + + model = load_wan_transformer(config) + + vae_encoder = load_wan_vae_encoder(config) + vae_decoder = load_wan_vae_decoder(config) + + logger.info("Models loaded successfully.") + + # 3. Initialize Scheduler + scheduler = WanScheduler(config) + model.set_scheduler(scheduler) + + # 4. Run Inference Pipeline + + # 4.1 Text Encoding + logger.info("Running text encoding...") + text_len = config.get("text_len", 512) + + context = text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + if config.get("enable_cfg", False): + context_null = text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + # 4.2 Image Encoding + VAE Encoding + logger.info("Running image encoding...") + img, _ = read_image_input(image_path) + + if image_encoder is not None: + clip_encoder_out = image_encoder.visual([img]).squeeze(0).to(GET_DTYPE()) + else: + clip_encoder_out = None + + if vae_encoder is None: + raise RuntimeError("VAE encoder is required for i2v task but was not loaded.") + + latent_shape, latent_h, latent_w = compute_latent_shape_from_image(config, img) + vae_encoder_out = get_vae_encoder_output(vae_encoder, config, img, latent_h, latent_w) + + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + } + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + } + + # 4.3 Scheduler Preparation + logger.info("Preparing scheduler...") + scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=image_encoder_output) + + # 4.4 Denoising Loop + logger.info("Starting denoising loop...") + infer_steps = scheduler.infer_steps + + for step_index in range(infer_steps): + logger.info(f"Step {step_index + 1}/{infer_steps}") + scheduler.step_pre(step_index=step_index) + model.infer(inputs) + scheduler.step_post() + + latents = scheduler.latents + + # 4.5 VAE Decoding + logger.info("Decoding latents...") + gen_video = vae_decoder.decode(latents.to(GET_DTYPE())) + + # 5. Post-processing and Saving + logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + logger.info(f"Saving video to {save_result_path}...") + save_to_video(gen_video_final, save_result_path, fps=config.get("fps", 16), method="ffmpeg") + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/wan_i2v_service.py b/lightx2v/disagg/examples/wan_i2v_service.py new file mode 100644 index 00000000..124a99b4 --- /dev/null +++ b/lightx2v/disagg/examples/wan_i2v_service.py @@ -0,0 +1,107 @@ +import logging +import os +import torch +from loguru import logger + +from lightx2v.disagg.utils import set_config +from lightx2v.disagg.services.encoder import EncoderService +from lightx2v.disagg.services.transformer import TransformerService +from lightx2v.utils.utils import seed_all + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B" + task = "i2v" + model_cls = "wan2.2_moe" + + # Generation parameters + seed = 42 + prompt = ( + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " + "and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if " + "savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details " + "and the refreshing atmosphere of the seaside." + ) + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走" + ) + image_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG" + save_result_path = "/root/zht/LightX2V/save_results/wan_i2v_A14B_disagg_service.mp4" + + # Initialize configuration + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=40, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=[3.5, 3.5], + sample_shift=5.0, + fps=16, + enable_cfg=True, + use_image_encoder=False, + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, + data_bootstrap_addr="127.0.0.1", + data_bootstrap_room=0, + ) + + config["image_path"] = image_path + config["prompt"] = prompt + config["negative_prompt"] = negative_prompt + config["save_path"] = save_result_path + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Add seed to config so services use it + config["seed"] = seed + + # 3. Define service threads + import threading + + def run_encoder(): + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + logger.info("Running Encoder Service...") + encoder_service.process() + logger.info("Encoder Service completed.") + encoder_service.release_memory() + + def run_transformer(): + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + logger.info("Running Transformer Service...") + result_path = transformer_service.process() + logger.info(f"Video generation completed. Saved to: {result_path}") + transformer_service.release_memory() + + # 4. Start threads + encoder_thread = threading.Thread(target=run_encoder) + transformer_thread = threading.Thread(target=run_transformer) + + logger.info("Starting services in separate threads...") + encoder_thread.start() + transformer_thread.start() + + # 5. Wait for completion + encoder_thread.join() + transformer_thread.join() + logger.info("All services finished.") + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/wan_t2v.py b/lightx2v/disagg/examples/wan_t2v.py new file mode 100644 index 00000000..99586b2a --- /dev/null +++ b/lightx2v/disagg/examples/wan_t2v.py @@ -0,0 +1,152 @@ +import os +import torch +import logging +from loguru import logger +from lightx2v.disagg.utils import ( + load_wan_text_encoder, + load_wan_vae_decoder, + load_wan_transformer, + set_config, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.utils import seed_all, save_to_video, wan_vae_to_comfy +from lightx2v.utils.input_info import init_empty_input_info +from lightx2v.utils.envs import GET_DTYPE + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" + task = "t2v" + model_cls = "wan2.1" + save_result_path = "/root/zht/LightX2V/save_results/test_disagg.mp4" + + # Generation parameters + seed = 42 + prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + # Initialize configuration + # Note: We pass generation parameters as kwargs to override defaults/config file settings + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + # Configuration parameters from pipe.create_generator in original example + attn_mode= "sage_attn2", + infer_steps=50, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + # Default parameters usually set in LightX2VPipeline or config setup + enable_cfg=True + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # 2. Load Models + logger.info("Loading models...") + + # Text Encoder (T5) + text_encoders = load_wan_text_encoder(config) + text_encoder = text_encoders[0] + + # Transformer (WanModel) + model = load_wan_transformer(config) + + # VAE Decoder + vae_decoder = load_wan_vae_decoder(config) + + logger.info("Models loaded successfully.") + + # 3. Initialize Scheduler + # Only supporting basic NoCaching scheduler for this simple example as per default config + scheduler = WanScheduler(config) + model.set_scheduler(scheduler) + + # 4. Run Inference Pipeline + + # 4.1 Text Encoding + logger.info("Running text encoding...") + text_len = config.get("text_len", 512) + + # Context (Prompt) + context = text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + # Context Null (Negative Prompt) for CFG + if config.get("enable_cfg", False): + context_null = text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + # 4.2 Prepare Inputs for Transformer + # Wan T2V input construction + # We need to construct the 'inputs' dictionary expected by model.infer + + # Calculate latent shape + # Logic from DefaultRunner.get_latent_shape_with_target_hw + latent_h = config["target_height"] // config["vae_stride"][1] + latent_w = config["target_width"] // config["vae_stride"][2] + latent_shape = [ + config.get("num_channels_latents", 16), + (config["target_video_length"] - 1) // config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": None # T2V usually doesn't need image encoder output unless specified + } + + # 4.3 Scheduler Preparation + logger.info("Preparing scheduler...") + scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=None) + + # 4.4 Denoising Loop + logger.info("Starting denoising loop...") + infer_steps = scheduler.infer_steps + + for step_index in range(infer_steps): + logger.info(f"Step {step_index + 1}/{infer_steps}") + + # Pre-step + scheduler.step_pre(step_index=step_index) + + # Model Inference + model.infer(inputs) + + # Post-step + scheduler.step_post() + + latents = scheduler.latents + + # 4.5 VAE Decoding + logger.info("Decoding latents...") + # Decode latents to video frames + # latents need to be cast to correct dtype usually + gen_video = vae_decoder.decode(latents.to(GET_DTYPE())) + + # 5. Post-processing and Saving + logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + logger.info(f"Saving video to {save_result_path}...") + save_to_video(gen_video_final, save_result_path, fps=config.get("fps", 16), method="ffmpeg") + logger.info("Done!") + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/wan_t2v_service.py b/lightx2v/disagg/examples/wan_t2v_service.py new file mode 100644 index 00000000..9beb992f --- /dev/null +++ b/lightx2v/disagg/examples/wan_t2v_service.py @@ -0,0 +1,89 @@ +import logging +from loguru import logger + +from lightx2v.disagg.services.encoder import EncoderService +from lightx2v.disagg.services.transformer import TransformerService +from lightx2v.disagg.utils import set_config +from lightx2v.utils.utils import seed_all + +# Setup basic logging +logging.basicConfig(level=logging.INFO) + + +def main(): + # 1. Configuration + model_path = "/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B" + task = "t2v" + model_cls = "wan2.1" + save_result_path = "/root/zht/LightX2V/save_results/test_disagg.mp4" + + # Generation parameters + seed = 42 + prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + negative_prompt = ( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部," + "畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + + # Initialize configuration (same as wan_t2v.py) + config = set_config( + model_path=model_path, + task=task, + model_cls=model_cls, + attn_mode="sage_attn2", + infer_steps=50, + target_height=480, + target_width=832, + target_video_length=81, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + enable_cfg=True, + data_bootstrap_addr="127.0.0.1", + data_bootstrap_room=0, + ) + + logger.info(f"Config initialized for task: {task}") + seed_all(seed) + + # Add seed into config so services can use it if needed + config["seed"] = seed + config["prompt"] = prompt + config["negative_prompt"] = negative_prompt + config["save_path"] = save_result_path + + # 2. Define service threads + import threading + + def run_encoder(): + logger.info("Initializing Encoder Service...") + encoder_service = EncoderService(config) + logger.info("Running Encoder Service...") + encoder_service.process() + encoder_service.release_memory() + + def run_transformer(): + logger.info("Initializing Transformer Service...") + transformer_service = TransformerService(config) + logger.info("Running Transformer Service...") + result_path = transformer_service.process() + logger.info(f"Video generation completed. Saved to: {result_path}") + transformer_service.release_memory() + + # 3. Start threads + encoder_thread = threading.Thread(target=run_encoder) + transformer_thread = threading.Thread(target=run_transformer) + + logger.info("Starting services in separate threads...") + encoder_thread.start() + transformer_thread.start() + + # 4. Wait for completion + encoder_thread.join() + transformer_thread.join() + logger.info("All services finished.") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py new file mode 100644 index 00000000..6579d284 --- /dev/null +++ b/lightx2v/disagg/mooncake.py @@ -0,0 +1,116 @@ +import json +import logging +import os +import uuid +import socket +import struct +import pickle +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + +logger = logging.getLogger(__name__) + +@dataclass +class MooncakeTransferEngineConfig: + local_hostname: str + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeTransferEngineConfig": + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + local_hostname=config.get("local_hostname", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "rdma"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> "MooncakeTransferEngineConfig": + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH", "/root/zht/LightX2V/configs/mooncake_config.json") + if config_file_path is None: + raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + def __init__(self): + self.engine = None + try: + from mooncake.engine import TransferEngine + self.engine = TransferEngine() + except ImportError as e: + logger.warning( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " + "to run with MooncakeTransferEngine." + ) + # We allow continuing without engine for non-transfer operations or testing structure + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except Exception as e: + logger.error(f"Failed to load Mooncake config: {e}") + raise + + # session_suffix = "_" + str(uuid.uuid4()) + self.initialize( + self.config.local_hostname, + self.config.metadata_server, + self.config.protocol, + self.config.device_name, + ) + # session_suffix = ":" + self.engine.get_rpc_port() + # self.session_id = self.config.local_hostname + session_suffix + self.session_id = f"{self.config.local_hostname}:{self.engine.get_rpc_port()}" + + def register(self, ptr, length): + if self.engine: + ret = self.engine.register_memory(ptr, length) + if ret != 0: + logger.error("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + def deregister(self, ptr): + if self.engine: + ret = self.engine.unregister_memory(ptr) + if ret != 0: + logger.error("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + def initialize( + self, + local_hostname: str, + metadata_server: str, + protocol: str, + device_name: str, + ) -> None: + """Initialize the mooncake instance.""" + if self.engine: + self.engine.initialize(local_hostname, metadata_server, protocol, device_name) + + def transfer_sync( + self, session_id: str, buffer: int, peer_buffer_address: int, length: int + ) -> int: + """Synchronously transfer data to the specified address.""" + if self.engine: + ret = self.engine.transfer_sync_write( + session_id, buffer, peer_buffer_address, length + ) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + return -1 + + def get_localhost(self): + return self.config.local_hostname + + def get_session_id(self): + return self.session_id diff --git a/lightx2v/disagg/protocol.py b/lightx2v/disagg/protocol.py new file mode 100644 index 00000000..1f8ff21b --- /dev/null +++ b/lightx2v/disagg/protocol.py @@ -0,0 +1,40 @@ +import zmq +import torch +import pickle +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from lightx2v.disagg.mooncake import MooncakeTransferEngine + +logger = logging.getLogger(__name__) + +@dataclass +class TensorMetadata: + id: int + shape: Tuple[int, ...] + dtype: torch.dtype + nbytes: int + +@dataclass +class AllocationRequest: + """ + Request sent from Sender (Encoder) to Receiver (Transformer). + - bootstrap_room: Unique ID for the transfer slot/session. + - config: Inference config used to estimate upper-bound buffer sizes. + """ + bootstrap_room: str + config: Dict[str, Any] + +@dataclass +class RemoteBuffer: + addr: int + session_id: str + nbytes: int + +@dataclass +class MemoryHandle: + """ + Handle sent from Receiver (Transformer) to Sender (Encoder). + - buffers: List of remote buffer details corresponding to the tensor_specs in the request. + """ + buffers: List[RemoteBuffer] diff --git a/lightx2v/disagg/services/__init__.py b/lightx2v/disagg/services/__init__.py new file mode 100644 index 00000000..54fa46a4 --- /dev/null +++ b/lightx2v/disagg/services/__init__.py @@ -0,0 +1 @@ +# Services package initialization diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py new file mode 100644 index 00000000..4e777fac --- /dev/null +++ b/lightx2v/disagg/services/base.py @@ -0,0 +1,25 @@ +import logging +import torch +from abc import ABC, abstractmethod + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class BaseService(ABC): + def __init__(self, config): + """ + Base initialization for all services. + Args: + config: A dictionary or object containing configuration parameters. + """ + self.config = config + self.logger = logger + self.logger.info(f"Initializing {self.__class__.__name__} with config: {config}") + + @abstractmethod + def load_models(self): + """ + Abstract method to load necessary models. + """ + pass diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py new file mode 100644 index 00000000..351a1f0a --- /dev/null +++ b/lightx2v/disagg/services/encoder.py @@ -0,0 +1,340 @@ +import torch +import hashlib +import json +import numpy as np +from typing import Dict, Any, Optional, List + +from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.conn import DataArgs, DataManager, DataSender, DisaggregationMode, DataPoll +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import seed_all +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.disagg.utils import ( + load_wan_text_encoder, + load_wan_image_encoder, + load_wan_vae_encoder, + read_image_input, + estimate_encoder_buffer_sizes, +) + +class EncoderService(BaseService): + def __init__(self, config): + super().__init__(config) + self.text_encoder = None + self.image_encoder = None + self.vae_encoder = None + self.sender_engine_rank = int(self.config.get("sender_engine_rank", 0)) + self.receiver_engine_rank = int(self.config.get("receiver_engine_rank", 1)) + self.data_mgr = None + self.data_sender = None + self._rdma_buffers: List[torch.Tensor] = [] + + # Load models based on config + self.load_models() + + # Seed everything if seed is in config + if "seed" in self.config: + seed_all(self.config["seed"]) + + data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") + data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + + if data_bootstrap_addr is not None and data_bootstrap_room is not None: + data_ptrs, data_lens = self.alloc_bufs() + data_args = DataArgs( + sender_engine_rank=self.sender_engine_rank, + receiver_engine_rank=self.receiver_engine_rank, + data_ptrs=data_ptrs, + data_lens=data_lens, + data_item_lens=data_lens, + ib_device=None, + ) + self.data_mgr = DataManager(data_args, DisaggregationMode.ENCODE) + self.data_sender = DataSender( + self.data_mgr, data_bootstrap_addr, int(data_bootstrap_room) + ) + + def load_models(self): + self.logger.info("Loading Encoder Models...") + + # T5 Text Encoder + text_encoders = load_wan_text_encoder(self.config) + self.text_encoder = text_encoders[0] if text_encoders else None + + # CLIP Image Encoder (Optional per usage in wan_i2v.py) + if self.config.get("use_image_encoder", False): + self.image_encoder = load_wan_image_encoder(self.config) + + # VAE Encoder (Required for I2V) + # Note: wan_i2v.py logic: if vae_encoder is None: raise RuntimeError + # But we only load if needed or always? Let's check the config flags. + # It seems always loaded for I2V task, but might be offloaded. + # For simplicity of this service, we load it if the task implies it or just try to load. + # But `load_wan_vae_encoder` will look at the config. + self.vae_encoder = load_wan_vae_encoder(self.config) + + self.logger.info("Encoder Models loaded successfully.") + + def _get_latent_shape_with_lat_hw(self, latent_h, latent_w): + return [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + + def _compute_latent_shape_from_image(self, image_tensor: torch.Tensor): + h, w = image_tensor.shape[2:] + aspect_ratio = h / w + max_area = self.config["target_height"] * self.config["target_width"] + + latent_h = round( + np.sqrt(max_area * aspect_ratio) + // self.config["vae_stride"][1] + // self.config["patch_size"][1] + * self.config["patch_size"][1] + ) + latent_w = round( + np.sqrt(max_area / aspect_ratio) + // self.config["vae_stride"][2] + // self.config["patch_size"][2] + * self.config["patch_size"][2] + ) + latent_shape = self._get_latent_shape_with_lat_hw(latent_h, latent_w) + return latent_shape, latent_h, latent_w + + def _get_vae_encoder_output(self, first_frame: torch.Tensor, latent_h: int, latent_w: int): + h = latent_h * self.config["vae_stride"][1] + w = latent_w * self.config["vae_stride"][2] + + msk = torch.ones( + 1, + self.config["target_video_length"], + latent_h, + latent_w, + device=torch.device(AI_DEVICE), + ) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, latent_h, latent_w) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, self.config["target_video_length"] - 1, h, w), + ], + dim=1, + ).to(AI_DEVICE) + + vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) + vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) + return vae_encoder_out + + def alloc_bufs(self): + # torch.cuda.set_device(self.sender_engine_rank) + buffer_sizes = estimate_encoder_buffer_sizes(self.config) + self._rdma_buffers = [] + data_ptrs: List[int] = [] + data_lens: List[int] = [] + + for nbytes in buffer_sizes: + if nbytes <= 0: + continue + buf = torch.empty( + (nbytes,), + dtype=torch.uint8, + # device=torch.device(f"cuda:{self.sender_engine_rank}"), + ) + self._rdma_buffers.append(buf) + data_ptrs.append(buf.data_ptr()) + data_lens.append(nbytes) + + return data_ptrs, data_lens + + def process(self) -> Dict[str, Any]: + """ + Generates encoder outputs from prompt and image input. + """ + self.logger.info("Starting processing in EncoderService...") + + prompt = self.config.get("prompt") + negative_prompt = self.config.get("negative_prompt") + if prompt is None: + raise ValueError("prompt is required in config.") + + # 1. Text Encoding + text_len = self.config.get("text_len", 512) + + context = self.text_encoder.infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context]) + + if self.config.get("enable_cfg", False): + if negative_prompt is None: + raise ValueError("negative_prompt is required in config when enable_cfg is True.") + context_null = self.text_encoder.infer([negative_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(text_len - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + task = self.config.get("task") + clip_encoder_out = None + + if task == "t2v": + latent_h = self.config["target_height"] // self.config["vae_stride"][1] + latent_w = self.config["target_width"] // self.config["vae_stride"][2] + latent_shape = [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + image_encoder_output = None + elif task == "i2v": + image_path = self.config.get("image_path") + if image_path is None: + raise ValueError("image_path is required for i2v task.") + + # 2. Image Encoding + VAE Encoding + img, _ = read_image_input(image_path) + + if self.image_encoder is not None: + # Assuming image_encoder.visual handles list of images + clip_encoder_out = self.image_encoder.visual([img]).squeeze(0).to(GET_DTYPE()) + + if self.vae_encoder is None: + raise RuntimeError("VAE encoder is required but was not loaded.") + + latent_shape, latent_h, latent_w = self._compute_latent_shape_from_image(img) + vae_encoder_out = self._get_vae_encoder_output(img, latent_h, latent_w) + + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + } + else: + raise ValueError(f"Unsupported task: {task}") + + self.logger.info("Encode processing completed. Preparing to send data...") + + if self.data_mgr is not None and self.data_sender is not None: + def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: + view = torch.empty(0, dtype=dtype, device=buf.device) + view.set_(buf.untyped_storage(), 0, shape) + return view + + def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + data_tensor = tensor.detach() + if data_tensor.dtype == torch.bfloat16: + data_tensor = data_tensor.to(torch.float32) + data = data_tensor.contiguous().cpu().numpy().tobytes() + return hashlib.sha256(data).hexdigest() + + text_len = int(self.config.get("text_len", 512)) + text_dim = int(self.config.get("text_encoder_dim", 4096)) + clip_dim = int(self.config.get("clip_embed_dim", 1024)) + z_dim = int(self.config.get("vae_z_dim", 16)) + + vae_stride = self.config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + + target_video_length = int(self.config.get("target_video_length", 81)) + target_height = int(self.config.get("target_height", 480)) + target_width = int(self.config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(np.ceil(target_height / stride_h)) + w_prime = int(np.ceil(target_width / stride_w)) + + buffer_index = 0 + context_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (1, text_len, text_dim) + ) + context_buf.copy_(context) + buffer_index += 1 + if self.config.get("enable_cfg", False): + context_null_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (1, text_len, text_dim) + ) + context_null_buf.copy_(context_null) + buffer_index += 1 + + if task == "i2v": + if self.config.get("use_image_encoder", True): + clip_buf = _buffer_view( + self._rdma_buffers[buffer_index], GET_DTYPE(), (clip_dim,) + ) + if image_encoder_output.get("clip_encoder_out") is not None: + clip_buf.copy_(image_encoder_output["clip_encoder_out"]) + else: + clip_buf.zero_() + buffer_index += 1 + + vae_buf = _buffer_view( + self._rdma_buffers[buffer_index], + GET_DTYPE(), + (z_dim + 4, t_prime, h_prime, w_prime), + ) + vae_buf.zero_() + vae_flat = vae_buf.view(-1) + src_flat = image_encoder_output["vae_encoder_out"].reshape(-1) + vae_flat[: src_flat.numel()].copy_(src_flat) + buffer_index += 1 + + latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) + latent_buf = _buffer_view( + self._rdma_buffers[buffer_index], torch.int64, (4,) + ) + latent_buf.copy_(latent_tensor) + buffer_index += 1 + + meta = { + "version": 1, + "context_shape": list(context.shape), + "context_hash": _sha256_tensor(context), + "context_null_shape": list(context_null.shape) if context_null is not None else None, + "context_null_hash": _sha256_tensor(context_null), + "clip_shape": list(clip_encoder_out.shape) if clip_encoder_out is not None else None, + "clip_hash": _sha256_tensor(clip_encoder_out), + "vae_shape": list(image_encoder_output["vae_encoder_out"].shape) if image_encoder_output is not None else None, + "vae_hash": _sha256_tensor(image_encoder_output["vae_encoder_out"]) if image_encoder_output is not None else None, + "latent_shape": list(latent_shape), + "latent_hash": _sha256_tensor(latent_tensor), + } + meta_bytes = json.dumps(meta, ensure_ascii=True).encode("utf-8") + meta_buf = _buffer_view(self._rdma_buffers[buffer_index], torch.uint8, (self._rdma_buffers[buffer_index].numel(),)) + if meta_bytes and len(meta_bytes) > meta_buf.numel(): + raise ValueError("metadata buffer too small for hash/shape payload") + meta_buf.zero_() + if meta_bytes: + meta_buf[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) + + buffer_ptrs = [buf.data_ptr() for buf in self._rdma_buffers] + self.data_sender.send(buffer_ptrs) + + import time + while True: + status = self.data_sender.poll() + if status == DataPoll.Success: + break + time.sleep(0.01) + + def release_memory(self): + """ + Releases the RDMA buffers and clears GPU cache. + """ + if self._rdma_buffers: + for buf in self._rdma_buffers: + if self.data_mgr is not None: + self.data_mgr.engine.deregister(buf.data_ptr()) + self._rdma_buffers = [] + torch.cuda.empty_cache() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py new file mode 100644 index 00000000..1c333663 --- /dev/null +++ b/lightx2v/disagg/services/transformer.py @@ -0,0 +1,350 @@ +import torch +import torch.nn.functional as F +import hashlib +import json +import logging +import numpy as np +from typing import Dict, Any, List, Optional + +from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.conn import DataArgs, DataManager, DataReceiver, DisaggregationMode, DataPoll +from lightx2v.disagg.utils import ( + estimate_encoder_buffer_sizes, + load_wan_transformer, + load_wan_vae_decoder, +) +from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import seed_all, save_to_video, wan_vae_to_comfy + +class TransformerService(BaseService): + def __init__(self, config): + super().__init__(config) + self.transformer = None + self.vae_decoder = None + self.scheduler = None + self._rdma_buffers: List[torch.Tensor] = [] + self.sender_engine_rank = int(self.config.get("sender_engine_rank", 0)) + self.receiver_engine_rank = int(self.config.get("receiver_engine_rank", 1)) + self.data_mgr = None + self.data_receiver = None + + self.load_models() + + # Set global seed if present in config, though specific process calls might reuse it + if "seed" in self.config: + seed_all(self.config["seed"]) + + data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") + data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + + if data_bootstrap_addr is None or data_bootstrap_room is None: + return + + request = AllocationRequest( + bootstrap_room=str(data_bootstrap_room), + config=self.config, + ) + handle = self.alloc_memory(request) + data_ptrs = [buf.addr for buf in handle.buffers] + data_lens = [buf.nbytes for buf in handle.buffers] + + data_args = DataArgs( + sender_engine_rank=self.sender_engine_rank, + receiver_engine_rank=self.receiver_engine_rank, + data_ptrs=data_ptrs, + data_lens=data_lens, + data_item_lens=data_lens, + ib_device=None, + ) + self.data_mgr = DataManager(data_args, DisaggregationMode.TRANSFORMER) + self.data_receiver = DataReceiver( + self.data_mgr, data_bootstrap_addr, int(data_bootstrap_room) + ) + self.data_receiver.init() + + def load_models(self): + self.logger.info("Loading Transformer Models...") + + self.transformer = load_wan_transformer(self.config) + self.vae_decoder = load_wan_vae_decoder(self.config) + + # Initialize scheduler + self.scheduler = WanScheduler(self.config) + self.transformer.set_scheduler(self.scheduler) + + self.logger.info("Transformer Models loaded successfully.") + + def alloc_memory(self, request: AllocationRequest) -> MemoryHandle: + """ + Estimate upper-bound memory for encoder results and allocate GPU buffers. + + Args: + request: AllocationRequest containing config and tensor specs. + + Returns: + MemoryHandle with RDMA-registered buffer addresses. + """ + config = request.config + estimated_sizes = estimate_encoder_buffer_sizes(config) + buffer_sizes = estimated_sizes + + # torch.cuda.set_device(self.receiver_engine_rank) + + self._rdma_buffers = [] + buffers: List[RemoteBuffer] = [] + for nbytes in buffer_sizes: + if nbytes <= 0: + continue + buf = torch.empty((nbytes,), dtype=torch.uint8, #device=torch.device(f"cuda:{self.receiver_engine_rank}") + ) + ptr = buf.data_ptr() + self._rdma_buffers.append(buf) + session_id = self.data_mgr.get_session_id() if self.data_mgr is not None else "" + buffers.append( + RemoteBuffer(addr=ptr, session_id=session_id, nbytes=nbytes) + ) + + if buffers: + self.logger.info( + "Transformer allocated RDMA buffers: %s", + [buf.addr for buf in buffers], + ) + + return MemoryHandle(buffers=buffers) + + def process(self): + """ + Executes the diffusion process and video decoding. + """ + self.logger.info("Starting processing in TransformerService...") + + def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: + view = torch.empty(0, dtype=dtype, device=buf.device) + view.set_(buf.untyped_storage(), 0, shape) + if view.device != torch.device(AI_DEVICE): + view = view.to(torch.device(AI_DEVICE)) + return view + + def _align_vae_to_latents(vae: torch.Tensor, target_t: int, target_h: int, target_w: int) -> torch.Tensor: + if vae is None: + return vae + _, t, h, w = vae.shape + t_slice = min(t, target_t) + h_slice = min(h, target_h) + w_slice = min(w, target_w) + vae = vae[:, :t_slice, :h_slice, :w_slice] + pad_t = target_t - t_slice + pad_h = target_h - h_slice + pad_w = target_w - w_slice + if pad_t > 0 or pad_h > 0 or pad_w > 0: + vae = F.pad(vae, (0, pad_w, 0, pad_h, 0, pad_t)) + return vae + + def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + data_tensor = tensor.detach() + if data_tensor.dtype == torch.bfloat16: + data_tensor = data_tensor.to(torch.float32) + data = data_tensor.contiguous().cpu().numpy().tobytes() + return hashlib.sha256(data).hexdigest() + + # Poll for data from EncoderService + import time + if self.data_receiver is not None: + while True: + status = self.data_receiver.poll() + if status == DataPoll.Success: + self.logger.info("Data received successfully in TransformerService.") + break + time.sleep(0.01) + else: + self.logger.warning("DataReceiver is not initialized. Using dummy or existing data if any.") + pass + + # Reconstruct inputs from _rdma_buffers + text_len = int(self.config.get("text_len", 512)) + text_dim = int(self.config.get("text_encoder_dim", 4096)) + clip_dim = int(self.config.get("clip_embed_dim", 1024)) + z_dim = int(self.config.get("vae_z_dim", 16)) + + vae_stride = self.config.get("vae_stride", (4, 8, 8)) + target_video_length = int(self.config.get("target_video_length", 81)) + target_height = int(self.config.get("target_height", 480)) + target_width = int(self.config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // int(vae_stride[0]) + h_prime = int(np.ceil(target_height / int(vae_stride[1]))) + w_prime = int(np.ceil(target_width / int(vae_stride[2]))) + + enable_cfg = bool(self.config.get("enable_cfg", False)) + task = self.config.get("task", "i2v") + use_image_encoder = bool(self.config.get("use_image_encoder", True)) + + buffer_index = 0 + + context_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + context_null_buf = None + if enable_cfg: + context_null_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + clip_buf = None + vae_buf = None + if task == "i2v": + if use_image_encoder: + clip_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + vae_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + latent_buf = self._rdma_buffers[buffer_index] + buffer_index += 1 + + meta_buf = self._rdma_buffers[buffer_index] + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" + meta = json.loads(meta_str) if meta_str else {} + + context_shape = tuple(meta.get("context_shape") or (1, text_len, text_dim)) + context = _buffer_view(context_buf, GET_DTYPE(), context_shape) + + context_null = None + if enable_cfg and context_null_buf is not None: + context_null_shape = tuple(meta.get("context_null_shape") or (1, text_len, text_dim)) + context_null = _buffer_view(context_null_buf, GET_DTYPE(), context_null_shape) + + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + image_encoder_output = {} + clip_encoder_out = None + vae_encoder_out_padded = None + + if task == "i2v": + if use_image_encoder and clip_buf is not None: + clip_shape = tuple(meta.get("clip_shape") or (clip_dim,)) + clip_encoder_out = _buffer_view(clip_buf, GET_DTYPE(), clip_shape) + + if vae_buf is not None: + vae_shape = tuple(meta.get("vae_shape") or (z_dim + 4, t_prime, h_prime, w_prime)) + vae_encoder_out_padded = _buffer_view(vae_buf, GET_DTYPE(), vae_shape) + + latent_shape = _buffer_view(latent_buf, torch.int64, (4,)).tolist() + + vae_encoder_out = None + if vae_encoder_out_padded is not None: + valid_t = latent_shape[1] + valid_h = latent_shape[2] + valid_w = latent_shape[3] + vae_encoder_out = vae_encoder_out_padded[:, :valid_t, :valid_h, :valid_w] + vae_encoder_out = _align_vae_to_latents(vae_encoder_out, valid_t, valid_h, valid_w) + + if task == "i2v": + image_encoder_output["clip_encoder_out"] = clip_encoder_out + image_encoder_output["vae_encoder_out"] = vae_encoder_out + else: + image_encoder_output = None + + if meta: + if meta.get("context_shape") is not None and list(context.shape) != meta.get("context_shape"): + raise ValueError("context shape mismatch between encoder and transformer") + if meta.get("context_hash") is not None and _sha256_tensor(context) != meta.get("context_hash"): + raise ValueError("context hash mismatch between encoder and transformer") + if enable_cfg: + if meta.get("context_null_shape") is not None and context_null is not None: + if list(context_null.shape) != meta.get("context_null_shape"): + raise ValueError("context_null shape mismatch between encoder and transformer") + if meta.get("context_null_hash") is not None: + if _sha256_tensor(context_null) != meta.get("context_null_hash"): + raise ValueError("context_null hash mismatch between encoder and transformer") + if task == "i2v": + if meta.get("clip_shape") is not None and clip_encoder_out is not None: + if list(clip_encoder_out.shape) != meta.get("clip_shape"): + raise ValueError("clip shape mismatch between encoder and transformer") + if meta.get("clip_hash") is not None: + if _sha256_tensor(clip_encoder_out) != meta.get("clip_hash"): + raise ValueError("clip hash mismatch between encoder and transformer") + if meta.get("vae_shape") is not None and vae_encoder_out is not None: + if list(vae_encoder_out.shape) != meta.get("vae_shape"): + raise ValueError("vae shape mismatch between encoder and transformer") + if meta.get("vae_hash") is not None: + if _sha256_tensor(vae_encoder_out) != meta.get("vae_hash"): + raise ValueError("vae hash mismatch between encoder and transformer") + if meta.get("latent_shape") is not None and list(latent_shape) != meta.get("latent_shape"): + raise ValueError("latent_shape mismatch between encoder and transformer") + if meta.get("latent_hash") is not None: + latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) + if _sha256_tensor(latent_tensor) != meta.get("latent_hash"): + raise ValueError("latent_shape hash mismatch between encoder and transformer") + + inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + "latent_shape": latent_shape, + } + + seed = self.config.get("seed") + save_path = self.config.get("save_path") + if seed is None: + raise ValueError("seed is required in config.") + if save_path is None: + raise ValueError("save_path is required in config.") + + if latent_shape is None: + raise ValueError("latent_shape is required in inputs.") + + # Scheduler Preparation + self.logger.info(f"Preparing scheduler with seed {seed}...") + self.scheduler.prepare(seed=seed, latent_shape=latent_shape, image_encoder_output=image_encoder_output) + + # Denoising Loop + self.logger.info("Starting denoising loop...") + infer_steps = self.scheduler.infer_steps + + for step_index in range(infer_steps): + if step_index % 10 == 0: + self.logger.info(f"Step {step_index + 1}/{infer_steps}") + self.scheduler.step_pre(step_index=step_index) + self.transformer.infer(inputs) + self.scheduler.step_post() + + latents = self.scheduler.latents + + # VAE Decoding + self.logger.info("Decoding latents...") + if self.vae_decoder is None: + raise RuntimeError("VAE decoder is not loaded.") + + gen_video = self.vae_decoder.decode(latents.to(GET_DTYPE())) + + # Post-processing + self.logger.info("Post-processing video...") + gen_video_final = wan_vae_to_comfy(gen_video) + + # Saving + self.logger.info(f"Saving video to {save_path}...") + save_to_video(gen_video_final, save_path, fps=self.config.get("fps", 16), method="ffmpeg") + self.logger.info("Done!") + + return save_path + + def release_memory(self): + """ + Releases the RDMA buffers, deregisters them from transfer engine, and clears GPU cache. + """ + if self._rdma_buffers: + for buf in self._rdma_buffers: + if self.data_mgr is not None: + self.data_mgr.engine.deregister(buf.data_ptr()) + self._rdma_buffers = [] + + torch.cuda.empty_cache() diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py new file mode 100644 index 00000000..0a1936b6 --- /dev/null +++ b/lightx2v/disagg/utils.py @@ -0,0 +1,425 @@ +import json +import logging +import math +import os +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from PIL import Image +from typing import Dict, Any, List, Optional + +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.utils.envs import GET_DTYPE +from lightx2v.utils.utils import find_torch_model_path + +from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel +from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel +from lightx2v.models.video_encoders.hf.wan.vae import WanVAE +from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny, Wan2_2_VAE_tiny +from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.utils.set_config import get_default_config, set_config as set_config_base + +logger = logging.getLogger(__name__) + +class ConfigObj: + """Helper class to convert dictionary to object with attributes""" + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + +def read_image_input(image_path): + img_ori = Image.open(image_path).convert("RGB") + img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + return img, img_ori + +def set_config( + model_path, + task, + model_cls, + config_path=None, + attn_mode="flash_attn2", + rope_type="torch", + infer_steps=50, + target_video_length=81, + target_height=480, + target_width=832, + sample_guide_scale=5.0, + sample_shift=5.0, + fps=16, + aspect_ratio="16:9", + boundary=0.900, + boundary_step_index=2, + denoising_step_list=None, + audio_fps=24000, + double_precision_rope=True, + norm_modulate_backend="torch", + distilled_sigma_values=None, + cpu_offload=False, + offload_granularity="block", + text_encoder_offload=False, + image_encoder_offload=False, + vae_offload=False, + **kwargs + ): + """ + Load configuration for Wan model. + """ + if denoising_step_list is None: + denoising_step_list = [1000, 750, 500, 250] + + # Create arguments object similar to what set_config expects + args_dict = { + "task": task, + "model_path": model_path, + "model_cls": model_cls, + "config_json": config_path, + "cpu_offload": cpu_offload, + "offload_granularity": offload_granularity, + "t5_cpu_offload": text_encoder_offload, # Map to internal keys + "clip_cpu_offload": image_encoder_offload, # Map to internal keys + "vae_cpu_offload": vae_offload, # Map to internal keys + } + + # Simulate logic from LightX2VPipeline.create_generator + # which calls set_infer_config / set_infer_config_json + # Here we directly populate args_dict with the required inference config + + if config_path is not None: + with open(config_path, "r") as f: + config_json_content = json.load(f) + args_dict.update(config_json_content) + else: + # Replicating set_infer_config logic + if model_cls == "ltx2": + args_dict["distilled_sigma_values"] = distilled_sigma_values + args_dict["infer_steps"] = len(distilled_sigma_values) - 1 if distilled_sigma_values is not None else infer_steps + else: + args_dict["infer_steps"] = infer_steps + + args_dict["target_width"] = target_width + args_dict["target_height"] = target_height + args_dict["target_video_length"] = target_video_length + args_dict["sample_guide_scale"] = sample_guide_scale + args_dict["sample_shift"] = sample_shift + + if sample_guide_scale == 1 or (model_cls == "z_image" and sample_guide_scale == 0): + args_dict["enable_cfg"] = False + else: + args_dict["enable_cfg"] = True + + args_dict["rope_type"] = rope_type + args_dict["fps"] = fps + args_dict["aspect_ratio"] = aspect_ratio + args_dict["boundary"] = boundary + args_dict["boundary_step_index"] = boundary_step_index + args_dict["denoising_step_list"] = denoising_step_list + args_dict["audio_fps"] = audio_fps + args_dict["double_precision_rope"] = double_precision_rope + + if model_cls.startswith("wan"): + # Set all attention types to the provided attn_mode + args_dict["self_attn_1_type"] = attn_mode + args_dict["cross_attn_1_type"] = attn_mode + args_dict["cross_attn_2_type"] = attn_mode + elif model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image"]: + args_dict["attn_type"] = attn_mode + + args_dict["norm_modulate_backend"] = norm_modulate_backend + + args_dict.update(kwargs) + + # Convert to object for set_config compatibility + args = ConfigObj(**args_dict) + + # Use existing set_config from utils + config = set_config_base(args) + + return config + +def build_wan_model_with_lora(wan_module, config, model_kwargs, lora_configs, model_type="high_noise_model"): + lora_dynamic_apply = config.get("lora_dynamic_apply", False) + + if lora_dynamic_apply: + if model_type in ["high_noise_model", "low_noise_model"]: + # For wan2.2 + lora_name_to_info = {item["name"]: item for item in lora_configs} + lora_path = lora_name_to_info[model_type]["path"] + lora_strength = lora_name_to_info[model_type]["strength"] + else: + # For wan2.1 + lora_path = lora_configs[0]["path"] + lora_strength = lora_configs[0]["strength"] + + model_kwargs["lora_path"] = lora_path + model_kwargs["lora_strength"] = lora_strength + model = wan_module(**model_kwargs) + else: + assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported." + assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging." + model = wan_module(**model_kwargs) + lora_wrapper = WanLoraWrapper(model) + if model_type in ["high_noise_model", "low_noise_model"]: + lora_configs = [lora_config for lora_config in lora_configs if lora_config["name"] == model_type] + lora_wrapper.apply_lora(lora_configs, model_type=model_type) + return model + +def load_wan_text_encoder(config: Dict[str, Any]): + # offload config + t5_offload = config.get("t5_cpu_offload", config.get("cpu_offload")) + if t5_offload: + t5_device = torch.device("cpu") + else: + t5_device = torch.device(AI_DEVICE) + tokenizer_path = os.path.join(config["model_path"], "google/umt5-xxl") + # quant_config + t5_quantized = config.get("t5_quantized", False) + if t5_quantized: + t5_quant_scheme = config.get("t5_quant_scheme", None) + assert t5_quant_scheme is not None + tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0] + t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth" + t5_quantized_ckpt = find_torch_model_path(config, "t5_quantized_ckpt", t5_model_name) + t5_original_ckpt = None + else: + t5_quant_scheme = None + t5_quantized_ckpt = None + t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" + t5_original_ckpt = find_torch_model_path(config, "t5_original_ckpt", t5_model_name) + + text_encoder = T5EncoderModel( + text_len=config["text_len"], + dtype=torch.bfloat16, + device=t5_device, + checkpoint_path=t5_original_ckpt, + tokenizer_path=tokenizer_path, + shard_fn=None, + cpu_offload=t5_offload, + t5_quantized=t5_quantized, + t5_quantized_ckpt=t5_quantized_ckpt, + quant_scheme=t5_quant_scheme, + load_from_rank0=config.get("load_from_rank0", False), + lazy_load=config.get("t5_lazy_load", False), + ) + # Return single encoder to match original returning list + text_encoders = [text_encoder] + return text_encoders + +def load_wan_image_encoder(config: Dict[str, Any]): + image_encoder = None + if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True): + # offload config + clip_offload = config.get("clip_cpu_offload", config.get("cpu_offload", False)) + if clip_offload: + clip_device = torch.device("cpu") + else: + clip_device = torch.device(AI_DEVICE) + # quant_config + clip_quantized = config.get("clip_quantized", False) + if clip_quantized: + clip_quant_scheme = config.get("clip_quant_scheme", None) + assert clip_quant_scheme is not None + tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] + clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth" + clip_quantized_ckpt = find_torch_model_path(config, "clip_quantized_ckpt", clip_model_name) + clip_original_ckpt = None + else: + clip_quantized_ckpt = None + clip_quant_scheme = None + clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" + clip_original_ckpt = find_torch_model_path(config, "clip_original_ckpt", clip_model_name) + + image_encoder = CLIPModel( + dtype=torch.float16, + device=clip_device, + checkpoint_path=clip_original_ckpt, + clip_quantized=clip_quantized, + clip_quantized_ckpt=clip_quantized_ckpt, + quant_scheme=clip_quant_scheme, + cpu_offload=clip_offload, + use_31_block=config.get("use_31_block", True), + load_from_rank0=config.get("load_from_rank0", False), + ) + + return image_encoder + +def get_vae_parallel(config: Dict[str, Any]): + if isinstance(config.get("parallel", False), bool): + return config.get("parallel", False) + if isinstance(config.get("parallel", False), dict): + return config.get("parallel", {}).get("vae_parallel", True) + return False + +def load_wan_vae_encoder(config: Dict[str, Any]): + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") + if config.get("model_cls", "") == "wan2.2": + vae_cls = Wan2_2_VAE + else: + vae_cls = WanVAE + + # offload config + vae_offload = config.get("vae_cpu_offload", config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(config, "vae_path", vae_name), + "device": vae_device, + "parallel": get_vae_parallel(config), + "use_tiling": config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "load_from_rank0": config.get("load_from_rank0", False), + "use_lightvae": config.get("use_lightvae", False), + } + if config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: + return None + else: + return vae_cls(**vae_config) + +def load_wan_vae_decoder(config: Dict[str, Any]): + + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") + tiny_vae_name = "taew2_1.pth" + + if config.get("model_cls", "") == "wan2.2": + vae_cls = Wan2_2_VAE + tiny_vae_cls = Wan2_2_VAE_tiny + tiny_vae_name = "taew2_2.pth" + else: + vae_cls = WanVAE + tiny_vae_cls = WanVAE_tiny + tiny_vae_name = "taew2_1.pth" + + # offload config + vae_offload = config.get("vae_cpu_offload", config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(config, "vae_path", vae_name), + "device": vae_device, + "parallel": get_vae_parallel(config), + "use_tiling": config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "use_lightvae": config.get("use_lightvae", False), + "dtype": GET_DTYPE(), + "load_from_rank0": config.get("load_from_rank0", False), + } + if config.get("use_tae", False): + tae_path = find_torch_model_path(config, "tae_path", tiny_vae_name) + vae_decoder = tiny_vae_cls(vae_path=tae_path, device=AI_DEVICE, need_scaled=config.get("need_scaled", False)).to(AI_DEVICE) + else: + vae_decoder = vae_cls(**vae_config) + return vae_decoder + +def load_wan_transformer(config: Dict[str, Any]): + if config["cpu_offload"]: + init_device = torch.device("cpu") + else: + init_device = torch.device(AI_DEVICE) + + if config.get("model_cls") == "wan2.1": + wan_model_kwargs = {"model_path": config["model_path"], "config": config, "device": init_device} + lora_configs = config.get("lora_configs") + if not lora_configs: + model = WanModel(**wan_model_kwargs) + else: + model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") + return model + elif config.get("model_cls") == "wan2.2_moe": + from lightx2v.models.runners.wan.wan_runner import MultiModelStruct + + high_noise_model_path = os.path.join(config["model_path"], "high_noise_model") + if config.get("dit_quantized", False) and config.get("high_noise_quantized_ckpt", None): + high_noise_model_path = config["high_noise_quantized_ckpt"] + elif config.get("high_noise_original_ckpt", None): + high_noise_model_path = config["high_noise_original_ckpt"] + + low_noise_model_path = os.path.join(config["model_path"], "low_noise_model") + if config.get("dit_quantized", False) and config.get("low_noise_quantized_ckpt", None): + low_noise_model_path = config["low_noise_quantized_ckpt"] + elif not config.get("dit_quantized", False) and config.get("low_noise_original_ckpt", None): + low_noise_model_path = config["low_noise_original_ckpt"] + + if not config.get("lazy_load", False) and not config.get("unload_modules", False): + lora_configs = config.get("lora_configs") + high_model_kwargs = { + "model_path": high_noise_model_path, + "config": config, + "device": init_device, + "model_type": "wan2.2_moe_high_noise", + } + low_model_kwargs = { + "model_path": low_noise_model_path, + "config": config, + "device": init_device, + "model_type": "wan2.2_moe_low_noise", + } + if not lora_configs: + high_noise_model = WanModel(**high_model_kwargs) + low_noise_model = WanModel(**low_model_kwargs) + else: + high_noise_model = build_wan_model_with_lora(WanModel, config, high_model_kwargs, lora_configs, model_type="high_noise_model") + low_noise_model = build_wan_model_with_lora(WanModel, config, low_model_kwargs, lora_configs, model_type="low_noise_model") + + return MultiModelStruct([high_noise_model, low_noise_model], config, config.get("boundary", 0.875)) + else: + model_struct = MultiModelStruct([None, None], config, config.get("boundary", 0.875)) + model_struct.low_noise_model_path = low_noise_model_path + model_struct.high_noise_model_path = high_noise_model_path + model_struct.init_device = init_device + return model_struct + else: + logger.error(f"Unsupported model_cls: {config.get('model_cls')}") + raise ValueError(f"Unsupported model_cls: {config.get('model_cls')}") + +def estimate_encoder_buffer_sizes(config: Dict[str, Any]) -> List[int]: + text_len = int(config.get("text_len", 512)) + enable_cfg = bool(config.get("enable_cfg", False)) + use_image_encoder = bool(config.get("use_image_encoder", True)) + task = config.get("task", "i2v") + + text_dim = int(config.get("text_encoder_dim", 4096)) + clip_dim = int(config.get("clip_embed_dim", 1024)) + z_dim = int(config.get("vae_z_dim", 16)) + + vae_stride = config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + + target_video_length = int(config.get("target_video_length", 81)) + target_height = int(config.get("target_height", 480)) + target_width = int(config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(math.ceil(target_height / stride_h)) + w_prime = int(math.ceil(target_width / stride_w)) + + bytes_per_elem = torch.tensor([], dtype=torch.float32).element_size() + int_bytes_per_elem = torch.tensor([], dtype=torch.int64).element_size() + + buffer_sizes = [] + context_bytes = text_len * text_dim * bytes_per_elem + buffer_sizes.append(context_bytes) + if enable_cfg: + buffer_sizes.append(context_bytes) + + if task == "i2v": + if use_image_encoder: + buffer_sizes.append(clip_dim * bytes_per_elem) + vae_bytes = (z_dim + 4) * t_prime * h_prime * w_prime * bytes_per_elem + buffer_sizes.append(vae_bytes) + + latent_shape_bytes = 4 * int_bytes_per_elem + buffer_sizes.append(latent_shape_bytes) + + # Metadata buffer for integrity checks (hashes + shapes) + buffer_sizes.append(4096) + + return buffer_sizes