From dae59d40ac27e4e7ae3ca89a815b9ce2e14f886e Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Sun, 1 Mar 2026 22:11:13 -0600 Subject: [PATCH 1/7] implements sumcumprod --- scripts/utils_alpha_blending.py | 237 +++++++++++++++++++++++--------- 1 file changed, 170 insertions(+), 67 deletions(-) diff --git a/scripts/utils_alpha_blending.py b/scripts/utils_alpha_blending.py index 69e41e5..d96f105 100644 --- a/scripts/utils_alpha_blending.py +++ b/scripts/utils_alpha_blending.py @@ -1,6 +1,7 @@ import os import sys import torch +import torch.nn as nn from utils_operation import regulate, cumprod @@ -8,19 +9,46 @@ sys.path.append(grandfather_path) from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm +from auto_LiRPA.perturbations import PerturbationLinear from collections import defaultdict -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') bound_opts = { 'conv_mode': 'matrix', 'optimize_bound_args': { - 'iteration': 100, - # 'lr_alpha':0.02, + 'iteration': 100, + # 'lr_alpha':0.02, 'early_stop_patience':5}, -} +} + + +class SumCumProdModel(nn.Module): + """SUMCUMPROD (Algorithm 5) as nn.Module for auto_LiRPA. + + Input: alpha, shape (B, N) + Output: pc, shape (B, 3) — all 3 color channels at once + Colors are fixed (not perturbed). + """ + def __init__(self, colors): + super().__init__() + self.register_buffer('colors', colors) # (B, N, 3) + + def forward(self, alpha): + B, N = alpha.shape + running = torch.ones(B, 1, device=alpha.device) + prefix_list = [] + for i in range(N): + prefix_list.append(running) + running = running * (1 - alpha[:, i:i+1]) + prefix = torch.cat(prefix_list, dim=1) # (B, N) + + alpha = torch.relu(alpha) + color = torch.relu(self.colors) + # prefix and alpha are (B, N), color is (B, N, 3) + return torch.sum(prefix.unsqueeze(-1) * alpha.unsqueeze(-1) * color, dim=1) # (B, 3) def alpha_blending(alpha, colors, method, triu_mask=None): @@ -104,75 +132,150 @@ def alpha_blending_ref(net, input_ref): def alpha_blending_ptb(net, input_ref, input_lb, input_ub, bound_method): N = net.call_model("get_num") gs_batch = net.call_model("get_gs_batch") - bg_color=(net.call_model("get_bg_color_tile")).unsqueeze(0).unsqueeze(-2) #[1, TH, TW, N, 3] + bg_color = net.call_model("get_bg_color_tile").unsqueeze(0).unsqueeze(-2) # (1, TH, TW, 1, 3) - if N==0: + if N == 0: return bg_color.squeeze(-2), bg_color.squeeze(-2) - else: - alphas_int_lb = [] - alphas_int_ub = [] - hl,wl,hu,wu = (net.call_model("get_tile_dict")[key] for key in ["hl", "wl", "hu", "wu"]) + # ── STEP 1: Extract linear bounds of α (per GS batch) from auto_LiRPA ── + alphas_lA = [] + alphas_uA = [] + alphas_lbias = [] + alphas_ubias = [] + alphas_ref = [] + + hl, wl, hu, wu = (net.call_model("get_tile_dict")[key] + for key in ["hl", "wl", "hu", "wu"]) + TH, TW = hu - hl, wu - wl + + ptb = PerturbationLpNorm(x_L=input_lb, x_U=input_ub) + input_ptb = BoundedTensor(input_ref, ptb) + + with torch.no_grad(): + for idx_start in range(0, N, gs_batch): + idx_end = min(idx_start + gs_batch, N) + num_gs = idx_end - idx_start - ptb = PerturbationLpNorm(x_L=input_lb,x_U=input_ub) - input_ptb = BoundedTensor(input_ref, ptb) + net.call_model("update_model_param", idx_start, idx_end, "middle") - with torch.no_grad(): - for i, idx_start in enumerate(range(0, N, gs_batch)): - idx_end = min(idx_start + gs_batch, N) + # Forward pass at input_ref for exact reference alphas + alpha_ref_batch = net(input_ref) # (1, TH*TW*num_gs) + alpha_ref_batch = alpha_ref_batch.reshape(TH * TW, num_gs) + alphas_ref.append(alpha_ref_batch.detach()) - net.call_model("update_model_param",idx_start,idx_end,"middle") - model = BoundedModule(net, input_ref, bound_opts=bound_opts, device=DEVICE) + model = BoundedModule(net, input_ref, bound_opts=bound_opts, device=DEVICE) - # Compute IBP bounds for reference - alpha_ibp_lb, alpha_ibp_ub = model.compute_bounds(x=(input_ptb, ), method="ibp") - reference_interm_bounds = {} - for node in model.nodes(): - if (node.perturbed + # IBP for reference intermediate bounds + model.compute_bounds(x=(input_ptb,), method="ibp") + reference_interm_bounds = {} + for node in model.nodes(): + if (node.perturbed and isinstance(node.lower, torch.Tensor) and isinstance(node.upper, torch.Tensor)): - reference_interm_bounds[node.name] = (node.lower, node.upper) - - # required_A = defaultdict(set) - # required_A[model.output_name[0]].add(model.input_name[0]) - - # Compute linear buond for alpha - alpha_int_lb, alpha_int_ub= model.compute_bounds( - x= (input_ptb, ), - method=bound_method, - reference_bounds=reference_interm_bounds, - ) #[1, TH, TW, N, 4] - - # lower_A, lower_bias = A_dict[model.output_name[0]][model.input_name[0]]['lA'], A_dict[model.output_name[0]][model.input_name[0]]['lbias'] - # upper_A, upper_bias = A_dict[model.output_name[0]][model.input_name[0]]['uA'], A_dict[model.output_name[0]][model.input_name[0]]['ubias'] - # print(f"lower_A shape: {lower_A.shape}, lower_bias shape: {lower_bias.shape}") - # print(f"upper_A shape: {upper_A.shape}, upper_bias shape: {upper_bias.shape}") - - alpha_int_lb = alpha_int_lb.reshape(1, hu-hl, wu-wl, idx_end-idx_start, 1) - alpha_int_ub = alpha_int_ub.reshape(1, hu-hl, wu-wl, idx_end-idx_start, 1) - - alphas_int_lb.append(alpha_int_lb.detach()) - alphas_int_ub.append(alpha_int_ub.detach()) - - del model - torch.cuda.empty_cache() - - alphas_int_lb = torch.cat(alphas_int_lb, dim=-2) - alphas_int_ub = torch.cat(alphas_int_ub, dim=-2) - - # Load Colors within Tile and Add background - colors = net.call_model("get_color_tile") - colors = colors.view(1, 1, 1, alphas_int_lb.size(-2), 3).repeat(1, alpha_int_lb.size(1), alpha_int_lb.size(2), 1, 1) - colors = torch.cat([colors, bg_color], dim = -2) - - ones = torch.ones_like(alphas_int_lb[:, :, :, 0:1, :]) - alphas_int_lb = torch.cat([alphas_int_lb, ones], dim=-2) - alphas_int_ub = torch.cat([alphas_int_ub, ones], dim=-2) - - # Volume Rendering for Interval Bounds - color_alpha_out_lb, color_alpha_out_ub = alpha_blending_interval(alphas_int_lb, alphas_int_ub, colors) - - color_out_lb,alpha_out_lb = color_alpha_out_lb.split([3,1],dim=-1) - color_out_ub,alpha_out_ub = color_alpha_out_ub.split([3,1],dim=-1) - - return color_out_lb.squeeze(-2), color_out_ub.squeeze(-2) + reference_interm_bounds[node.name] = (node.lower, node.upper) + + # CROWN with A matrix extraction + required_A = defaultdict(set) + required_A[model.output_name[0]].add(model.input_name[0]) + + # Must use 'backward' (CROWN) to extract A matrices; + # forward mode does not support return_A. + _, _, A_dict = model.compute_bounds( + x=(input_ptb,), + method='backward', + reference_bounds=reference_interm_bounds, + return_A=True, + needed_A_dict=required_A, + ) + + # lA·x + lbias ≤ α ≤ uA·x + ubias + A_entry = A_dict[model.output_name[0]][model.input_name[0]] + lA = A_entry['lA'].detach() # (1, TH*TW*num_gs, input_dim) + uA = A_entry['uA'].detach() + lbias = A_entry['lbias'].detach() # (1, TH*TW*num_gs) + ubias = A_entry['ubias'].detach() + + # Reshape per-pixel: (1, TH*TW*num_gs, d) → (TH*TW, num_gs, d) + lA = lA.reshape(TH * TW, num_gs, -1) + uA = uA.reshape(TH * TW, num_gs, -1) + lbias = lbias.reshape(TH * TW, num_gs) + ubias = ubias.reshape(TH * TW, num_gs) + + alphas_lA.append(lA) + alphas_uA.append(uA) + alphas_lbias.append(lbias) + alphas_ubias.append(ubias) + + del model + torch.cuda.empty_cache() + + # ── STEP 2: Concatenate α (partial GS) → α (all GS) ── + # Concatenate along the Gaussian dimension (dim=1 for per-pixel A matrices) + full_lA = torch.cat(alphas_lA, dim=1) # (TH*TW, N, input_dim) + full_uA = torch.cat(alphas_uA, dim=1) + full_lbias = torch.cat(alphas_lbias, dim=1) # (TH*TW, N) + full_ubias = torch.cat(alphas_ubias, dim=1) + alpha_ref_all = torch.cat(alphas_ref, dim=1) # (TH*TW, N) + + # Append background Gaussian: α_bg = 1 (fixed, zero A, bias = 1) + num_pixels = TH * TW + input_dim = full_lA.shape[-1] + full_lA = torch.cat([full_lA, + torch.zeros(num_pixels, 1, input_dim, device=DEVICE)], dim=1) + full_uA = torch.cat([full_uA, + torch.zeros(num_pixels, 1, input_dim, device=DEVICE)], dim=1) + full_lbias = torch.cat([full_lbias, + torch.ones(num_pixels, 1, device=DEVICE)], dim=1) + full_ubias = torch.cat([full_ubias, + torch.ones(num_pixels, 1, device=DEVICE)], dim=1) + alpha_ref_flat = torch.cat([alpha_ref_all, + torch.ones(num_pixels, 1, device=DEVICE)], dim=1) # (TH*TW, N+1) + + # Colors + background + colors = net.call_model("get_color_tile") + colors = colors.view(1, 1, 1, N, 3).repeat(1, TH, TW, 1, 1) + colors = torch.cat([colors, bg_color], dim=-2) # (1, TH, TW, N+1, 3) + + # ── STEPS 3–5: PerturbationLinear → BoundedModule → compute_bounds ── + N_total = full_lA.shape[1] # N + 1 (including background) + + # Colors: (1, TH, TW, N_total, 3) → (TH*TW, N_total, 3) + colors_flat = colors.reshape(-1, N_total, 3) + + with torch.no_grad(): + # STEP 3: Wrap α bounds in PerturbationLinear + # PerturbationLinear.concretize() will compose A_blend (from CROWN on + # blending model) with lA/uA and concretize against input_lb/input_ub, + # preserving full correlation: camera x → α → pc. + ptb_linear = PerturbationLinear( + lower_A=full_lA, # (TH*TW, N_total, input_dim) + upper_A=full_uA, + lower_b=full_lbias, # (TH*TW, N_total) + upper_b=full_ubias, + input_lb=input_lb.expand(num_pixels, -1), # (TH*TW, input_dim) + input_ub=input_ub.expand(num_pixels, -1), + ) + alpha_ptb = BoundedTensor(alpha_ref_flat, ptb_linear) + + # STEP 4: Wrap blending step in BoundedModule (all 3 channels at once) + blend_model = SumCumProdModel(colors_flat) + bounded_blend = BoundedModule( + blend_model, alpha_ref_flat, + bound_opts=bound_opts, device=DEVICE + ) + + # STEP 5: CROWN backward → A_blend; PerturbationLinear.concretize + # composes A_blend with lA/uA and concretizes against camera bounds. + pixel_lb, pixel_ub = bounded_blend.compute_bounds( + x=(alpha_ptb,), method="backward" + ) + # pixel_lb, pixel_ub: (TH*TW, 3) + + img_lb = pixel_lb.reshape(1, TH, TW, 3) + img_ub = pixel_ub.reshape(1, TH, TW, 3) + + del bounded_blend + torch.cuda.empty_cache() + + # ── STEP 6: Return image bounds ── + return img_lb, img_ub From bbf81f2869e44424641286b330c4dfff1c8cb4c6 Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Thu, 12 Mar 2026 13:57:26 -0500 Subject: [PATCH 2/7] Changes to make more like example code --- scripts/abstract_gsplat.py | 2 +- scripts/utils_alpha_blending.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/abstract_gsplat.py b/scripts/abstract_gsplat.py index a6e0db1..b79b293 100644 --- a/scripts/abstract_gsplat.py +++ b/scripts/abstract_gsplat.py @@ -73,7 +73,7 @@ def main(setup_dict): os.makedirs(save_folder_full) # Load Trained 3DGS - scene_parameters = torch.load(checkpoint_file, weights_only=False) + scene_parameters = torch.load(checkpoint_file, weights_only=False, map_location=DEVICE) means = scene_parameters['pipeline']['_model.gauss_params.means'].to(DEVICE) quats = scene_parameters['pipeline']['_model.gauss_params.quats'].to(DEVICE) opacities = torch.sigmoid(scene_parameters['pipeline']['_model.gauss_params.opacities']).to(DEVICE) diff --git a/scripts/utils_alpha_blending.py b/scripts/utils_alpha_blending.py index d96f105..f8922f8 100644 --- a/scripts/utils_alpha_blending.py +++ b/scripts/utils_alpha_blending.py @@ -38,17 +38,21 @@ def __init__(self, colors): def forward(self, alpha): B, N = alpha.shape - running = torch.ones(B, 1, device=alpha.device) + log_one_minus_alpha = torch.log(1 - alpha) # (B, N) + + running = torch.zeros(B, 1, device=alpha.device) prefix_list = [] for i in range(N): prefix_list.append(running) - running = running * (1 - alpha[:, i:i+1]) - prefix = torch.cat(prefix_list, dim=1) # (B, N) + running = running + log_one_minus_alpha[:, i:i+1] + + prefix = torch.cat(prefix_list, dim=1) # (B, N) + prefix = -torch.relu(-prefix) # clamp ≤ 0 + prefix = torch.exp(prefix) # (B, N) alpha = torch.relu(alpha) color = torch.relu(self.colors) - # prefix and alpha are (B, N), color is (B, N, 3) - return torch.sum(prefix.unsqueeze(-1) * alpha.unsqueeze(-1) * color, dim=1) # (B, 3) + return torch.sum(prefix.unsqueeze(-1) * alpha.unsqueeze(-1) * color, dim=1) def alpha_blending(alpha, colors, method, triu_mask=None): From 2aa96c68abe817ee058eeba17ad5c7f45520799e Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Thu, 12 Mar 2026 19:04:07 -0500 Subject: [PATCH 3/7] including paramters used to run model --- configs/boeing_airplane/config.yaml | 30 ++++++++++++++++++++++ configs/boeing_airplane/samples.json | 14 ++++++++++ configs/boeing_airplane/traj.json | 38 ++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 configs/boeing_airplane/config.yaml create mode 100644 configs/boeing_airplane/samples.json create mode 100644 configs/boeing_airplane/traj.json diff --git a/configs/boeing_airplane/config.yaml b/configs/boeing_airplane/config.yaml new file mode 100644 index 0000000..7af27f1 --- /dev/null +++ b/configs/boeing_airplane/config.yaml @@ -0,0 +1,30 @@ +bound_method: "forward" +render_method: "splatfacto" +case_name: "boeing_airplane" +odd_type: "cylinder" +save_filename: null +debug: true + +width: 256 +height: 256 +fx: 216 +fy: 216 +eps2d: 8.0 + +downsampling_ratio: 8 +tile_size_abstract: 8 +tile_size_render: 24 +min_distance: 0.01 +max_distance: 100.0 +gs_batch: 20 +part: [1, 2, 2] + +data_time: "2026-01-31_235019" +checkpoint_filename: "step-000299999-pruned95.ckpt" + +bg_img_path: null +bg_pure_color: [0.0, 1.0, 0.0] + +save_ref: true +save_bound: true +N_samples: 5 diff --git a/configs/boeing_airplane/samples.json b/configs/boeing_airplane/samples.json new file mode 100644 index 0000000..f90939b --- /dev/null +++ b/configs/boeing_airplane/samples.json @@ -0,0 +1,14 @@ +[ + { + "index": 0, + "pose": [ + 3.096113, + 0.375645, + 3.329813, + -3.141593, + 0.785398, + 1.570796 + ], + "radius": 0.1 + } +] diff --git a/configs/boeing_airplane/traj.json b/configs/boeing_airplane/traj.json new file mode 100644 index 0000000..f24de08 --- /dev/null +++ b/configs/boeing_airplane/traj.json @@ -0,0 +1,38 @@ +[ + { + "index": 0, + "pose": [ + 3.096113, + 0.375645, + 3.329813, + -3.141593, + 0.785398, + 1.570796 + ], + "tangent": [ + 0.0, + -1.0, + 0.0 + ], + "gate": null, + "radius": 0.1 + }, + { + "index": 1, + "pose": [ + 3.096113, + -0.624355, + 3.329813, + -3.141593, + 0.785398, + 1.570796 + ], + "tangent": [ + 0.0, + -1.0, + 0.0 + ], + "gate": null, + "radius": 0.1 + } +] From 728e399dc4d5f68a13144d1cf4c720df6630ed72 Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Thu, 12 Mar 2026 19:11:42 -0500 Subject: [PATCH 4/7] made some fixes --- scripts/render_models.py | 2 +- scripts/utils_alpha_blending.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/render_models.py b/scripts/render_models.py index dd6a0d4..f38a16e 100644 --- a/scripts/render_models.py +++ b/scripts/render_models.py @@ -508,7 +508,7 @@ def render_alpha(self, pose, scene_dict, eps_max=1.0): Ms_pix_12 = Ms_pix[:, :, 1, 2] covs_pix_det = (Ms_pix_00*Ms_pix_11-Ms_pix_01*Ms_pix_10)**2+(Ms_pix_00*Ms_pix_12-Ms_pix_02*Ms_pix_10)**2+(Ms_pix_01*Ms_pix_12-Ms_pix_02*Ms_pix_11)**2 - covs_pix_det += depth*1e-15 # May cause error + covs_pix_det = torch.relu(covs_pix_det - 1e-6) + 1e-6 # clamp min=1e-6, LiRPA-safe covs_pix_00 = Ms_pix_00**2+Ms_pix_01**2+Ms_pix_02**2 covs_pix_01 = Ms_pix_00*Ms_pix_10+Ms_pix_01*Ms_pix_11+Ms_pix_02*Ms_pix_12 diff --git a/scripts/utils_alpha_blending.py b/scripts/utils_alpha_blending.py index f8922f8..7044535 100644 --- a/scripts/utils_alpha_blending.py +++ b/scripts/utils_alpha_blending.py @@ -38,7 +38,8 @@ def __init__(self, colors): def forward(self, alpha): B, N = alpha.shape - log_one_minus_alpha = torch.log(1 - alpha) # (B, N) + one_minus_alpha = torch.relu(1 - alpha - 1e-6) + 1e-6 # clamp min=1e-6, LiRPA-safe + log_one_minus_alpha = torch.log(one_minus_alpha) # (B, N) running = torch.zeros(B, 1, device=alpha.device) prefix_list = [] From dea0f1070723471e07d1f153f00eeaa3fbe713cd Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Fri, 13 Mar 2026 10:19:36 -0500 Subject: [PATCH 5/7] config param changes --- configs/boeing_airplane/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/boeing_airplane/config.yaml b/configs/boeing_airplane/config.yaml index 7af27f1..2fe33a3 100644 --- a/configs/boeing_airplane/config.yaml +++ b/configs/boeing_airplane/config.yaml @@ -11,7 +11,7 @@ fx: 216 fy: 216 eps2d: 8.0 -downsampling_ratio: 8 +downsampling_ratio: 2 tile_size_abstract: 8 tile_size_render: 24 min_distance: 0.01 @@ -20,7 +20,7 @@ gs_batch: 20 part: [1, 2, 2] data_time: "2026-01-31_235019" -checkpoint_filename: "step-000299999-pruned95.ckpt" +checkpoint_filename: "step-000299999.ckpt" bg_img_path: null bg_pure_color: [0.0, 1.0, 0.0] From a3c79531bbae02b052ac70b8913b1fedd70be092 Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Wed, 18 Mar 2026 17:08:13 -0500 Subject: [PATCH 6/7] bug fix on rendering --- scripts/render_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/render_models.py b/scripts/render_models.py index f38a16e..1de5b8e 100644 --- a/scripts/render_models.py +++ b/scripts/render_models.py @@ -204,7 +204,7 @@ def render_alpha(self, pose, scene_dict, eps_max=1.0): covs_pix_det = (Ms_pix_00*Ms_pix_11-Ms_pix_01*Ms_pix_10)**2+(Ms_pix_00*Ms_pix_12-Ms_pix_02*Ms_pix_10)**2+(Ms_pix_01*Ms_pix_12-Ms_pix_02*Ms_pix_11)**2 # print(torch.min(covs_pix_det), torch.max(covs_pix_det)) - covs_pix_det += depth*1e-20# May cause error + covs_pix_det += depth*1e-15# May cause error covs_pix_00 = Ms_pix_00**2+Ms_pix_01**2+Ms_pix_02**2 covs_pix_01 = Ms_pix_00*Ms_pix_10+Ms_pix_01*Ms_pix_11+Ms_pix_02*Ms_pix_12 @@ -508,7 +508,7 @@ def render_alpha(self, pose, scene_dict, eps_max=1.0): Ms_pix_12 = Ms_pix[:, :, 1, 2] covs_pix_det = (Ms_pix_00*Ms_pix_11-Ms_pix_01*Ms_pix_10)**2+(Ms_pix_00*Ms_pix_12-Ms_pix_02*Ms_pix_10)**2+(Ms_pix_01*Ms_pix_12-Ms_pix_02*Ms_pix_11)**2 - covs_pix_det = torch.relu(covs_pix_det - 1e-6) + 1e-6 # clamp min=1e-6, LiRPA-safe + covs_pix_det += depth*1e-15 covs_pix_00 = Ms_pix_00**2+Ms_pix_01**2+Ms_pix_02**2 covs_pix_01 = Ms_pix_00*Ms_pix_10+Ms_pix_01*Ms_pix_11+Ms_pix_02*Ms_pix_12 From 6de80c3a3c1bda2eb34c20f4aaf2f93104ba937d Mon Sep 17 00:00:00 2001 From: dbelgorod Date: Thu, 19 Mar 2026 03:12:55 -0500 Subject: [PATCH 7/7] update the bound type of image color to linear set as what you did on alpha. After that, store both interval bound and linear bound of image color in the pt files --- scripts/abstract_gsplat.py | 19 ++++++++++-- scripts/utils_alpha_blending.py | 52 +++++++++++++++++++++++++++++---- scripts/utils_save.py | 25 +++++++++++----- 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/scripts/abstract_gsplat.py b/scripts/abstract_gsplat.py index b79b293..732e284 100644 --- a/scripts/abstract_gsplat.py +++ b/scripts/abstract_gsplat.py @@ -160,6 +160,11 @@ def main(setup_dict): if save_bound: img_lb = np.zeros((height, width,3)) img_ub = np.zeros((height, width,3)) + input_dim = input_center.shape[-1] + img_lA = torch.zeros(height, width, 3, input_dim, device=DEVICE) + img_uA = torch.zeros(height, width, 3, input_dim, device=DEVICE) + img_lbias = torch.zeros(height, width, 3, device=DEVICE) + img_ubias = torch.zeros(height, width, 3, device=DEVICE) # Create Tiles Queue tiles_queue = [ @@ -188,11 +193,15 @@ def main(setup_dict): img_ref[hl:hu, wl:wu, :] = ref_tile_np if save_bound: - lb_tile, ub_tile = alpha_blending_ptb(verf_net, input_center, input_lb, input_ub, bound_method) + lb_tile, ub_tile, lA_tile, uA_tile, lbias_tile, ubias_tile = alpha_blending_ptb(verf_net, input_center, input_lb, input_ub, bound_method) lb_tile_np = lb_tile.squeeze(0).detach().cpu().numpy() # [TH, TW, 3] ub_tile_np = ub_tile.squeeze(0).detach().cpu().numpy() img_lb[hl:hu, wl:wu, :] = lb_tile_np img_ub[hl:hu, wl:wu, :] = ub_tile_np + img_lA[hl:hu, wl:wu, :, :] = lA_tile.squeeze(0) + img_uA[hl:hu, wl:wu, :, :] = uA_tile.squeeze(0) + img_lbias[hl:hu, wl:wu, :] = lbias_tile.squeeze(0) + img_ubias[hl:hu, wl:wu, :] = ubias_tile.squeeze(0) if debug: pbar3.update(1) @@ -210,10 +219,14 @@ def main(setup_dict): save_abstract_record( save_dir=save_folder_full, index = absimg_num, - lower_input = input_lb_org, - upper_input = input_ub_org, + lower_input = input_lb.squeeze(0), + upper_input = input_ub.squeeze(0), lower_img=img_lb_f, upper_img=img_ub_f, + img_lA=img_lA, + img_uA=img_uA, + img_lbias=img_lbias, + img_ubias=img_ubias, point = base_trans, direction = direction, radius = radius, diff --git a/scripts/utils_alpha_blending.py b/scripts/utils_alpha_blending.py index 7044535..d56546c 100644 --- a/scripts/utils_alpha_blending.py +++ b/scripts/utils_alpha_blending.py @@ -140,7 +140,11 @@ def alpha_blending_ptb(net, input_ref, input_lb, input_ub, bound_method): bg_color = net.call_model("get_bg_color_tile").unsqueeze(0).unsqueeze(-2) # (1, TH, TW, 1, 3) if N == 0: - return bg_color.squeeze(-2), bg_color.squeeze(-2) + bg = bg_color.squeeze(-2) # (1, TH, TW, 3) + input_dim = input_ref.shape[-1] + TH, TW = bg.shape[1], bg.shape[2] + zero_A = torch.zeros(1, TH, TW, 3, input_dim, device=bg.device) + return bg, bg, zero_A, zero_A, bg.clone(), bg.clone() # ── STEP 1: Extract linear bounds of α (per GS batch) from auto_LiRPA ── alphas_lA = [] @@ -271,16 +275,54 @@ def alpha_blending_ptb(net, input_ref, input_lb, input_ub, bound_method): # STEP 5: CROWN backward → A_blend; PerturbationLinear.concretize # composes A_blend with lA/uA and concretizes against camera bounds. - pixel_lb, pixel_ub = bounded_blend.compute_bounds( - x=(alpha_ptb,), method="backward" + # Also extract A matrices for linear bounds on pixel color. + required_A_blend = defaultdict(set) + required_A_blend[bounded_blend.output_name[0]].add(bounded_blend.input_name[0]) + + pixel_lb, pixel_ub, A_dict_blend = bounded_blend.compute_bounds( + x=(alpha_ptb,), method="backward", + return_A=True, + needed_A_dict=required_A_blend, ) # pixel_lb, pixel_ub: (TH*TW, 3) + # Extract blend-level A matrices: pixel_color w.r.t. alpha + A_blend_entry = A_dict_blend[bounded_blend.output_name[0]][bounded_blend.input_name[0]] + blend_lA = A_blend_entry['lA'].detach() # (TH*TW, 3, N_total) + blend_uA = A_blend_entry['uA'].detach() + blend_lbias = A_blend_entry['lbias'].detach() # (TH*TW, 3) + blend_ubias = A_blend_entry['ubias'].detach() + + # ── STEP 5b: Compose blend-level A (pixel→alpha) with alpha-level A (alpha→x) ── + # blend_lA: (TH*TW, 3, N_total), full_lA: (TH*TW, N_total, input_dim) + # Result: composite_lA: (TH*TW, 3, input_dim) + # Split into positive and negative parts for correct bound composition + blend_lA_pos = torch.clamp(blend_lA, min=0) + blend_lA_neg = torch.clamp(blend_lA, max=0) + blend_uA_pos = torch.clamp(blend_uA, min=0) + blend_uA_neg = torch.clamp(blend_uA, max=0) + + # Lower bound: pos*lower + neg*upper + composite_lA = torch.bmm(blend_lA_pos, full_lA) + torch.bmm(blend_lA_neg, full_uA) + composite_lbias = blend_lbias + torch.bmm(blend_lA_pos, full_lbias.unsqueeze(-1)).squeeze(-1) \ + + torch.bmm(blend_lA_neg, full_ubias.unsqueeze(-1)).squeeze(-1) + + # Upper bound: pos*upper + neg*lower + composite_uA = torch.bmm(blend_uA_pos, full_uA) + torch.bmm(blend_uA_neg, full_lA) + composite_ubias = blend_ubias + torch.bmm(blend_uA_pos, full_ubias.unsqueeze(-1)).squeeze(-1) \ + + torch.bmm(blend_uA_neg, full_lbias.unsqueeze(-1)).squeeze(-1) + + # Reshape to tile shape: (TH*TW, 3, input_dim) → (1, TH, TW, 3, input_dim) + composite_lA = composite_lA.reshape(1, TH, TW, 3, input_dim) + composite_uA = composite_uA.reshape(1, TH, TW, 3, input_dim) + composite_lbias = composite_lbias.reshape(1, TH, TW, 3) + composite_ubias = composite_ubias.reshape(1, TH, TW, 3) + img_lb = pixel_lb.reshape(1, TH, TW, 3) img_ub = pixel_ub.reshape(1, TH, TW, 3) del bounded_blend torch.cuda.empty_cache() - # ── STEP 6: Return image bounds ── - return img_lb, img_ub + # ── STEP 6: Return image bounds (interval + linear) ── + return img_lb, img_ub, composite_lA, composite_uA, composite_lbias, composite_ubias diff --git a/scripts/utils_save.py b/scripts/utils_save.py index 8e576d5..2073b09 100644 --- a/scripts/utils_save.py +++ b/scripts/utils_save.py @@ -19,6 +19,10 @@ def save_abstract_record( upper_input, lower_img, upper_img, + img_lA=None, + img_uA=None, + img_lbias=None, + img_ubias=None, point=None, direction=None, radius=None, @@ -27,10 +31,13 @@ def save_abstract_record( Save an abstract image record. Required fields: - xl, xu : input lower/upper bounds + xl, xu : input lower/upper bounds (refined coordinate space) lower, upper : image lower/upper bounds (H, W, 3), float32 in [0, 1] - lA, uA : placeholder (None) - lb, ub : placeholder (None) + + Linear bound fields (None if not computed): + lA, uA : linear coefficients (H, W, 3, input_dim), float32 + lb, ub : linear biases (H, W, 3), float32 + Encodes: lA·x + lb ≤ pixel_color ≤ uA·x + ub, where x ∈ [xl, xu] Optional: point, direction, radius @@ -40,6 +47,10 @@ def save_abstract_record( upper_input = _to_float_tensor(upper_input) lower_img = _to_float_tensor(lower_img) upper_img = _to_float_tensor(upper_img) + img_lA = _to_float_tensor(img_lA) + img_uA = _to_float_tensor(img_uA) + img_lbias = _to_float_tensor(img_lbias) + img_ubias = _to_float_tensor(img_ubias) point = _to_float_tensor(point) direction = _to_float_tensor(direction) radius = _to_float_tensor(radius) @@ -50,10 +61,10 @@ def save_abstract_record( "xu": upper_input, "lower": lower_img, "upper": upper_img, - "lA": None, - "uA": None, - "lb": None, - "ub": None, + "lA": img_lA, + "uA": img_uA, + "lb": img_lbias, + "ub": img_ubias, "point": point, "direction": direction, "radius": radius,