diff --git a/point_e/util/pc_to_mesh.py b/point_e/util/pc_to_mesh.py index 14e41cf..2f05977 100644 --- a/point_e/util/pc_to_mesh.py +++ b/point_e/util/pc_to_mesh.py @@ -51,12 +51,12 @@ def int_coord_to_float(int_coords: torch.Tensor) -> torch.Tensor: volume = [] for i in indices: - indices = torch.arange( + batch_indices = torch.arange( i, min(i + batch_size, grid_size**3), step=1, dtype=torch.int64, device=model.device ) - zs = int_coord_to_float(indices % grid_size) - ys = int_coord_to_float(torch.div(indices, grid_size, rounding_mode="trunc") % grid_size) - xs = int_coord_to_float(torch.div(indices, grid_size**2, rounding_mode="trunc")) + zs = int_coord_to_float(batch_indices % grid_size) + ys = int_coord_to_float(torch.div(batch_indices, grid_size, rounding_mode="trunc") % grid_size) + xs = int_coord_to_float(torch.div(batch_indices, grid_size**2, rounding_mode="trunc")) coords = torch.stack([xs, ys, zs], dim=0) with torch.no_grad(): volume.append(model(coords[None], encoded=cond)[0])