-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy patheval_lightning.py
More file actions
56 lines (42 loc) · 1.63 KB
/
eval_lightning.py
File metadata and controls
56 lines (42 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from glob import glob
from typing import Any
from natsort import natsorted
import hydra
import pandas as pd
from omegaconf import DictConfig, OmegaConf
import torch
import torchvision
import lightning as L
from tqdm import tqdm
from utils.routines import load_from_model_path
@hydra.main(version_base=None, config_path="conf", config_name="eval")
def main(cfg_eval : DictConfig) -> None:
model, scene, cfg = load_from_model_path(cfg_eval.model_path, source_path=cfg_eval.source_path)
save_folder = cfg.scene.model_path
if cfg_eval.eval_on_gg:
from model.gaussian_grouping import GaussianGrouping
assert cfg_eval.gg_ckpt_folder is not None, "gg_ckpt_folder must be specified"
print(f"Loading GaussainGrouping model from {cfg_eval.gg_ckpt_folder}")
cfg.model.name = "gaussian_grouping"
cfg.model.gg_ckpt_folder = cfg_eval.gg_ckpt_folder
save_folder = cfg_eval.gg_ckpt_folder
model = GaussianGrouping(cfg, scene)
trainer = L.Trainer(
devices=cfg.gpus,
)
for subset in ["test", "valid"]:
loader = scene.get_data_loader(subset, shuffle=False)
if len(loader) > 0:
trainer.test(
model=model,
dataloaders=loader,
)
df = pd.DataFrame(model.test_logs)
df.to_csv(os.path.join(save_folder, f"eval_logs_{subset}.csv"), index=False)
if __name__ == "__main__":
main()