landmark.nerf_components.model_components.fields.alpha_mask 源代码

import torch
import torch.nn
import torch.nn.functional as F


[文档] class AlphaGridMask(torch.nn.Module): """ A class for the alpha grid mask. Args: device (str): The device to use. aabb (torch.Tensor): The axis-aligned bounding box. alpha_volume (torch.Tensor): The alpha volume. """ def __init__(self, device, aabb, alpha_volume): """ Initializes an AlphaMask object. Args: device (torch.device): The device to be used for computations. aabb (torch.Tensor): The axis-aligned bounding box (AABB) of the volume. alpha_volume (torch.Tensor): The alpha volume. """ super().__init__() self.device = device self.aabb = aabb.to(self.device) self.aabbSize = self.aabb[1] - self.aabb[0] self.invgridSize = 1.0 / self.aabbSize * 2 self.alpha_volume = alpha_volume.view(1, 1, *alpha_volume.shape[-3:]) self.gridSize = torch.LongTensor([alpha_volume.shape[-1], alpha_volume.shape[-2], alpha_volume.shape[-3]]).to( self.device )
[文档] def update_device(self, device): """ Updates the device used by the AlphaGridMask instance. Args: device (str): The new device to use. """ self.device = device self.aabb = self.aabb.to(device) self.invgridSize = self.invgridSize.to(device) self.gridSize = self.gridSize.to(device) self.alpha_volume = self.alpha_volume.to(device)
[文档] def sample_alpha(self, xyz_sampled): """ Samples the alpha values. Args: xyz_sampled (torch.Tensor): The sampled coordinates. Returns: torch.Tensor: The alpha values. """ xyz_sampled = self.normalize_coord(xyz_sampled) alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1) return alpha_vals
[文档] def forward(self, xyz_sampled, filter_thresh: torch.Tensor, mask: torch.Tensor = None): """ Computes the forward pass of the AlphaGridMask. Args: xyz_sampled (torch.Tensor): The sampled coordinates. filter_thresh (torch.Tensor): The filter threshold. mask (torch.Tensor, optional): The mask tensor. Defaults to None. Returns: torch.Tensor: The computed mask. """ if mask is not None: xyz_sampled = xyz_sampled[mask] xyz_sampled = self.normalize_coord(xyz_sampled) alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1) alpha_mask = alpha_vals > filter_thresh if mask is not None: mask_invalid = ~mask mask_invalid[mask] |= ~alpha_mask mask_valid = ~mask_invalid else: mask_valid = alpha_mask return mask_valid
[文档] def generate_alpha_mask(self, xyz_sampled, filter_thresh=0, mask=None): """ Generates the alpha mask. Args: xyz_sampled (torch.Tensor): The sampled coordinates. filter_thresh (int, optional): The filter threshold. Defaults to 0. mask (torch.Tensor, optional): The mask tensor. Defaults to None. Returns: torch.Tensor or tuple: The generated alpha mask or a tuple of alpha mask and valid mask. """ if mask is not None: xyz_sampled = xyz_sampled[mask] xyz_sampled = self.normalize_coord(xyz_sampled) alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1) alpha_mask = alpha_vals > filter_thresh mask_invalid = ~mask mask_invalid[mask] |= ~alpha_mask mask_valid = ~mask_invalid return alpha_mask, mask_valid else: xyz_sampled = self.normalize_coord(xyz_sampled) alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1) alpha_mask = alpha_vals > filter_thresh return alpha_mask
[文档] def normalize_coord(self, xyz_sampled): """ Normalizes the sampled coordinates. Args: xyz_sampled (torch.Tensor): The sampled coordinates. Returns: torch.Tensor: The normalized coordinates. """ return (xyz_sampled - self.aabb[0]) * self.invgridSize - 1