Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ This is the official implementation for [LERF](https://lerf.io).
</div>

# Installation
LERF follows the integration guidelines described [here](https://docs.nerf.studio/en/latest/developer_guides/new_methods.html) for custom methods within Nerfstudio.
LERF follows the integration guidelines described [here](https://docs.nerf.studio/developer_guides/new_methods.html) for custom methods within Nerfstudio.
### 0. Install Nerfstudio dependencies
[Follow these instructions](https://docs.nerf.studio/en/latest/quickstart/installation.html) up to and including "tinycudann" to install dependencies and create an environment
[Follow these instructions](https://docs.nerf.studio/quickstart/installation.html) up to and including "tinycudann" to install dependencies and create an environment
### 1. Clone this repo
`git clone https://github.com/kerrj/lerf`
### 2. Install this repo as a python package
Expand All @@ -23,7 +23,7 @@ Run `ns-train -h`: you should see a list of "subcommands" with lerf, lerf-big, a
# Using LERF
Now that LERF is installed you can play with it!

- Launch training with `ns-train lerf --data <data_folder>`. This specifies a data folder to use. For more details, see [Nerfstudio documentation](https://docs.nerf.studio/en/latest/quickstart/first_nerf.html).
- Launch training with `ns-train lerf --data <data_folder>`. This specifies a data folder to use. For more details, see [Nerfstudio documentation](https://docs.nerf.studio/quickstart/first_nerf.html).
- Connect to the viewer by forwarding the viewer port (we use VSCode to do this), and click the link to `viewer.nerf.studio` provided in the output of the train script
- Within the viewer, you can type text into the textbox, then select the `relevancy_0` output type to visualize relevancy maps.

Expand Down
121 changes: 121 additions & 0 deletions lerf/data/alpha_lerf_datamanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Datamanager.
"""

from __future__ import annotations

import os.path as osp
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union

import torch
import yaml
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper
from rich.progress import Console

CONSOLE = Console(width=120)

from lerf.data.utils.dino_dataloader import DinoDataloader
from lerf.data.utils.alpha_pyramid_embedding_dataloader import AlphaPyramidEmbeddingDataloader
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig


@dataclass
class AlphaLERFDataManagerConfig(VanillaDataManagerConfig):
_target: Type = field(default_factory=lambda: AlphaLERFDataManager)
patch_tile_size_range: Tuple[int, int] = (0.1, 0.6)
patch_tile_size_res: int = 7
patch_stride_scaler: float = 0.5


class AlphaLERFDataManager(VanillaDataManager): # pylint: disable=abstract-method
"""Basic stored data manager implementation.

This is pretty much a port over from our old dataloading utilities, and is a little jank
under the hood. We may clean this up a little bit under the hood with more standard dataloading
components that can be strung together, but it can be just used as a black box for now since
only the constructor is likely to change in the future, or maybe passing in step number to the
next_train and next_eval functions.

Args:
config: the DataManagerConfig used to instantiate class
"""


def __init__(
self,
config: AlphaLERFDataManagerConfig,
device: Union[torch.device, str] = "cpu",
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
**kwargs, # pylint: disable=unused-argument
):
super().__init__(
config=config, device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, **kwargs
)
self.image_encoder = kwargs["image_encoder"]
images = [self.train_dataset[i]["image"].permute(2, 0, 1)[None, ...] for i in range(len(self.train_dataset))]
images = torch.cat(images)

cache_dir = f"outputs/{self.config.dataparser.data.name}"
clip_cache_path = Path(osp.join(cache_dir, f"clip_{self.image_encoder.name}"))
dino_cache_path = Path(osp.join(cache_dir, "dino.npy"))
# NOTE: cache config is sensitive to list vs. tuple, because it checks for dict equality
self.dino_dataloader = DinoDataloader(
image_list=images,
device=self.device,
cfg={"image_shape": list(images.shape[2:4])},
cache_path=dino_cache_path,
)
torch.cuda.empty_cache()
self.clip_interpolator = AlphaPyramidEmbeddingDataloader(
image_list=images,
device=self.device,
cfg={
"tile_size_range": [0.05, 0.5],
"tile_size_res": 7,
"stride_scaler": 0.5,
"image_shape": list(images.shape[2:4]),
"model_name": self.image_encoder.name,
},
cache_path=clip_cache_path,
model=self.image_encoder,
)

def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the train dataloader."""
self.train_count += 1
image_batch = next(self.iter_train_image_dataloader)
assert self.train_pixel_sampler is not None
batch = self.train_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.train_ray_generator(ray_indices)
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
batch["dino"] = self.dino_dataloader(ray_indices)
ray_bundle.metadata["clip_scales"] = clip_scale
# assume all cameras have the same focal length and image width
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
return ray_bundle, batch
116 changes: 116 additions & 0 deletions lerf/data/utils/alpha_patch_embedding_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json

import numpy as np
import torch
from nerfstudio.data.utils.feature_dataloader import FeatureDataloader
from nerfstudio.encoders.image_encoder import BaseImageEncoder
from tqdm import tqdm


class PatchEmbeddingDataloader(FeatureDataloader):
def __init__(
self,
cfg: dict,
device: torch.device,
model: BaseImageEncoder,
image_list: torch.Tensor = None,
cache_path: str = None,
):
assert "tile_ratio" in cfg
assert "stride_ratio" in cfg
assert "image_shape" in cfg
assert "model_name" in cfg

self.kernel_size = int(cfg["image_shape"][0] * cfg["tile_ratio"])
self.stride = int(self.kernel_size * cfg["stride_ratio"])
self.padding = self.kernel_size // 2
self.center_x = (
(self.kernel_size - 1) / 2
- self.padding
+ self.stride
* np.arange(
np.floor((cfg["image_shape"][0] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
)
)
self.center_y = (
(self.kernel_size - 1) / 2
- self.padding
+ self.stride
* np.arange(
np.floor((cfg["image_shape"][1] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
)
)
self.center_x = torch.from_numpy(self.center_x).half()
self.center_y = torch.from_numpy(self.center_y).half()
self.start_x = self.center_x[0].float()
self.start_y = self.center_y[0].float()

self.model = model
self.embed_size = self.model.embedding_dim
super().__init__(cfg, device, image_list, cache_path)

def load(self):
cache_info_path = self.cache_path.with_suffix(".info")
if not cache_info_path.exists():
raise FileNotFoundError
with open(cache_info_path, "r") as f:
cfg = json.loads(f.read())
if cfg != self.cfg:
raise ValueError("Config mismatch")
self.data = torch.from_numpy(np.load(self.cache_path)).half().to(self.device)

def create(self, image_list):
assert self.model is not None, "model must be provided to generate features"
assert image_list is not None, "image_list must be provided to generate features"

unfold_func = torch.nn.Unfold(
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
).to(self.device)

img_embeds = []
for img in tqdm(image_list, desc="Embedding images", leave=False):
img_embeds.append(self._embed_clip_tiles(img.unsqueeze(0), unfold_func))
self.data = torch.from_numpy(np.stack(img_embeds)).half().to(self.device)

def __call__(self, img_points):
# img_points: (B, 3) # (img_ind, x, y) (img_ind, row, col)
# return: (B, 512)
img_points = img_points.cpu()
img_ind, img_points_x, img_points_y = img_points[:, 0], img_points[:, 1], img_points[:, 2]

x_ind = torch.floor((img_points_x - (self.start_x)) / self.stride).long()
y_ind = torch.floor((img_points_y - (self.start_y)) / self.stride).long()
return self._interp_inds(img_ind, x_ind, y_ind, img_points_x, img_points_y)

def _interp_inds(self, img_ind, x_ind, y_ind, img_points_x, img_points_y):
img_ind = img_ind.to(self.data.device) # self.data is on cpu to save gpu memory, hence this line
topleft = self.data[img_ind, x_ind, y_ind].to(self.device)
topright = self.data[img_ind, x_ind + 1, y_ind].to(self.device)
botleft = self.data[img_ind, x_ind, y_ind + 1].to(self.device)
botright = self.data[img_ind, x_ind + 1, y_ind + 1].to(self.device)

x_stride = self.stride
y_stride = self.stride
right_w = ((img_points_x - (self.center_x[x_ind])) / x_stride).to(self.device) # .half()
top = torch.lerp(topleft, topright, right_w[:, None])
bot = torch.lerp(botleft, botright, right_w[:, None])

bot_w = ((img_points_y - (self.center_y[y_ind])) / y_stride).to(self.device) # .half()
return torch.lerp(top, bot, bot_w[:, None])

def _embed_clip_tiles(self, image, unfold_func):
# image augmentation: slow-ish (0.02s for 600x800 image per augmentation)
aug_imgs = torch.cat([image])

tiles = unfold_func(aug_imgs).permute(2, 0, 1).reshape(-1, 3, self.kernel_size, self.kernel_size).to("cuda")

with torch.no_grad():
clip_embeds = self.model.encode_image(tiles)
clip_embeds /= clip_embeds.norm(dim=-1, keepdim=True)

clip_embeds = clip_embeds.reshape((self.center_x.shape[0], self.center_y.shape[0], -1))
clip_embeds = torch.concat((clip_embeds, clip_embeds[:, [-1], :]), dim=1)
clip_embeds = torch.concat((clip_embeds, clip_embeds[[-1], :, :]), dim=0)
return clip_embeds.detach().cpu().numpy()
125 changes: 125 additions & 0 deletions lerf/data/utils/alpha_pyramid_embedding_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json
import os
from pathlib import Path

import numpy as np
import torch
from nerfstudio.data.utils.feature_dataloader import FeatureDataloader
from lerf.data.utils.mask_embedding_dataloader import MaskEmbeddingDataloader
from nerfstudio.encoders.image_encoder import BaseImageEncoder
from tqdm import tqdm


class AlphaPyramidEmbeddingDataloader(FeatureDataloader):
def __init__(
self,
cfg: dict,
device: torch.device,
model: BaseImageEncoder,
image_list: torch.Tensor = None,
cache_path: str = None,
):
assert "tile_size_range" in cfg
assert "tile_size_res" in cfg
assert "stride_scaler" in cfg
assert "image_shape" in cfg
assert "model_name" in cfg

self.tile_sizes = torch.linspace(*cfg["tile_size_range"], cfg["tile_size_res"]).to(device)
self.strider_scaler_list = [self._stride_scaler(tr.item(), cfg["stride_scaler"]) for tr in self.tile_sizes]

self.model = model
self.embed_size = self.model.embedding_dim
self.data_dict = {}
super().__init__(cfg, device, image_list, cache_path)

def __call__(self, img_points, scale=None):
if scale is None:
return self._random_scales(img_points)
else:
return self._uniform_scales(img_points, scale)

def _stride_scaler(self, tile_ratio, stride_scaler):
return np.interp(tile_ratio, [0.05, 0.15], [1.0, stride_scaler])

def load(self):
# don't create anything, PatchEmbeddingDataloader will create itself
cache_info_path = self.cache_path.with_suffix(".info")

# check if cache exists
if not cache_info_path.exists():
raise FileNotFoundError

# if config is different, remove all cached content
with open(cache_info_path, "r") as f:
cfg = json.loads(f.read())
if cfg != self.cfg:
for f in os.listdir(self.cache_path):
os.remove(os.path.join(self.cache_path, f))
raise ValueError("Config mismatch")

raise FileNotFoundError # trigger create

def create(self, image_list):
os.makedirs(self.cache_path, exist_ok=True)
for i, tr in enumerate(tqdm(self.tile_sizes, desc="Scales")):
stride_scaler = self.strider_scaler_list[i]
self.data_dict[i] = MaskEmbeddingDataloader(
cfg={
"tile_ratio": tr.item(),
"stride_ratio": stride_scaler,
"image_shape": self.cfg["image_shape"],
"model_name": self.cfg["model_name"],
},
device=self.device,
model=self.model,
image_list=image_list,
cache_path=Path(f"{self.cache_path}/level_{i}.npy"),
)
print(image_list.shape)

def save(self):
cache_info_path = self.cache_path.with_suffix(".info")
with open(cache_info_path, "w") as f:
f.write(json.dumps(self.cfg))
# don't save anything, PatchEmbeddingDataloader will save itself
pass

def _random_scales(self, img_points):
# img_points: (B, 3) # (img_ind, x, y)
# return: (B, 512), some random scale (between 0, 1)
img_points = img_points.to(self.device)
random_scale_bin = torch.randint(self.tile_sizes.shape[0] - 1, size=(img_points.shape[0],), device=self.device)
random_scale_weight = torch.rand(img_points.shape[0], dtype=torch.float16, device=self.device)

stepsize = (self.tile_sizes[1] - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0])

bottom_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device)
top_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device)

for i in range(len(self.tile_sizes) - 1):
ids = img_points[random_scale_bin == i]
bottom_interp[random_scale_bin == i] = self.data_dict[i](ids)
top_interp[random_scale_bin == i] = self.data_dict[i + 1](ids)

return (
torch.lerp(bottom_interp, top_interp, random_scale_weight[..., None]),
(random_scale_bin * stepsize + random_scale_weight * stepsize)[..., None],
)

def _uniform_scales(self, img_points, scale):
# img_points: (B, 3) # (img_ind, x, y)
scale_bin = torch.floor(
(scale - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0]) * (self.tile_sizes.shape[0] - 1)
).to(torch.int64)
scale_weight = (scale - self.tile_sizes[scale_bin]) / (
self.tile_sizes[scale_bin + 1] - self.tile_sizes[scale_bin]
)
interp_lst = torch.stack([interp(img_points) for interp in self.data_dict.values()])
point_inds = torch.arange(img_points.shape[0])
interp = torch.lerp(
interp_lst[scale_bin, point_inds],
interp_lst[scale_bin + 1, point_inds],
torch.Tensor([scale_weight]).half().to(self.device)[..., None],
)
return interp / interp.norm(dim=-1, keepdim=True), scale
Loading