训练API

本章将介绍训练相关的 API。

landmark.train.nerf_trainer.NeRFTrainer

class landmark.train.nerf_trainer.NeRFTrainer(config)[源代码]

Base class for Neural Radiance Field (NeRF) trainer.

This class provides a structured way to train NeRF models, handling initialization, training, evaluation, and logging. It is designed to be extended by specific implementations that define the model architecture, optimizer, and training/evaluation procedures.

config

Configuration object containing all settings for the training process.

Type:

Config

device

The device (CPU/GPU) on which the model will be trained.

Type:

torch.device

model

The neural network model to be trained.

Type:

nn.Module

module

An additional module, if any, used during training.

Type:

nn.Module

optimizer

The optimizer used for training the model.

Type:

torch.optim.Optimizer

data_mgr

Manages the dataset for training and evaluation.

Type:

DatasetManager

scene_mgr

Manages scene-specific configurations and data.

Type:

SceneManager

metrics

A dictionary to store various metrics computed during training.

Type:

dict

logfolder

Path to the folder where logs and outputs are saved.

Type:

str

optim_dir

Directory within logfolder for saving optimizer states.

Type:

str

save_folder

Directory for saving test images.

Type:

str

writer

TensorBoard writer for logging metrics.

Type:

SummaryWriter

abstract check_args()[源代码]

Checks the consistency and validity of the arguments provided in the config.

abstract create_model()[源代码]

Creates the model to be trained. This method must be implemented by subclasses.

abstract create_optimizer()[源代码]

Creates the optimizer for training the model. This method must be implemented by subclasses.

eval()[源代码]

Evaluates the model on the test dataset, computes metrics, and saves the results.

abstract evaluation(*args, **kwargs)[源代码]

Evaluates the model. This method must be implemented by subclasses.

init_train_env()[源代码]

Initializes the training environment, including setting up distributed training and logging directories.

abstract train(*args, **kwargs)[源代码]

Trains the model. This method must be implemented by subclasses.