landmark.nerf_components.configs.render_config 源代码

from typing import List, Optional, Union

import torch
from typing_extensions import Literal

from .base_configclass import ConfigClass
from .train_config import TrainConfig


[文档] class RenderConfig(ConfigClass): """Render Config Args: model_name (str): The name of the model. expname (str): The name of the experiment. batch_size (int): The size of the batch for processing. basedir (str): The base directory for logging and saving outputs. add_timestamp (bool): Whether to add a timestamp to the output filenames. random_seed (int): The seed for random number generation to ensure reproducibility. debug (bool): Enables debug mode if set to True. N_vis (int): Number of visualizations to generate. ckpt (Optional[str]): Path to a checkpoint file. ckpt_type (Literal["full", "sub"]): Type of checkpoint, either 'full' or 'sub'. kwargs (Optional[str]): Additional keyword arguments. is_train (bool): Flag indicating whether the configuration is for training. distributed (bool): Flag indicating whether distributed training is enabled. device (Union[str, torch.device]): The computing device to use ('cuda' or 'cpu'). DDP (bool): Flag indicating whether Distributed Data Parallel (DDP) is used. channel_parallel (bool): Flag indicating whether channel parallelism is used. branch_parallel (bool): Flag indicating whether branch parallelism is used. plane_division (Optional[list]): Specifies how to divide the input data into planes. channel_parallel_size (Optional[int]): Size of the channel parallelism. model_parallel_and_DDP (bool): Flag indicating whether model parallelism and DDP are combined. test_iterations (List[int]): List of iterations at which to perform testing. save_iterations (List[int]): List of iterations at which to save outputs. quiet (bool): If True, reduces the verbosity of the output. start_checkpoint (Optional[str]): Path to the starting checkpoint for resuming training. compute_extra_metrics (bool): Flag indicating whether to compute additional metrics. """ model_name: str = None expname: str = None batch_size: int = 4096 basedir: str = "./log" add_timestamp: bool = False random_seed: int = 20211202 debug: bool = False N_vis: int = 5 ckpt: Optional[str] = None ckpt_type: Literal["full", "sub"] = "full" kwargs: Optional[str] = None is_train: bool = False distributed: bool = False device: Union[str, torch.device] = torch.device("cuda") DDP: bool = False channel_parallel: bool = False branch_parallel: bool = False plane_division: Optional[list] = None channel_parallel_size: Optional[int] = None model_parallel_and_DDP: bool = False test_iterations: List[int] = [] save_iterations: List[int] = [] quiet: bool = False start_checkpoint: Optional[str] = None compute_extra_metrics: bool = False
[文档] def from_train_config(self, train_config: TrainConfig): """Updates the configuration based on a given TrainConfig instance. Args: train_config (TrainConfig): The training configuration to update this configuration with. """ # TODO add warning when having conflict settings (frank) self.model_name = train_config.model_name if self.model_name is None else self.model_name self.expname = train_config.expname if self.expname is None else self.expname self.basedir = train_config.basedir if self.basedir is None else self.basedir self.add_timestamp = train_config.add_timestamp if self.add_timestamp is None else self.add_timestamp self.random_seed = train_config.random_seed if self.random_seed is None else self.random_seed self.debug = train_config.debug if self.debug is None else self.debug self.N_vis = train_config.N_vis if self.N_vis is None else self.N_vis self.device = train_config.device if self.device is None else self.device self.quiet = train_config.quiet if self.quiet is None else self.quiet self.plane_division = train_config.plane_division if self.plane_division is None else self.plane_division return self
[文档] def check_args(self): """This method ensures that the configuration is valid for rendering""" assert not self.is_train, "RenderConfig is not for training"