训练API
本章将介绍训练相关的 API。
landmark.train.nerf_trainer.NeRFTrainer
- class landmark.train.nerf_trainer.NeRFTrainer(config)[源代码]
Base class for Neural Radiance Field (NeRF) trainer.
This class provides a structured way to train NeRF models, handling initialization, training, evaluation, and logging. It is designed to be extended by specific implementations that define the model architecture, optimizer, and training/evaluation procedures.
- config
Configuration object containing all settings for the training process.
- Type:
Config
- device
The device (CPU/GPU) on which the model will be trained.
- Type:
torch.device
- model
The neural network model to be trained.
- Type:
nn.Module
- module
An additional module, if any, used during training.
- Type:
nn.Module
- optimizer
The optimizer used for training the model.
- Type:
torch.optim.Optimizer
- data_mgr
Manages the dataset for training and evaluation.
- Type:
DatasetManager
- scene_mgr
Manages scene-specific configurations and data.
- Type:
SceneManager
- metrics
A dictionary to store various metrics computed during training.
- Type:
dict
- logfolder
Path to the folder where logs and outputs are saved.
- Type:
str
- optim_dir
Directory within logfolder for saving optimizer states.
- Type:
str
- save_folder
Directory for saving test images.
- Type:
str
- writer
TensorBoard writer for logging metrics.
- Type:
SummaryWriter
- abstract check_args()[源代码]
Checks the consistency and validity of the arguments provided in the config.
- abstract create_model()[源代码]
Creates the model to be trained. This method must be implemented by subclasses.
- abstract create_optimizer()[源代码]
Creates the optimizer for training the model. This method must be implemented by subclasses.
- abstract evaluation(*args, **kwargs)[源代码]
Evaluates the model. This method must be implemented by subclasses.