landmark.nerf_components.configs.config_parser 源代码

import argparse
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path

from landmark.utils.config import Config


def get_default_parser():
    """Reads user command line and uses an argument parser to parse the input arguments.
    Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.

    Returns:
       Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, help="path to the config file")

    return parser


def parse_console_args(unknown_args):
    if unknown_args is None:
        return None
    parsed_args = {}
    i = 0
    while i < len(unknown_args):
        arg = unknown_args[i]
        if arg.startswith("--"):
            key = arg[2:]
            if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
                value = unknown_args[i + 1]
                parsed_args[key] = value
                i += 1
            else:
                parsed_args[key] = True
        i += 1
    return parsed_args


# Copy from InternLM
[文档] class BaseConfig(Config): """This is a wrapper class for dict objects so that values of which can be accessed as attributes. Args: config (dict): The dict object to be wrapped. """
[文档] @staticmethod def from_file(filename: str, console_args=None, merge_configs=True): """Reads a python file and constructs a corresponding :class:`Config` object. Args: filename (str): Name of the file to construct the return object. Returns: :class:`Config`: A :class:`Config` object constructed with information in the file. Raises: AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file """ # check config path if isinstance(filename, str): filepath = Path(filename).absolute() elif isinstance(filename, Path): filepath = filename.absolute() assert filepath.exists(), f"{filename} is not found, please check your configuration path" # check extension extension = filepath.suffix assert extension == ".py", "only .py files are supported" # import the config as module remove_path = False if filepath.parent not in sys.path: sys.path.insert(0, (filepath)) remove_path = True module_name = filepath.stem source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 # load into config config = BaseConfig() config["config"] = filename assert hasattr(module, "config"), f"{filename} does not have config attribute" config_list = module.config if merge_configs: for cfg in config_list: cfg.check_args() for arg_name, arg_val in cfg: assert arg_name not in config, f"Duplicate attribute {arg_name} in configs." # TODO check duplicated attributes in which configs (frank) config._add_item(arg_name, arg_val) else: # TODO to be tested (frank) for cfg in config_list: cfg.check_args() cfg_cls_name = cfg.__class__.__name__ config._add_item(cfg_cls_name, cfg) # process console args if console_args: for arg_name, arg_val in console_args.items(): assert arg_name in config, f"Unknown argument {arg_name}." arg_type = type(config[arg_name]) assert arg_type in ( int, float, str, bool, ), f"Unsupported type {arg_type} for argument {arg_name}, please set this argument in config file." if arg_type is bool: if isinstance(arg_val, bool): pass elif arg_val.lower() in ["true", "1"]: arg_val = True elif arg_val.lower() in ["false", "0"]: arg_val = False else: raise ValueError(f"Unknown boolean value {arg_val} for argument {arg_name}.") # elif arg_type is list: # try: # arg_val = eval(arg_val) # except Exception: # raise ValueError(f"Unknown list value {arg_val} for argument {arg_name}.") else: try: arg_val = arg_type(arg_val) except Exception as e: raise ValueError(f"Invalid value {arg_val} for argument {arg_name}.") from e config.update({arg_name: arg_val}) # remove module del sys.modules[module_name] if remove_path: sys.path.pop(0) return config
def save_config(self, file_path: str) -> None: # all_attrs = copy.deepcopy(vars(self)) # for config in self.config_list: # all_attrs.update(vars(config)) with open(file_path, "w", encoding="utf-8") as file: for attr_name in sorted(self): file.write(f"{attr_name} = {self[attr_name]}\n")