-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvisualize_mesh.py
More file actions
executable file
·97 lines (79 loc) · 2.99 KB
/
visualize_mesh.py
File metadata and controls
executable file
·97 lines (79 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import argparse
from dataclasses import dataclass, asdict
from typing import Any, Optional
import torch
import tqdm
import stochsync.shared_modules as sm
from stochsync.data import DATASETs
from stochsync.background import BACKGROUNDs
from stochsync.model import MODELs
from stochsync.prior import PRIORs
from stochsync.logger import LOGGERs
from stochsync.utils.config_utils import load_config
from stochsync.utils.extra_utils import ignore_kwargs
from stochsync.utils.print_utils import print_info, print_error, print_warning
class Renderer:
"""
Renderer class for rendering 3D shapes from different camera positions.
"""
@ignore_kwargs
@dataclass
class Config:
root_dir: str = "./results/"
dataset: Any = "seq_turnaround"
background: Any = "solid"
model: Any = "mesh"
logger: Any = "renderer"
mesh_path: str = "./data/mesh/face.obj"
initialization: str = "image"
texture_path: str = "./data/mesh/face_texture.png"
# Dataset parameters
dist: float = 2.0
elev: float = 30.0
fov: float = 72
width: int = 256
height: int = 256
num_cameras: int = 180
# Logging parameters
output: str = "rendered.mp4"
output_type: str = "video"
fps: int = 15
def __init__(self, cfg_dict):
self.cfg = self.Config(**cfg_dict)
cfg_dict.update(
asdict(self.cfg)
) # Update the config dict with the default values
os.makedirs(self.cfg.root_dir, exist_ok=True)
output = os.path.join(self.cfg.root_dir, self.cfg.output)
if os.path.exists(output):
print_warning(f"Output file {output} already exists. Overwriting...")
sm.dataset = DATASETs[self.cfg.dataset](cfg_dict)
sm.background = BACKGROUNDs[self.cfg.background](cfg_dict)
sm.model = MODELs[self.cfg.model](cfg_dict)
sm.logger = LOGGERs[self.cfg.logger](cfg_dict)
sm.model.prepare_optimization()
@torch.no_grad()
def __call__(self) -> Any:
for step in tqdm.tqdm(range(self.cfg.num_cameras), desc="Rendering"):
# Sample a camera position
camera = sm.dataset.generate_sample()
# Render the 3D shape from the sampled camera position
r_pkg = sm.model.render(camera)
bg = sm.background(camera)
images = r_pkg["image"] + bg * (1 - r_pkg["alpha"])
# Log the result
sm.logger(step, camera, images)
sm.logger.end_logging()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
args, extra = parser.parse_known_args()
if args.config:
cfg = load_config(args.config, cli_args=extra)
else:
cfg = load_config(cli_args=extra)
cfg.root_dir = os.path.join(cfg.root_dir.replace(" ", "_"), cfg.tag)
renderer = Renderer(cfg)
renderer()
print_info(f"Rendering complete. Output saved to {renderer.cfg.output}")