landmark.nerf_components.model_components.fields.anchor_decoder 源代码

from typing import Optional

import torch
from einops import repeat
from torch import nn

from landmark.nerf_components.model.base_module import BaseModule


[文档] class AnchorDecoder(BaseModule): """ Decode the anchors to neural gaussians Args: use_feat_bank (bool): Whether to use feature bank. add_opacity_dist (bool): Whether to add opacity distribution. add_color_dist (bool): Whether to add color distribution. add_cov_dist (bool): Whether to add covariance distribution. view_dim (int): Dimension of the view. feat_dim (int): Dimension of the features. n_offsets (int): Number of offsets. appearance_dim (int, optional): Dimension of the appearance code. Defaults to 0. add_level (bool, optional): Whether to add level. Defaults to False. dist2level (str, optional): Distance to level. Defaults to None. Attributes: use_feat_bank (bool): Whether to use feature bank. add_opacity_dist (bool): Whether to add opacity distribution. add_color_dist (bool): Whether to add color distribution. add_cov_dist (bool): Whether to add covariance distribution. add_level (bool): Whether to add level. dist2level (str): Distance to level. view_dim (int): Dimension of the view. feat_dim (int): Dimension of the features. n_offsets (int): Number of offsets. appearance_dim (int): Dimension of the appearance code. mlp_feature_bank (nn.Sequential): MLP for feature bank. mlp_opacity (nn.Sequential): MLP for opacity. mlp_cov (nn.Sequential): MLP for covariance. mlp_color (nn.Sequential): MLP for color. Methods: forward(camera_center, anchor, anchor_feat, offset, scaling, visible_mask, level, prog_ratio, transition_mask, appearance_code): Forward pass of the Anchor Decoder. """ def __init__( self, use_feat_bank: bool, add_opacity_dist: bool, add_color_dist: bool, add_cov_dist: bool, view_dim: int, feat_dim: int, n_offsets: int, appearance_dim: int = 0, add_level: bool = False, dist2level: str = None, ): super().__init__() self.use_feat_bank = use_feat_bank self.add_opacity_dist = add_opacity_dist self.add_color_dist = add_color_dist self.add_cov_dist = add_cov_dist self.add_level = add_level self.dist2level = dist2level self.view_dim = view_dim self.feat_dim = feat_dim self.n_offsets = n_offsets self.appearance_dim = appearance_dim opacity_dist_dim = 1 if self.add_opacity_dist else 0 cov_dist_dim = 1 if self.add_cov_dist else 0 color_dist_dim = 1 if self.add_color_dist else 0 level_dim = 1 if self.add_level else 0 if self.use_feat_bank: self.mlp_feature_bank = nn.Sequential( nn.Linear(self.view_dim + level_dim, self.feat_dim), nn.ReLU(True), nn.Linear(self.feat_dim, 3), nn.Softmax(dim=1), ).cuda() self.mlp_opacity = nn.Sequential( nn.Linear(self.feat_dim + self.view_dim + opacity_dist_dim + level_dim, self.feat_dim), nn.ReLU(True), nn.Linear(self.feat_dim, self.n_offsets), nn.Tanh(), ).cuda() self.mlp_cov = nn.Sequential( nn.Linear(self.feat_dim + self.view_dim + cov_dist_dim + level_dim, self.feat_dim), nn.ReLU(True), nn.Linear(self.feat_dim, 7 * self.n_offsets), ).cuda() self.mlp_color = nn.Sequential( nn.Linear(self.feat_dim + self.view_dim + color_dist_dim + level_dim + self.appearance_dim, self.feat_dim), nn.ReLU(True), nn.Linear(self.feat_dim, 3 * self.n_offsets), nn.Sigmoid(), ).cuda() self.save_init_kwargs(locals()) # save for converting to fusion kernel
[文档] def forward( self, camera_center: torch.Tensor, anchor: torch.Tensor, anchor_feat: torch.Tensor, offset: torch.Tensor, scaling: torch.Tensor, visible_mask: Optional[torch.Tensor], level: Optional[torch.Tensor] = None, prog_ratio: Optional[torch.Tensor] = None, transition_mask: Optional[torch.Tensor] = None, appearance_code: Optional[torch.Tensor] = None, ): """ Forward pass of the Anchor Decoder. Args: camera_center (torch.Tensor): Camera center. anchor (torch.Tensor): Anchor. anchor_feat (torch.Tensor): Anchor features. offset (torch.Tensor): Offset. scaling (torch.Tensor): Scaling. visible_mask (torch.Tensor, optional): Visible mask. Defaults to None. level (torch.Tensor, optional): Level. Defaults to None. prog_ratio (torch.Tensor, optional): Progress ratio. Defaults to None. transition_mask (torch.Tensor, optional): Transition mask. Defaults to None. appearance_code (torch.Tensor, optional): Appearance code. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - xyz (torch.Tensor): XYZ coordinates. - color (torch.Tensor): Color. - opacity (torch.Tensor): Opacity. - scaling (torch.Tensor): Scaling. - rot (torch.Tensor): Rotation. - neural_opacity (torch.Tensor, optional): Neural opacity. Defaults to None. - mask (torch.Tensor, optional): Mask. Defaults to None. """ # view frustum filtering for acceleration if visible_mask is None: visible_mask = torch.ones(anchor.shape[0], dtype=torch.bool, device=anchor.device) anchor = anchor[visible_mask] feat = anchor_feat[visible_mask] grid_offsets = offset[visible_mask] grid_scaling = scaling[visible_mask] # feature from octree-gs if self.add_level and level is not None: level = level[visible_mask] ob_view = anchor - camera_center ob_dist = ob_view.norm(dim=1, keepdim=True) ob_view = ob_view / ob_dist if self.use_feat_bank: if self.add_level and level is not None: cat_view = torch.cat([ob_view, level], dim=1) else: cat_view = ob_view bank_weight = self.mlp_feature_bank(cat_view).unsqueeze(dim=1) # [n, 1, 3] feat = feat.unsqueeze(dim=-1) feat = ( feat[:, ::4, :1].repeat([1, 4, 1]) * bank_weight[:, :, :1] + feat[:, ::2, :1].repeat([1, 2, 1]) * bank_weight[:, :, 1:2] + feat[:, ::1, :1] * bank_weight[:, :, 2:] ) feat = feat.squeeze(dim=-1) # [n, c] if self.add_level and level is not None: cat_local_view = torch.cat([feat, ob_view, ob_dist, level], dim=1) # [N, c+3] cat_local_view_wodist = cat_local_view = torch.cat([feat, ob_view, level], dim=1) # [N, c+3+1] else: cat_local_view = torch.cat([feat, ob_view, ob_dist], dim=1) # [N, c+3+1] cat_local_view_wodist = torch.cat([feat, ob_view], dim=1) # [N, c+3] if self.add_opacity_dist: neural_opacity = self.mlp_opacity(cat_local_view) # [N, k] else: neural_opacity = self.mlp_opacity(cat_local_view_wodist) if self.dist2level == "progressive" and prog_ratio is not None and transition_mask is not None: prog = prog_ratio[visible_mask] transition_mask = transition_mask[visible_mask] prog[~transition_mask] = 1.0 neural_opacity = neural_opacity * prog neural_opacity = neural_opacity.reshape([-1, 1]) mask = neural_opacity > 0.0 mask = mask.view(-1) opacity = neural_opacity[mask] if self.appearance_dim > 0 and appearance_code is not None: appearance_code = ( torch.ones_like(cat_local_view[:, 0], dtype=torch.long, device=ob_dist.device) * appearance_code[0] ) if self.add_color_dist: color = self.mlp_color(torch.cat([cat_local_view, appearance_code], dim=1)) else: color = self.mlp_color(torch.cat([cat_local_view_wodist, appearance_code], dim=1)) else: if self.add_color_dist: color = self.mlp_color(cat_local_view) else: color = self.mlp_color(cat_local_view_wodist) color = color.reshape([anchor.shape[0] * self.n_offsets, 3]) # [mask] if self.add_cov_dist: scale_rot = self.mlp_cov(cat_local_view) else: scale_rot = self.mlp_cov(cat_local_view_wodist) scale_rot = scale_rot.reshape([anchor.shape[0] * self.n_offsets, 7]) # [mask] offsets = grid_offsets.view([-1, 3]) # [mask] concatenated = torch.cat([grid_scaling, anchor], dim=-1) concatenated_repeated = repeat(concatenated, "n (c) -> (n k) (c)", k=self.n_offsets) concatenated_all = torch.cat([concatenated_repeated, color, scale_rot, offsets], dim=-1) masked = concatenated_all[mask] scaling_repeat, repeat_anchor, color, scale_rot, offsets = masked.split([6, 3, 3, 7, 3], dim=-1) scaling = scaling_repeat[:, 3:] * torch.sigmoid(scale_rot[:, :3]) rot = torch.nn.functional.normalize(scale_rot[:, 3:7]) offsets = offsets * scaling_repeat[:, :3] xyz = repeat_anchor + offsets if self.training: return xyz, color, opacity, scaling, rot, neural_opacity, mask else: return xyz, color, opacity, scaling, rot, None, None