landmark.nerf_components.ray_samplers.uniform_sampler 源代码
from typing import Union
import torch
from landmark.nerf_components.data import Rays, Samples
from .base_sampler import BaseSampler
[文档]
class UniformSampler(BaseSampler):
""" Sample N_samples number points on a ray. If is_train is True, use segmented \
random sampling. Otherwise use uniform sampling.
Args:
num_samples (int): number of samples on each ray.
"""
def __init__(self, num_samples: int, step_size: int, aabb: torch.Tensor, near_far: list, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_samples = num_samples
self.step_size = step_size
self.aabb = aabb
self.near_far = near_far
self.save_init_kwargs(locals()) # save for converting to fusion kernel
[文档]
def forward(
self,
rays: Union[Rays, torch.Tensor],
num_samples: int = -1,
random_sampling: bool = False,
sample_within_hull: bool = False,
) -> (Samples, torch.Tensor):
"""Sample points along the rays.
Args:
rays (Rays): rays to sample.
num_samples (int): number of samples on each ray. If None, use self.num_samples.
Returns:
(Samples): sampled points.
(torch.Tensor): mask of the points outside the bounding box.
"""
rays_data = rays.data if isinstance(rays, Rays) else rays
ray_origin = rays_data[..., :3]
ray_dirs = rays_data[..., 3:6]
camera_indice = rays_data[..., 6]
num_samples = num_samples if num_samples > 0 else self.num_samples
# near, far = rays.nears.clone(), rays.fars.clone()
near, far = self.near_far
if not sample_within_hull:
vec = torch.where(ray_dirs == 0, torch.full_like(ray_dirs, 1e-6), ray_dirs)
self.aabb = self.aabb.to(rays_data.device)
rate_a = (self.aabb[1] - ray_origin) / vec
rate_b = (self.aabb[0] - ray_origin) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
rng = torch.arange(num_samples)[None].float()
if random_sampling:
rng = rng.repeat(ray_dirs.shape[-2], 1)
rng += torch.rand_like(rng[:, [0]])
step = self.step_size * rng.to(ray_dirs.device)
z_vals = t_min[..., None] + step
else:
o_z = ray_origin[:, -1:] - self.aabb[0, 2].item()
d_z = ray_dirs[:, -1:]
far = -(o_z / d_z)
# far[ray_dirs[:, 2] >= 0] = self.near_far[-1]
far.scatter_(0, (ray_dirs[:, 2] >= 0).nonzero(), torch.full(far.shape, float(self.near_far[-1])).cuda())
t_vals = torch.linspace(0.0, 1.0, steps=num_samples, device=ray_origin.device)
z_vals = near * (1.0 - t_vals) + far * (t_vals)
if random_sampling:
mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
t_rand = torch.rand(z_vals.shape, device=ray_origin.device)
z_vals = lower + (upper - lower) * t_rand
rays_pts = ray_origin[..., None, :] + ray_dirs[..., None, :] * z_vals[..., None]
if camera_indice is not None:
rays_pts_idxs = camera_indice.unsqueeze(-1).repeat(1, rays_pts.shape[1]).type(torch.long)
else:
rays_pts_idxs = None
samples_dirs = ray_dirs.clone().view(-1, 1, 3).expand(rays_pts.shape)
rank = rays.rank if isinstance(rays, Rays) else None
group = rays.group if isinstance(rays, Rays) else None
samples = Samples(
xyz=rays_pts, dirs=samples_dirs, z_vals=z_vals, camera_idx=rays_pts_idxs, rank=rank, group=group
)
samples = samples.to(ray_origin.device)
aabb = self.aabb.clone()
mask1 = aabb[0] > rays_pts
mask2 = rays_pts > aabb[1]
mask_outbbox = torch.any(torch.add(mask1, mask2), dim=-1)
# mask_outbbox = ((aabb[0] > rays_pts) | (rays_pts > aabb[1])).any(dim=-1)
return samples, ~mask_outbbox