Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 287 additions & 50 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import gradio as gr

import os
# Load optional environment variables from .env file if it exists
if os.path.exists('.env'):
with open('.env') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
os.environ[key.strip()] = value.strip().strip('"')

os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from datetime import datetime
Expand Down Expand Up @@ -350,6 +359,9 @@ def get_seed(randomize_seed: bool, seed: int) -> int:

def image_to_3d(
image: Image.Image,
multi_images: List[Image.Image],
use_multi_image: bool,
multi_image_mode: str,
seed: int,
resolution: str,
ss_guidance_strength: float,
Expand All @@ -367,36 +379,205 @@ def image_to_3d(
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
# --- Sampling ---
outputs, latents = pipeline.run(
image,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
},
shape_slat_sampler_params={
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
pipeline_type={
"""
Generate a 3D model from single or multiple images.

Args:
image (Image.Image): Single input image (used when use_multi_image is False).
multi_images (List[Image.Image]): List of input images for multi-view generation.
use_multi_image (bool): Whether to use multi-image mode.
multi_image_mode (str): Fusion mode for multi-image conditioning.
- 'stochastic': Cycles through images sequentially (memory efficient)
- 'multidiffusion': Averages predictions from all images (higher quality)
seed (int): Random seed for generation.
resolution (str): Output resolution ('512', '1024', or '1536').
ss_guidance_strength (float): Guidance strength for sparse structure sampling.
ss_guidance_rescale (float): Guidance rescale for sparse structure sampling.
ss_sampling_steps (int): Number of sampling steps for sparse structure.
ss_rescale_t (float): Rescale parameter for sparse structure sampling.
shape_slat_guidance_strength (float): Guidance strength for shape latent sampling.
shape_slat_guidance_rescale (float): Guidance rescale for shape latent sampling.
shape_slat_sampling_steps (int): Number of sampling steps for shape latent.
shape_slat_rescale_t (float): Rescale parameter for shape latent sampling.
tex_slat_guidance_strength (float): Guidance strength for texture latent sampling.
tex_slat_guidance_rescale (float): Guidance rescale for texture latent sampling.
tex_slat_sampling_steps (int): Number of sampling steps for texture latent.
tex_slat_rescale_t (float): Rescale parameter for texture latent sampling.
req (gr.Request): Gradio request object.
progress (gr.Progress): Gradio progress tracker.

Returns:
str: JSON string containing rendered images and model state.
"""
# --- Check if multi-image mode is enabled ---
if use_multi_image:
if multi_images is None or len(multi_images) == 0:
raise gr.Error("Please upload images in the Multi-Image gallery before generating")

# Multi-image processing - ensure all images are PIL Images
images_to_process = []
for idx, img in enumerate(multi_images):
if img is not None:
# Handle different formats Gradio might return
if isinstance(img, tuple):
# Gallery returns (PIL.Image, caption/metadata)
img = img[0] # Extract the image from the tuple
elif isinstance(img, dict) and 'name' in img:
# Gradio sometimes returns dict with 'name' key pointing to file path
img_path = img['name']
img = Image.open(img_path)
elif isinstance(img, str):
# File path string
img = Image.open(img)
elif isinstance(img, np.ndarray):
# NumPy array
img = Image.fromarray(img)

# Verify we have a PIL Image
if not isinstance(img, Image.Image):
continue

images_to_process.append(img)

if len(images_to_process) == 0:
raise gr.Error("No valid images could be processed. Please check the image format.")

# Get conditioning from multiple images
torch.manual_seed(seed)
pipeline_type = {
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[resolution],
return_latent=True,
)
}[resolution]

# Process each image and stack along batch dimension
cond_list_512 = [pipeline.get_cond([img], 512)['cond'] for img in images_to_process]
stacked_cond_512 = torch.cat(cond_list_512, dim=0)
cond_512 = {
'cond': stacked_cond_512,
'neg_cond': torch.zeros_like(stacked_cond_512[:1])
}

cond_1024 = None
if resolution != '512':
cond_list_1024 = [pipeline.get_cond([img], 1024)['cond'] for img in images_to_process]
stacked_cond_1024 = torch.cat(cond_list_1024, dim=0)
cond_1024 = {
'cond': stacked_cond_1024,
'neg_cond': torch.zeros_like(stacked_cond_1024[:1])
}

# Sample sparse structure
ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
with pipeline.inject_sampler_multi_image('sparse_structure_sampler', len(images_to_process), ss_sampling_steps, mode=multi_image_mode):
coords = pipeline.sample_sparse_structure(
cond_512, ss_res,
num_samples=1,
sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
}
)

# Sample shape latent
if pipeline_type == '512':
with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode):
shape_slat = pipeline.sample_shape_slat(
cond_512, pipeline.models['shape_slat_flow_model_512'],
coords, {
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
}
)
tex_cond = cond_512
tex_model = pipeline.models['tex_slat_flow_model_512']
res = 512
elif pipeline_type == '1024':
with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode):
shape_slat = pipeline.sample_shape_slat(
cond_1024, pipeline.models['shape_slat_flow_model_1024'],
coords, {
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
}
)
tex_cond = cond_1024
tex_model = pipeline.models['tex_slat_flow_model_1024']
res = 1024
elif pipeline_type in ['1024_cascade', '1536_cascade']:
target_res = 1024 if pipeline_type == '1024_cascade' else 1536
with pipeline.inject_sampler_multi_image('shape_slat_sampler', len(images_to_process), shape_slat_sampling_steps, mode=multi_image_mode):
shape_slat, res = pipeline.sample_shape_slat_cascade(
cond_512, cond_1024,
pipeline.models['shape_slat_flow_model_512'],
pipeline.models['shape_slat_flow_model_1024'],
512, target_res,
coords, {
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
max_num_tokens=49152
)
tex_cond = cond_1024
tex_model = pipeline.models['tex_slat_flow_model_1024']

# Sample texture latent
with pipeline.inject_sampler_multi_image('tex_slat_sampler', len(images_to_process), tex_slat_sampling_steps, mode=multi_image_mode):
tex_slat = pipeline.sample_tex_slat(
tex_cond, tex_model,
shape_slat, {
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
}
)

latents = (shape_slat, tex_slat, res)
outputs = pipeline.decode_latent(shape_slat, tex_slat, res)
else:
# Single-image processing
if image is None:
raise gr.Error("No image provided")

outputs, latents = pipeline.run(
image,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
},
shape_slat_sampler_params={
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
pipeline_type={
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[resolution],
return_latent=True,
)

mesh = outputs[0]
mesh.simplify(16777216) # nvdiffrast limit
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
Expand Down Expand Up @@ -517,40 +698,72 @@ def extract_glb(
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
* Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
* **Single Image Mode**: Upload one image and click Generate to create a 3D asset.
* **Multi-Image Mode**: Enable multi-image mode and upload multiple views for better 3D reconstruction.
* Click Extract GLB to export and download the generated GLB file if you're satisfied with the result.
* **Default settings are optimized for defect detection and anomaly inspection** (high resolution, detail preservation, multi-view consistency).
""")

with gr.Row():
with gr.Column(scale=1, min_width=360):
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)

# Mode selector
use_multi_image = gr.Checkbox(label="Enable Multi-Image Mode", value=False)

# Single image input
image_prompt = gr.Image(
label="Single Image Input",
format="png",
image_mode="RGBA",
type="pil",
height=400,
visible=True
)

# Multi-image gallery
multi_image_prompt = gr.Gallery(
label="Multi-Image Input (Upload 2-8 views)",
format="png",
type="pil",
height=400,
columns=4,
visible=False
)

# Multi-image fusion mode
multi_image_mode = gr.Radio(
["stochastic", "multidiffusion"],
label="Multi-Image Fusion Mode",
value="multidiffusion",
info="Stochastic: cycles through images (fast). Multidiffusion: averages all images (slower, better quality)",
visible=False
)

resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=900000, step=10000)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=4096, step=1024)

generate_btn = gr.Button("Generate")

with gr.Accordion(label="Advanced Settings", open=False):
with gr.Accordion(label="Advanced Settings", open=False):
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=8.5, step=0.1)
ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.75, step=0.01)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1)
ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
gr.Markdown("Stage 2: Shape Generation")
with gr.Row():
shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=8.5, step=0.1)
shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.6, step=0.01)
shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1)
shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
gr.Markdown("Stage 3: Material Generation")
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=4.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.3, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=18, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)

with gr.Column(scale=10):
Expand Down Expand Up @@ -581,7 +794,30 @@ def extract_glb(
# Handlers
demo.load(start_session)
demo.unload(end_session)


# Toggle visibility based on mode
def toggle_multi_image_mode(enabled):
"""
Toggle visibility of UI components based on multi-image mode selection.

Args:
enabled (bool): Whether multi-image mode is enabled.

Returns:
dict: Dictionary of Gradio component updates for visibility toggling.
"""
return {
image_prompt: gr.Image(visible=not enabled),
multi_image_prompt: gr.Gallery(visible=enabled),
multi_image_mode: gr.Radio(visible=enabled)
}

use_multi_image.change(
toggle_multi_image_mode,
inputs=[use_multi_image],
outputs=[image_prompt, multi_image_prompt, multi_image_mode]
)

image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
Expand All @@ -597,7 +833,8 @@ def extract_glb(
).then(
image_to_3d,
inputs=[
image_prompt, seed, resolution,
image_prompt, multi_image_prompt, use_multi_image, multi_image_mode,
seed, resolution,
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
Expand Down Expand Up @@ -642,4 +879,4 @@ def extract_glb(
)),
}

demo.launch(css=css, head=head)
demo.launch(server_name="0.0.0.0", server_port=7860, css=css, head=head)
Loading