Source code for easydel.utils.cli_helpers

# Modified from huggingface and vLLM argparser

import dataclasses
import json
import os
import sys
import types
import typing as tp
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
from inspect import isclass
from pathlib import Path

import yaml

DataClass = tp.NewType("DataClass", tp.Any)
DataClassType = tp.NewType("DataClassType", tp.Any)


[docs]def string_to_bool(v: tp.Union[str, bool]) -> bool: """ Convert a string to a boolean. Accepts various string representations for truthy and falsy values. """ if isinstance(v, bool): return v lower_v = v.lower() if lower_v in ("yes", "true", "t", "y", "1"): return True elif lower_v in ("no", "false", "f", "n", "0"): return False raise ArgumentTypeError( f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." )
[docs]def make_choice_type_function(choices: tp.List[tp.Any]) -> tp.Callable[[str], tp.Any]: """ Create a function that maps a string representation to an actual value from choices. """ str_to_choice = {str(choice): choice for choice in choices} return lambda arg: str_to_choice.get(arg, arg)
[docs]def Argu( *, aliases: tp.Optional[tp.Union[str, tp.List[str]]] = None, help: tp.Optional[str] = None, default: tp.Any = dataclasses.MISSING, default_factory: tp.Callable[[], tp.Any] = dataclasses.MISSING, metadata: tp.Optional[dict] = None, **kwargs, ) -> dataclasses.Field: if metadata is None: metadata = {} if aliases is not None: metadata["aliases"] = aliases if help is not None: metadata["help"] = help return dataclasses.field( metadata=metadata, default=default, default_factory=default_factory, **kwargs )
[docs]class DataClassArgumentParser(ArgumentParser): """ A subclass of argparse.ArgumentParser that automatically generates arguments based on dataclass type hints. It supports additional argparse features (like sub-groups) and can also load configuration from dictionaries, JSON files, or YAML files. """ dataclass_types: tp.Iterable[DataClassType] def __init__( self, dataclass_types: tp.Union[DataClassType, tp.Iterable[DataClassType]], **kwargs: tp.Any, ) -> None: # Use ArgumentDefaultsHelpFormatter to show default values in --help if not specified if "formatter_class" not in kwargs: kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter super().__init__(**kwargs) if dataclasses.is_dataclass(dataclass_types): dataclass_types = [dataclass_types] self.dataclass_types = list(dataclass_types) for dtype in self.dataclass_types: self._add_dataclass_arguments(dtype) @staticmethod def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field) -> None: """ Convert a dataclass field into a corresponding argparse argument. """ # Create long option names (e.g. --my_field and --my-field) long_options = [f"--{field.name}"] if "_" in field.name: long_options.append(f"--{field.name.replace('_', '-')}") # Start with a copy of the field's metadata kwargs = field.metadata.copy() if isinstance(field.type, str): raise RuntimeError( f"Unresolved type detected for field '{field.name}'. " "Ensure that type annotations are fully resolved." ) # Handle any aliases provided aliases = kwargs.pop("aliases", []) if isinstance(aliases, str): aliases = [aliases] # Process union types (only tp.Optional[T] is supported) origin_type = getattr(field.type, "__origin__", None) if origin_type in (tp.Union, getattr(types, "UnionType", None)): union_args = field.type.__args__ if len(union_args) == 2 and type(None) in union_args: # tp.Optional[T] detected: choose the non-None type field.type = next(arg for arg in union_args if arg is not type(None)) origin_type = getattr(field.type, "__origin__", None) else: raise ValueError( f"Only tp.Optional types (tp.Union[T, None]) are supported for field '{field.name}', got {field.type}." ) # Special handling for booleans bool_kwargs: tp.Dict[str, tp.Any] = {} if field.type is bool: bool_kwargs = copy(kwargs) kwargs["type"] = string_to_bool default_val = False if field.default is dataclasses.MISSING else field.default kwargs["default"] = default_val kwargs["nargs"] = "?" kwargs["const"] = True # Handle tp.Literal or Enum types as a fixed set of choices elif origin_type is tp.Literal or ( isinstance(field.type, type) and issubclass(field.type, Enum) ): if origin_type is tp.Literal: kwargs["choices"] = field.type.__args__ else: kwargs["choices"] = [member.value for member in field.type] kwargs["type"] = make_choice_type_function(kwargs["choices"]) if field.default is not dataclasses.MISSING: kwargs["default"] = field.default else: kwargs["required"] = True # Handle list types (expecting at least one value) elif isclass(field.type) and issubclass(field.type, list): kwargs["type"] = field.type.__args__[0] kwargs["nargs"] = "+" if field.default_factory is not dataclasses.MISSING: kwargs["default"] = field.default_factory() elif field.default is dataclasses.MISSING: kwargs["required"] = True else: kwargs["type"] = field.type if field.default is not dataclasses.MISSING: kwargs["default"] = field.default elif field.default_factory is not dataclasses.MISSING: kwargs["default"] = field.default_factory() else: kwargs["required"] = True # Add the main argument to the parser parser.add_argument(*long_options, *aliases, **kwargs) # For boolean fields that default to True, add a complementary --no_* option if field.type is bool and field.default is True: bool_kwargs["default"] = False parser.add_argument( f"--no_{field.name}", f"--no-{field.name.replace('_', '-')}", action="store_false", dest=field.name, **bool_kwargs, ) def _add_dataclass_arguments(self, dtype: DataClassType) -> None: """ Add arguments for all 'init'-enabled fields of a dataclass. """ group_name = getattr(dtype, "_argument_group_name", None) parser = self.add_argument_group(group_name) if group_name else self try: type_hints: tp.Dict[str, type] = tp.get_type_hints(dtype) except NameError as e: raise RuntimeError( f"Type resolution failed for {dtype}. Consider declaring the class in global scope or disabling " "PEP 563 (postponed evaluation of annotations)." ) from e except TypeError as ex: if sys.version_info < (3, 10) and "unsupported operand type(s) for |" in str(ex): python_version = ".".join(map(str, sys.version_info[:3])) raise RuntimeError( f"Type resolution failed for {dtype} on Python {python_version}. " "Please use typing.tp.Union and typing.tp.Optional instead of the | syntax for union types." ) from ex raise for field in dataclasses.fields(dtype): if not field.init: continue field.type = type_hints[field.name] self._parse_dataclass_field(parser, field)
[docs] def parse_args_into_dataclasses( self, args: tp.Optional[tp.List[str]] = None, return_remaining_strings: bool = False, look_for_args_file: bool = True, args_filename: tp.Optional[str] = None, args_file_flag: tp.Optional[str] = None, ) -> tp.Tuple[tp.Any, ...]: """ Parse command-line arguments into instances of the specified dataclass types. Optionally, this method can also look for an external ".args" file or a command-line flag that points to one, and prepend its content to the command-line arguments. Raises: ValueError: If there are any unknown arguments (and return_remaining_strings is False). """ if args_file_flag or args_filename or (look_for_args_file and sys.argv): args_files: tp.List[Path] = [] if args_filename: args_files.append(Path(args_filename)) elif look_for_args_file and sys.argv: args_files.append(Path(sys.argv[0]).with_suffix(".args")) if args_file_flag: # Create a temporary parser to extract file path(s) args_file_parser = ArgumentParser(add_help=False) args_file_parser.add_argument(args_file_flag, type=str, action="append") cfg, args = args_file_parser.parse_known_args(args=args) cmd_args_file_paths = getattr(cfg, args_file_flag.lstrip("-"), None) if cmd_args_file_paths: args_files.extend(Path(p) for p in cmd_args_file_paths) file_args: tp.List[str] = [] for args_file in args_files: if args_file.exists(): file_args.extend(args_file.read_text(encoding="utf-8").split()) if args is None: args = sys.argv[1:] # Command-line arguments take precedence over those in files. args = file_args + args namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: field_names = {f.name for f in dataclasses.fields(dtype) if f.init} init_args = {k: v for k, v in vars(namespace).items() if k in field_names} # Remove used keys from the namespace for key in init_args: delattr(namespace, key) outputs.append(dtype(**init_args)) if namespace.__dict__: outputs.append(namespace) if return_remaining_strings: return (*outputs, remaining_args) elif remaining_args: raise ValueError( f"Some arguments were not used by DataClassArgumentParser: {remaining_args}" ) return tuple(outputs)
[docs] def parse_dict( self, args: tp.Dict[str, tp.Any], allow_extra_keys: bool = False ) -> tp.Tuple[tp.Any, ...]: """ Parse a dictionary of configuration values into dataclass instances. Args: args: Dictionary containing configuration values. allow_extra_keys: If False, raises an exception if unknown keys are present. """ unused_keys = set(args.keys()) outputs = [] for dtype in self.dataclass_types: field_names = {f.name for f in dataclasses.fields(dtype) if f.init} init_args = {k: v for k, v in args.items() if k in field_names} unused_keys -= init_args.keys() outputs.append(dtype(**init_args)) if not allow_extra_keys and unused_keys: raise ValueError(f"Unused keys in configuration: {sorted(unused_keys)}") return tuple(outputs)
[docs] def parse_json_file( self, json_file: tp.Union[str, os.PathLike], allow_extra_keys: bool = False, ) -> tp.Tuple[tp.Any, ...]: """ Load a JSON file and parse it into dataclass instances. """ with open(Path(json_file), encoding="utf-8") as f: data = json.load(f) return self.parse_dict(data, allow_extra_keys=allow_extra_keys)
[docs] def parse_yaml_file( self, yaml_file: tp.Union[str, os.PathLike], allow_extra_keys: bool = False, ) -> tp.Tuple[tp.Any, ...]: """ Load a YAML file and parse it into dataclass instances. """ yaml_text = Path(yaml_file).read_text(encoding="utf-8") data = yaml.safe_load(yaml_text) return self.parse_dict(data, allow_extra_keys=allow_extra_keys)