landmark.nerf_components.model_components.fields.nerf_branch 源代码

import torch
import torch.nn.functional as F
from torch import nn


def positional_encoding(positions, freqs):
    """
    Applies positional encoding to the input positions.

    Args:
        positions (torch.Tensor): Input positions tensor of shape (..., N).
        freqs (int): Number of frequency bands.

    Returns:
        torch.Tensor: Tensor with positional encoding applied, of shape (..., D * F),
        where D is the dimension of the input positions and F is the number of frequency bands.
    """
    freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device)  # (F,)
    pts = (positions[..., None] * freq_bands).reshape(
        positions.shape[:-1] + (freqs * positions.shape[-1],)
    )  # (..., DF)
    pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
    return pts


def raw2alpha(raw, dists, act_fn=F.relu):
    """
    Computes the alpha values from the raw values and distances.

    Args:
        raw (torch.Tensor): The raw values.
        dists (torch.Tensor): The distances.
        act_fn (function): The activation function. Default is F.relu.

    Returns:
        torch.Tensor: The alpha values.
    """
    return 1.0 - torch.exp(-act_fn(raw) * dists)


def raw2outputs(raw, dists=1):
    """
    Computes the outputs from the raw values and distances.

    Args:
        raw (torch.Tensor): The raw values.
        dists (torch.Tensor): The distances.

    Returns:
        dict: A dictionary containing the RGB map, weights, RGB values, and alpha values.
    """
    rgb = torch.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]
    alpha = raw2alpha(raw[..., 3], dists)  # .unsqueeze(-1)
    weights = (
        alpha
        * torch.cumprod(
            torch.cat([torch.ones((alpha.shape[0], 1)).to(alpha), 1.0 - alpha + 1e-10], -1),
            -1,
        )[:, :-1]
    )
    rgb_map = torch.sum(weights[..., None] * rgb, -2)  # [N_rays, 3]
    return {
        "rgb_map": rgb_map,
        "weights": weights,
        "rgb": rgb,
        "alpha": alpha,
    }


[文档] class NeRF(torch.nn.Module): """ Class for NeRF branch Args: args (object): The arguments object. den_n_comp (int): The number of components for the density field. app_n_comp (int): The number of components for the appearance field. nfreqs (int, optional): The number of frequencies for positional encoding. Defaults to None. Attributes: num_freq (int): The number of frequencies for positional encoding. den_n_comp (int): The number of components for the density field. app_n_comp (int): The number of components for the appearance field. skips (list): The list of skip connections. use_viewdirs (bool): Flag indicating whether to use view directions. D (int): The value of Depth in the first stage. D_a (int): The value of Depth in the second stage. W (int): The value of Width. encode_app (bool): Flag indicating whether to encode appearance. Methods: init_module_v0: Initializes the module. forward: Forward pass of the NeRF branch. """ def __init__(self, args, den_n_comp, app_n_comp, nfreqs=None): # pure frequency embed super().__init__() self.num_freq = args.nerf_freq if nfreqs is None else nfreqs self.den_n_comp = den_n_comp self.app_n_comp = app_n_comp self.skips = [4] self.use_viewdirs = True self.D = args.nerf_D self.D_a = args.nerf_D_a self.W = args.nerf_W self.encode_app = args.encode_app self.init_module_v0()
[文档] def init_module_v0(self): """ Initializes the module. """ W, D, D_a = self.W, self.D, self.D_a input_ch = 3 # pts input_ch += 2 * 3 * self.num_freq # pts_pe input_ch += self.den_n_comp # 3 planes, 8 component per plane input_ch_views = 3 # views input_ch_views += 2 * 3 * 4 # views_pe input_ch_views += self.app_n_comp self.pts_linears = torch.nn.ModuleList( [torch.nn.Linear(input_ch, W)] + [torch.nn.Linear(W, W) if i not in self.skips else torch.nn.Linear(W + input_ch, W) for i in range(D - 1)] ) if self.encode_app: self.views_linears = nn.ModuleList( [torch.nn.Linear(input_ch_views + W + 48, W // 2)] + [torch.nn.Linear(W // 2, W // 2) for i in range(D_a - 1)] ) else: self.views_linears = nn.ModuleList( [torch.nn.Linear(input_ch_views + W, W // 2)] + [torch.nn.Linear(W // 2, W // 2) for i in range(D_a - 1)] ) if self.use_viewdirs: self.feature_linear = nn.Linear(W, W) self.alpha_linear = nn.Linear(W, 1) self.rgb_linear = nn.Linear(W // 2, 3) else: self.output_linear = nn.Linear(W, 4)
[文档] def forward( self, pts, viewdir=None, den_feats=None, app_feats=None, app_latent=None # , dists=None ): # gridfeat store in dict format """ Forward pass of the NeRF branch. Args: pts (torch.Tensor): The input points tensor. viewdir (torch.Tensor, optional): The view direction tensor. Defaults to None. den_feats (torch.Tensor, optional): The density field tensor. Defaults to None. app_feats (torch.Tensor, optional): The appearance field tensor. Defaults to None. dists (torch.Tensor, optional): The distances. Defaults to None. Returns: torch.Tensor: The output tensor. """ # nray, npts = pts.shape[:2] # pts = pts.view(-1, 3) pts_pe = positional_encoding(pts, self.num_freq) if viewdir is not None: viewdir = viewdir.reshape(-1, 3) # use reshape when sample multiple pts viewdir_pe = positional_encoding(viewdir, 4) input_concat = [pts, pts_pe] input_concat += [den_feats] inputs_flat = torch.cat(input_concat, -1) # concat h = inputs_flat for i, _ in enumerate(self.pts_linears): h = F.relu(self.pts_linears[i](h)) if i in self.skips: h = torch.cat([inputs_flat, h], -1) if self.use_viewdirs: alpha = self.alpha_linear(h) feature = self.feature_linear(h) view_concat = [feature, viewdir, viewdir_pe] if self.encode_app: view_concat += [app_feats, app_latent] else: view_concat += [app_feats] h = torch.cat(view_concat, -1) for i, _ in enumerate(self.views_linears): h = F.relu(self.views_linears[i](h)) rgb = self.rgb_linear(h) outputs = torch.cat([rgb, alpha], -1) else: outputs = self.output_linear(h) return outputs