landmark.nerf_components.model_components.fields.embedding 源代码

import torch
from torch import nn


[文档] class AppearanceEmbedding(nn.Module): """ A class for the appearance embedding. """ def __init__(self, n_imgs: int, n_component: int, device: torch.device) -> None: """ Initialize the AppearanceEmbedding module. Args: n_imgs (int): Number of images. n_component (int): Dimensionality of the embedding. device (torch.device): Device to use for computation. """ super().__init__() self.device = device self.embedding = torch.nn.Embedding(num_embeddings=n_imgs, embedding_dim=n_component, device=device)
[文档] def forward(self, idxs: torch.Tensor = None, xyz_sampled: torch.Tensor = None, app_code: int = None): """ Compute the appearance latent vector. Args: idxs (torch.Tensor, optional): Indices of the sampled points. xyz_sampled (torch.Tensor, optional): Sampled points. app_code (int, optional): Given appearance code. Returns: torch.Tensor: The appearance latent vector. """ if idxs is not None: app_latent = self.embedding(idxs) elif app_code is not None: fake_idxs = torch.ones(xyz_sampled.shape[:-1], dtype=torch.long, device=self.device) fake_idxs *= app_code.long() app_latent = self.embedding(fake_idxs) else: app_latent = None # TODO: raise a warning (frank) return app_latent