diff --git a/app.py b/app.py index 74b882b..7cbb45e 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,15 @@ import gradio as gr import os +# Load optional environment variables from .env file if it exists +if os.path.exists('.env'): + with open('.env') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + os.environ[key.strip()] = value.strip().strip('"') + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from datetime import datetime @@ -350,6 +359,9 @@ def get_seed(randomize_seed: bool, seed: int) -> int: def image_to_3d( image: Image.Image, + multi_images: List[Image.Image], + use_multi_image: bool, + multi_image_mode: str, seed: int, resolution: str, ss_guidance_strength: float, @@ -367,36 +379,205 @@ def image_to_3d( req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> str: - # --- Sampling --- - outputs, latents = pipeline.run( - image, - seed=seed, - preprocess_image=False, - sparse_structure_sampler_params={ - "steps": ss_sampling_steps, - "guidance_strength": ss_guidance_strength, - "guidance_rescale": ss_guidance_rescale, - "rescale_t": ss_rescale_t, - }, - shape_slat_sampler_params={ - "steps": shape_slat_sampling_steps, - "guidance_strength": shape_slat_guidance_strength, - "guidance_rescale": shape_slat_guidance_rescale, - "rescale_t": shape_slat_rescale_t, - }, - tex_slat_sampler_params={ - "steps": tex_slat_sampling_steps, - "guidance_strength": tex_slat_guidance_strength, - "guidance_rescale": tex_slat_guidance_rescale, - "rescale_t": tex_slat_rescale_t, - }, - pipeline_type={ + """ + Generate a 3D model from single or multiple images. + + Args: + image (Image.Image): Single input image (used when use_multi_image is False). + multi_images (List[Image.Image]): List of input images for multi-view generation. + use_multi_image (bool): Whether to use multi-image mode. + multi_image_mode (str): Fusion mode for multi-image conditioning. + - 'stochastic': Cycles through images sequentially (memory efficient) + - 'multidiffusion': Averages predictions from all images (higher quality) + seed (int): Random seed for generation. + resolution (str): Output resolution ('512', '1024', or '1536'). + ss_guidance_strength (float): Guidance strength for sparse structure sampling. + ss_guidance_rescale (float): Guidance rescale for sparse structure sampling. + ss_sampling_steps (int): Number of sampling steps for sparse structure. + ss_rescale_t (float): Rescale parameter for sparse structure sampling. + shape_slat_guidance_strength (float): Guidance strength for shape latent sampling. + shape_slat_guidance_rescale (float): Guidance rescale for shape latent sampling. + shape_slat_sampling_steps (int): Number of sampling steps for shape latent. + shape_slat_rescale_t (float): Rescale parameter for shape latent sampling. + tex_slat_guidance_strength (float): Guidance strength for texture latent sampling. + tex_slat_guidance_rescale (float): Guidance rescale for texture latent sampling. + tex_slat_sampling_steps (int): Number of sampling steps for texture latent. + tex_slat_rescale_t (float): Rescale parameter for texture latent sampling. + req (gr.Request): Gradio request object. + progress (gr.Progress): Gradio progress tracker. + + Returns: + str: JSON string containing rendered images and model state. + """ + # --- Check if multi-image mode is enabled --- + if use_multi_image: + if multi_images is None or len(multi_images) == 0: + raise gr.Error("Please upload images in the Multi-Image gallery before generating") + + # Multi-image processing - ensure all images are PIL Images + images_to_process = [] + for idx, img in enumerate(multi_images): + if img is not None: + # Handle different formats Gradio might return + if isinstance(img, tuple): + # Gallery returns (PIL.Image, caption/metadata) + img = img[0] # Extract the image from the tuple + elif isinstance(img, dict) and 'name' in img: + # Gradio sometimes returns dict with 'name' key pointing to file path + img_path = img['name'] + img = Image.open(img_path) + elif isinstance(img, str): + # File path string + img = Image.open(img) + elif isinstance(img, np.ndarray): + # NumPy array + img = Image.fromarray(img) + + # Verify we have a PIL Image + if not isinstance(img, Image.Image): + continue + + images_to_process.append(img) + + if len(images_to_process) == 0: + raise gr.Error("No valid images could be processed. Please check the image format.") + + # Get conditioning from multiple images + torch.manual_seed(seed) + pipeline_type = { "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", - }[resolution], - return_latent=True, - ) + }[resolution] + + # Process each image and stack along batch dimension + cond_list_512 = [pipeline.get_cond([img], 512)['cond'] for img in images_to_process] + stacked_cond_512 = torch.cat(cond_list_512, dim=0) + cond_512 = { + 'cond': stacked_cond_512, + 'neg_cond': torch.zeros_like(stacked_cond_512[:1]) + } + + cond_1024 = None + if resolution != '512': + cond_list_1024 = [pipeline.get_cond([img], 1024)['cond'] for img in images_to_process] + stacked_cond_1024 = torch.cat(cond_list_1024, dim=0) + cond_1024 = { + 'cond': stacked_cond_1024, + 'neg_cond': torch.zeros_like(stacked_cond_1024[:1]) + } + + # Sample sparse structure + ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] + with pipeline.inject_sampler_multi_image('sparse_structure_sampler', len(images_to_process), ss_sampling_steps, mode=multi_image_mode): + coords = pipeline.sample_sparse_structure( + cond_512, ss_res, + num_samples=1, + sampler_params={ + "steps": ss_sampling_steps, + "guidance_strength": ss_guidance_strength, + "guidance_rescale": ss_guidance_rescale, + "rescale_t": ss_rescale_t, + } + ) + + # Sample shape latent + if pipeline_type == '512': + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode): + shape_slat = pipeline.sample_shape_slat( + cond_512, pipeline.models['shape_slat_flow_model_512'], + coords, { + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + } + ) + tex_cond = cond_512 + tex_model = pipeline.models['tex_slat_flow_model_512'] + res = 512 + elif pipeline_type == '1024': + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode): + shape_slat = pipeline.sample_shape_slat( + cond_1024, pipeline.models['shape_slat_flow_model_1024'], + coords, { + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + } + ) + tex_cond = cond_1024 + tex_model = pipeline.models['tex_slat_flow_model_1024'] + res = 1024 + elif pipeline_type in ['1024_cascade', '1536_cascade']: + target_res = 1024 if pipeline_type == '1024_cascade' else 1536 + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode): + shape_slat, res = pipeline.sample_shape_slat_cascade( + cond_512, cond_1024, + pipeline.models['shape_slat_flow_model_512'], + pipeline.models['shape_slat_flow_model_1024'], + 512, target_res, + coords, { + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + }, + max_num_tokens=49152 + ) + tex_cond = cond_1024 + tex_model = pipeline.models['tex_slat_flow_model_1024'] + + # Sample texture latent + with pipeline.inject_sampler_multi_image('tex_slat_sampler', len(images_to_process), tex_slat_sampling_steps, mode=multi_image_mode): + tex_slat = pipeline.sample_tex_slat( + tex_cond, tex_model, + shape_slat, { + "steps": tex_slat_sampling_steps, + "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, + "rescale_t": tex_slat_rescale_t, + } + ) + + latents = (shape_slat, tex_slat, res) + outputs = pipeline.decode_latent(shape_slat, tex_slat, res) + else: + # Single-image processing + if image is None: + raise gr.Error("No image provided") + + outputs, latents = pipeline.run( + image, + seed=seed, + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "guidance_strength": ss_guidance_strength, + "guidance_rescale": ss_guidance_rescale, + "rescale_t": ss_rescale_t, + }, + shape_slat_sampler_params={ + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + }, + tex_slat_sampler_params={ + "steps": tex_slat_sampling_steps, + "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, + "rescale_t": tex_slat_rescale_t, + }, + pipeline_type={ + "512": "512", + "1024": "1024_cascade", + "1536": "1536_cascade", + }[resolution], + return_latent=True, + ) + mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) @@ -517,40 +698,72 @@ def extract_glb( with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) - * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset. - * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time. + * **Single Image Mode**: Upload one image and click Generate to create a 3D asset. + * **Multi-Image Mode**: Enable multi-image mode and upload multiple views for better 3D reconstruction. + * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. + * **Default settings are optimized for defect detection and anomaly inspection** (high resolution, detail preservation, multi-view consistency). """) - + with gr.Row(): with gr.Column(scale=1, min_width=360): - image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) - + # Mode selector + use_multi_image = gr.Checkbox(label="Enable Multi-Image Mode", value=False) + + # Single image input + image_prompt = gr.Image( + label="Single Image Input", + format="png", + image_mode="RGBA", + type="pil", + height=400, + visible=True + ) + + # Multi-image gallery + multi_image_prompt = gr.Gallery( + label="Multi-Image Input (Upload 2-8 views)", + format="png", + type="pil", + height=400, + columns=4, + visible=False + ) + + # Multi-image fusion mode + multi_image_mode = gr.Radio( + ["stochastic", "multidiffusion"], + label="Multi-Image Fusion Mode", + value="multidiffusion", + info="Stochastic: cycles through images (fast). Multidiffusion: averages all images (slower, better quality)", + visible=False + ) + resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") - seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) - randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) - decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000) - texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) - + seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) + decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=900000, step=10000) + texture_size = gr.Slider(1024, 4096, label="Texture Size", value=4096, step=1024) + generate_btn = gr.Button("Generate") - with gr.Accordion(label="Advanced Settings", open=False): + with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): - ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) - ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) - ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=8.5, step=0.1) + ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.75, step=0.01) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): - shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) - shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) - shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=8.5, step=0.1) + shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.6, step=0.01) + shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): - tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) - tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) - tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=4.0, step=0.1) + tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.3, step=0.01) + tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) with gr.Column(scale=10): @@ -581,7 +794,30 @@ def extract_glb( # Handlers demo.load(start_session) demo.unload(end_session) - + + # Toggle visibility based on mode + def toggle_multi_image_mode(enabled): + """ + Toggle visibility of UI components based on multi-image mode selection. + + Args: + enabled (bool): Whether multi-image mode is enabled. + + Returns: + dict: Dictionary of Gradio component updates for visibility toggling. + """ + return { + image_prompt: gr.Image(visible=not enabled), + multi_image_prompt: gr.Gallery(visible=enabled), + multi_image_mode: gr.Radio(visible=enabled) + } + + use_multi_image.change( + toggle_multi_image_mode, + inputs=[use_multi_image], + outputs=[image_prompt, multi_image_prompt, multi_image_mode] + ) + image_prompt.upload( preprocess_image, inputs=[image_prompt], @@ -597,7 +833,8 @@ def extract_glb( ).then( image_to_3d, inputs=[ - image_prompt, seed, resolution, + image_prompt, multi_image_prompt, use_multi_image, multi_image_mode, + seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, @@ -642,4 +879,4 @@ def extract_glb( )), } - demo.launch(css=css, head=head) + demo.launch(server_name="0.0.0.0", server_port=7860, css=css, head=head) diff --git a/example_multi_image.py b/example_multi_image.py new file mode 100644 index 0000000..225e384 --- /dev/null +++ b/example_multi_image.py @@ -0,0 +1,125 @@ +""" +Multi-image to 3D generation using TRELLIS 2. + +This example demonstrates multi-image conditioning with two fusion modes: +- 'stochastic': Cycles through images at each step (memory efficient) +- 'multidiffusion': Averages all images at each step (higher quality) +""" + +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory +import cv2 +import imageio +from PIL import Image +import torch +from trellis2.pipelines import Trellis2ImageTo3DPipeline +from trellis2.utils import render_utils +from trellis2.renderers import EnvMap +import o_voxel + +# Configuration +IMAGE_PATHS = [ + "assets/example_image/T.png", + "assets/example_image/T.png", # Add more image paths here +] +FUSION_MODE = 'multidiffusion' # 'stochastic' or 'multidiffusion' +RESOLUTION = '1024' # '512', '1024', or '1536' + +# 1. Setup Environment Map +envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' +)) + +# 2. Load Pipeline +pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B") +pipeline.cuda() + +# 3. Load Multiple Images +images = [Image.open(path) for path in IMAGE_PATHS] + +# 4. Extract Conditioning Features +torch.manual_seed(42) +cond_list_512 = [pipeline.get_cond([img], 512)['cond'] for img in images] +stacked_cond_512 = torch.cat(cond_list_512, dim=0) +cond_512 = { + 'cond': stacked_cond_512, + 'neg_cond': torch.zeros_like(stacked_cond_512[:1]) +} + +cond_1024 = None +if RESOLUTION != '512': + cond_list_1024 = [pipeline.get_cond([img], 1024)['cond'] for img in images] + stacked_cond_1024 = torch.cat(cond_list_1024, dim=0) + cond_1024 = { + 'cond': stacked_cond_1024, + 'neg_cond': torch.zeros_like(stacked_cond_1024[:1]) + } + +# 5. Sample with Multi-Image Conditioning +pipeline_type = {'512': '512', '1024': '1024_cascade', '1536': '1536_cascade'}[RESOLUTION] +ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] +ss_steps = pipeline.sparse_structure_sampler_params.get('steps', 12) +shape_steps = pipeline.shape_slat_sampler_params.get('steps', 12) +tex_steps = pipeline.tex_slat_sampler_params.get('steps', 12) + +# Sample sparse structure +with pipeline.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=FUSION_MODE): + coords = pipeline.sample_sparse_structure(cond_512, ss_res, num_samples=1, sampler_params={}) + +# Sample shape latent +if pipeline_type == '512': + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_steps, mode=FUSION_MODE): + shape_slat = pipeline.sample_shape_slat(cond_512, pipeline.models['shape_slat_flow_model_512'], coords, {}) + tex_cond = cond_512 + tex_model = pipeline.models['tex_slat_flow_model_512'] + res = 512 +elif pipeline_type == '1024': + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_steps, mode=FUSION_MODE): + shape_slat = pipeline.sample_shape_slat(cond_1024, pipeline.models['shape_slat_flow_model_1024'], coords, {}) + tex_cond = cond_1024 + tex_model = pipeline.models['tex_slat_flow_model_1024'] + res = 1024 +else: # 1024_cascade or 1536_cascade + target_res = 1024 if pipeline_type == '1024_cascade' else 1536 + with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_steps, mode=FUSION_MODE): + shape_slat, res = pipeline.sample_shape_slat_cascade( + cond_512, cond_1024, + pipeline.models['shape_slat_flow_model_512'], + pipeline.models['shape_slat_flow_model_1024'], + 512, target_res, coords, {}, max_num_tokens=49152 + ) + tex_cond = cond_1024 + tex_model = pipeline.models['tex_slat_flow_model_1024'] + +# Sample texture latent +with pipeline.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_steps, mode=FUSION_MODE): + tex_slat = pipeline.sample_tex_slat(tex_cond, tex_model, shape_slat, {}) + +# 6. Decode to Mesh +meshes = pipeline.decode_latent(shape_slat, tex_slat, res) +mesh = meshes[0] +mesh.simplify(16777216) # nvdiffrast limit + +# 7. Render Video +video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap)) +imageio.mimsave("sample_multi.mp4", video, fps=15) + +# 8. Export to GLB +glb = o_voxel.postprocess.to_glb( + vertices = mesh.vertices, + faces = mesh.faces, + attr_volume = mesh.attrs, + coords = mesh.coords, + attr_layout = mesh.layout, + voxel_size = mesh.voxel_size, + aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target = 1000000, + texture_size = 4096, + remesh = True, + remesh_band = 1, + remesh_project = 0, + verbose = True +) +glb.export("sample_multi.glb", extension_webp=True) diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py index a7b84b9..c3a2cdc 100644 --- a/trellis2/pipelines/trellis2_image_to_3d.py +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -484,7 +484,90 @@ def decode_latent( ) ) return out_mesh - + + def inject_sampler_multi_image( + self, + sampler_name: str, + num_images: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + """ + Inject a sampler with multiple images as condition. + This is a context manager that temporarily modifies the sampler's inference behavior + to handle multiple image conditioning. + + Args: + sampler_name (str): The name of the sampler to inject ('sparse_structure_sampler', + 'shape_slat_sampler', or 'tex_slat_sampler'). + num_images (int): The number of images to condition on. + num_steps (int): The number of steps to run the sampler for. + mode (str): The fusion mode for multi-image conditioning. + - 'stochastic': Cycle through images sequentially at each step (memory efficient) + - 'multidiffusion': Average predictions from all images at each step (higher quality) + """ + from contextlib import contextmanager + + @contextmanager + def _inject(): + sampler = getattr(self, sampler_name) + setattr(sampler, f'_old_inference_model', sampler._inference_model) + + if mode == 'stochastic': + if num_images > num_steps: + print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m") + + # Use a counter that cycles infinitely instead of a pre-created list + counter = {'value': 0} + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + cond_idx = counter['value'] % num_images + counter['value'] += 1 + cond_i = cond[cond_idx:cond_idx+1] + return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) + + elif mode == 'multidiffusion': + from .samplers import FlowEulerSampler + def _new_inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_interval, guidance_rescale=0.0, **kwargs): + if guidance_interval[0] <= t <= guidance_interval[1]: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + + # Apply CFG with optional rescaling + pred_cfg = guidance_strength * pred + (1 - guidance_strength) * neg_pred + + if guidance_rescale > 0: + x_0_pos = self._pred_to_xstart(x_t, t, pred) + x_0_cfg = self._pred_to_xstart(x_t, t, pred_cfg) + std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True) + std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True) + x_0_rescaled = x_0_cfg * (std_pos / std_cfg) + x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg + pred_cfg = self._xstart_to_pred(x_t, t, x_0) + + return pred_cfg + else: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + return pred + + else: + raise ValueError(f"Unsupported mode: {mode}") + + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) + + yield + + sampler._inference_model = sampler._old_inference_model + delattr(sampler, f'_old_inference_model') + + return _inject() + @torch.no_grad() def run( self,