-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtruncate_opacity.py
More file actions
66 lines (50 loc) · 1.91 KB
/
truncate_opacity.py
File metadata and controls
66 lines (50 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from argparse import ArgumentParser
import os
import shutil
import torch
from gs3dgs.scene.gaussian_model import GaussianModel
@torch.no_grad()
def truncate_gs(gs: GaussianModel, threshold: float = 0.1):
selector = ~(gs.get_opacity < threshold).squeeze()
gs._xyz = gs._xyz[selector]
gs._features_dc = gs._features_dc[selector]
gs._features_rest = gs._features_rest[selector]
gs._scaling = gs._scaling[selector]
gs._rotation = gs._rotation[selector]
gs._opacity = gs._opacity[selector]
gs.max_radii2D = gs.max_radii2D[selector]
def truncate(gs_path: str, threshold: float = 0.1):
assert os.path.exists(gs_path), f"Gaussian model file {gs_path} does not exist"
shutil.copy(gs_path, gs_path + ".bak")
with torch.no_grad():
gaussian_refined = GaussianModel(0)
gaussian_refined.load_ply(gs_path)
truncate_gs(gaussian_refined, threshold)
gaussian_refined.save_ply(gs_path)
def main():
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
parser.add_argument(
"--model_path",
"-m",
type=str,
required=True,
help="path to the scene Gaussian model",
)
parser.add_argument(
"--threshold",
default=0.1,
type=float,
help="opacity threshold for truncation",
)
args = parser.parse_args()
gs_generated_path = os.path.join(args.model_path, "generated")
if not os.path.exists(gs_generated_path):
raise FileNotFoundError(f"Generated Gaussian model path {gs_generated_path} does not exist")
for gs_file_name in os.listdir(gs_generated_path):
if gs_file_name.endswith(".ply"):
gs_path = os.path.join(gs_generated_path, gs_file_name)
truncate(gs_path, args.threshold)
print(f"Truncated {gs_path} with threshold {args.threshold}")
if __name__ == "__main__":
main()