landmark.nerf_components.configs.train_config 源代码

from typing import List, Optional, Union

import torch
from typing_extensions import Literal

from .base_configclass import ConfigClass


[文档] class TrainConfig(ConfigClass): """ Train Configuration class for setting up training parameters. Args: model_name (str): Name of the model to be trained. expname (str): Experiment name. start_iters (int, optional): Starting iteration number. Defaults to 0. n_iters (int): Total number of iterations. batch_size (int, optional): Batch size for training. Defaults to 4096. progress_refresh_rate (int, optional): Refresh rate for progress updates. Defaults to 10. wandb (bool, optional): Flag to indicate if Weights & Biases logging is enabled. Defaults to False. tensorboard (bool, optional): Flag to indicate if TensorBoard logging is enabled. Defaults to False. basedir (str, optional): Base directory for logging. Defaults to "./log". add_timestamp (bool, optional): Flag to indicate if timestamp should be added to log filenames. Defaults to False. random_seed (int, optional): Random seed. Defaults to 20211202. debug (bool, optional): Flag to indicate if debug mode is enabled. Defaults to False. N_vis (int, optional): Number of images to visualize. Defaults to 5. vis_every (int, optional): Iteration interval for visualization. Defaults to 10000. skip_save_imgs (bool, optional): Flag to indicate if saving images is skipped. Defaults to False. optim_dir (Optional[str], optional): Directory for optimization checkpoints. Defaults to None. ckpt (Optional[str], optional): Path to a specific checkpoint. Defaults to None. ckpt_type (Literal["full", "sub"], optional): Type of checkpoint. Defaults to "full". kwargs (Optional[str], optional): Additional arguments for checkpointing. Defaults to None. is_train (bool, optional): Flag to indicate if the configuration is for training. Defaults to True. device (Union[str, torch.device], optional): Device to use for training. Defaults to torch.device("cuda"). DDP (bool, optional): Flag to indicate if Distributed Data Parallel is used. Defaults to False. channel_parallel (bool, optional): Flag to indicate if channel parallelism is used. Defaults to False. branch_parallel (bool, optional): Flag to indicate if branch parallelism is used. Defaults to False. plane_division (Optional[list], optional): Division of planes for parallelism. Defaults to [1, 1]. channel_parallel_size (Optional[int], optional): Size for channel parallelism. Defaults to None. model_parallel_and_DDP (bool, optional): Flag to indicate if model parallelism and DDP are combined. Defaults to False. test_iterations (List[int], optional): Iterations for testing. Defaults to []. save_iterations (List[int], optional): Iterations for saving checkpoints. Defaults to []. start_checkpoint (Optional[str], optional): Path to the starting checkpoint. Defaults to None. print_timestamp (bool, optional): Flag to indicate if timestamps should be printed. Defaults to True. """ model_name: str expname: str start_iters: int = 0 n_iters: int batch_size: int = 4096 progress_refresh_rate: int = 10 wandb: bool = False tensorboard: bool = False basedir: str = "./log" add_timestamp: bool = False random_seed: int = 20211202 debug: bool = False N_vis: int = 5 vis_every: int = 10000 skip_save_imgs: bool = False optim_dir: Optional[str] = None ckpt: Optional[str] = None ckpt_type: Literal["full", "sub"] = "full" kwargs: Optional[str] = None is_train: bool = True device: Union[str, torch.device] = torch.device("cuda") DDP: bool = False channel_parallel: bool = False branch_parallel: bool = False plane_division: Optional[list] = [1, 1] channel_parallel_size: Optional[int] = None model_parallel_and_DDP: bool = False # TODO for 3d gaussian temporary (frank) # detect_anomaly: bool = False test_iterations: List[int] = [] save_iterations: List[int] = [] start_checkpoint: Optional[str] = None print_timestamp: bool = True
[文档] def check_args(self): """Performs a sanity check on the training configuration.""" pass