-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampling.py
More file actions
64 lines (53 loc) · 2.78 KB
/
Copy pathsampling.py
File metadata and controls
64 lines (53 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def rays_sampling(H, W, F, c2w):
i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32).to(c2w), torch.arange(H, dtype=torch.float32).to(c2w), indexing='xy')
dirs = torch.stack([(i - W * .5) / F, -(j - H * .5) / F, -torch.ones_like(i)], dim=-1)
rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], dim=-1)
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d
def stratified_sampling(rays_o, rays_d, near, far, n_samples, device):
t_vals = torch.linspace(0.0, 1.0, n_samples, device=device)
z_vals = near * (1 - t_vals) + far * t_vals
mids = (z_vals[1:] + z_vals[:-1]) / 2
upper = torch.cat((mids, z_vals[-1:]), dim=0)
lower = torch.cat((z_vals[:1], mids), dim=0)
t_rand = torch.rand([n_samples], device=device)
z_vals = lower + (upper-lower)*t_rand
z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None]
return pts, z_vals
def sample_pdf(bins, weights, n_sample, device):
pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, dim=-1, keepdim=True)
cdf = torch.cumsum(pdf, dim=-1)
cdf = torch.cat((torch.zeros_like(cdf[..., :1]), cdf), dim=-1)
u = torch.rand(list(cdf.shape[:-1]) + [n_sample], device=device).contiguous()
ids = torch.searchsorted(cdf, u, right=True)
below = torch.clamp(ids - 1, min=0)
above = torch.clamp(ids, max = cdf.shape[-1] - 1)
ids_g = torch.stack([below, above], dim=-1)
matched_shape = [ids_g.shape[0], ids_g.shape[1], cdf.shape[-1]]
cdf_val = torch.gather(cdf.unsqueeze(1).expand(matched_shape), dim=-1, index=ids_g)
bins_val = torch.gather(bins[:, None, :].expand(matched_shape), dim=-1, index=ids_g)
cdf_d = (cdf_val[..., 1] - cdf_val[..., 0])
cdf_d = torch.where(cdf_d < 1e-5, torch.ones_like(cdf_d, device=device), cdf_d)
t = (u - cdf_val[..., 0]) / cdf_d
samples = bins_val[..., 0] + t * (bins_val[..., 1] - bins_val[..., 0])
return samples
def hierarachical_sampling(rays_o, rays_d, z_vals, weights, n_samples, device):
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, device)
new_z_samples = new_z_samples.detach()
z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None] # [N_rays, N_samples + n_samples, 3]
return pts, z_vals_combined, new_z_samples
# if __name__ == "__main__":
# rays_o = torch.rand((25,3))
# rays_d = torch.rand((25,3))
# near = 2
# far = 6
# n_samples=20
# stratified_sampling(rays_o, rays_d, near, far, n_samples, device="cpu")
# print()