landmark.render.util.inference_module 源代码

import os
from typing import Any, Dict, Mapping, Optional, Union

from torch import nn

from landmark.nerf_components.scene import SceneManager


def preprocess(*args, **kwargs):
    return args, kwargs


def postprocess(*args, **kwargs):
    if len(args) == 1 and len(kwargs) == 0:
        return args[0]
    return tuple(list(args).append(kwargs)) if len(kwargs) > 0 else args


def _bind_empty_preprocess_and_postprocess(model):
    assert not (hasattr(model, "preprocess") or hasattr(model, "postprocess"))
    setattr(model, "preprocess", preprocess)
    setattr(model, "postprocess", postprocess)


class InferenceInterface(nn.Module):
    """
    The defination of landmark inference interface
    """

    def __init__(self) -> None:
        super().__init__()

    def preprocess(self, *args, **kwargs):
        return args, kwargs

    def postprocess(self, *args, **kwargs):
        return args, kwargs

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)


[文档] class InferenceModule(InferenceInterface): """Abstract inference module for inference engine adaptation Args: model (Optional[nn.Module]): internal model which is composed of the nerf components """ def __init__(self, model: Optional[nn.Module] = None) -> None: """Initializes the InferenceModule class. Args: model (Optional[nn.Module]): internal model which is composed of the nerf components. Defaults to None. """ super().__init__() self.model = model self.scene_mgr = None @property def inter_model(self): """Returns the internal model.""" return self.model @inter_model.setter def inter_model(self, model): """set the internal model. Args: model (Optional[nn.Module]): internal model which is composed of the nerf components. """ self.model = model @property def scene_manager(self): """Returns the scene manager.""" if self.scene_mgr is None: for attr_name in vars(self.model): attr_value = getattr(self.model, attr_name) if isinstance(attr_value, SceneManager): self.scene_mgr = attr_value return self.scene_mgr @property def merge_config(self): """Returns the offload merge config.""" return self.model.get_merge_config()
[文档] def channel_last(self): """Set the model in channel last mode.""" self.model.channel_last()
[文档] def get_state_dict_from_ckpt( self, file_path: Union[str, os.PathLike], map_location: Union[str, Dict[str, str]] ) -> Mapping[str, Any]: """Get state_dict from the given file path. Args: file_path (Union[str, os.PathLike]): a string or os.PathLike object containing a file name. map_location (Union[str, Dict[str, str]]): a string or a dict specifying how to remap storage locations. Returns: Mapping[str, Any]: the model's state dict """ return self.model.get_state_dict_from_ckpt(file_path, map_location)
[文档] def load_from_state_dict(self, state_dict: Mapping[str, Any], strict: bool = False): """ Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in state_dict match the keys returned by this module's key """ self.model.load_from_state_dict(state_dict, strict)
[文档] def preprocess(self, *args, **kwargs): """ Define the engine's preprocessing logic. By default, it returns the inputs. Args: args (list): a list containing the input parameters of preprocess. kwargs (dict): a dict containing input parameters of preprocess. Returns: Tuple[list, dict]: containing the output of the model preprocess. """ if self.model is None: return super().preprocess(*args, **kwargs) return self.model.preprocess(*args, **kwargs)
[文档] def postprocess(self, *args, **kwargs): """ Define the engine's postprocess logic. By default, it returns the inputs. Args: args (list): a list containing the input parameters of postprocess. kwargs (dict): a dict containing input parameters of postprocess. Returns: Tuple[list, dict]: containing the output of the model postprocess. """ if self.model is None: return super().postprocess(*args, **kwargs) return self.model.postprocess(*args, **kwargs)
[文档] def forward(self, *args, **kwargs): """ Define the engine's forward logic. By default, it uses the forward logic of internal model. Args: args (list): a list containing the input parameters of forward. kwargs (dict): a dict containing input parameters of forward. Returns: Tuple[list, dict]: containing the output of the model forward. """ assert self.model is not None return self.model(*args, **kwargs)