Source code for pyrevs.core.config

"""A configuration class to expose limited configuration."""

from __future__ import annotations
from collections.abc import Mapping
from dataclasses import MISSING
from dataclasses import Field
from dataclasses import asdict
from dataclasses import fields
from dataclasses import is_dataclass
from typing import Any
from typing import TypeVar
from typing import cast
from typing import get_type_hints

[docs] T = TypeVar("T")
[docs] class MergePolicy: """A merge policy for configuration dataclasses."""
[docs] IMMUTABLE = "immutable"
[docs] REPLACE = "replace"
[docs] class Config: """Lightweight structured access to configuration. pyREVS input parameters are mostly handled through dataclasses to ensure default values are available and types are enforced: TOML file -> Config -> Typed dataclass. The typed dataclass are the final object used to instanciate pyREVS objects (Sampler, Database, Trajectories, ...) The Config class provides a simple interface to access configuration parameters in a more structured way. To add a new section to the TOML input file, create a new dataclass with a `__section__` class attribute set to the name of the section. """ def __init__(self, data: Mapping[str, Any]) -> None: self._data = data
[docs] def section(self, name: str) -> Config: """Return a configuration section.""" value = self._data.get(name, {}) if not isinstance(value, Mapping): err_msg = f"Config section {name} is not a mapping !" raise TypeError(err_msg) return Config(value)
[docs] def section_dict(self, name: str) -> dict[str, Any]: """Return a section as dict.""" value = self._data.get(name, {}) if not isinstance(value, Mapping): err_msg = f"Config section {name} is not a mapping !" raise TypeError(err_msg) return dict(value)
[docs] def require(self, section: str, key: str) -> Any: """Get a required parameter.""" try: return self._data[section][key] except KeyError as exc: err_msg = f"Missing required config: [{section}].{key}" raise ValueError(err_msg) from exc
[docs] def get(self, section: str, key: str, default: Any = None) -> Any: """Get an optional parameter.""" return self._data.get(section, {}).get(key, default)
[docs] def as_dict(self) -> dict[str, Any]: """Return the config Mapping as a dict.""" return dict(self._data)
[docs] def load(self, cls: type[T]) -> T: """Load a dataclass of the provided type from this config section.""" if not is_dataclass(cls): err_msg = f"{cls} is not a dataclass" raise TypeError(err_msg) # Use the custom __load__ method if available # This casting is an easy fix for mypy # but defining a Protocol would be more elegant if hasattr(cls, "__load__"): return cast("T", cast("Any", cls).__load__(self)) cfg = self._resolve_section(cls) kwargs = self._build_kwargs(cfg, cls) return cls(**kwargs)
def _resolve_section(self, cls: type[T]) -> Config: """Return the appropriate Config view for a given dataclass. If the dataclass defines a ``__section__`` attribute, this method extracts that subsection from the underlying data and wraps it into a new ``Config`` instance. Otherwise, the current instance is returned. Args: cls: The dataclass type being loaded. Returns: A ``Config`` instance scoped to the relevant section. Raises: TypeError: If the resolved section is not a mapping. """ section_name = getattr(cls, "__section__", None) if section_name is None: return self sub = self._data.get(section_name, {}) if not isinstance(sub, Mapping): err_msg = f"Config section {section_name} is not a mapping !" raise TypeError(err_msg) return Config(sub) def _build_kwargs(self, cfg: Config, cls: type[Any]) -> dict[str, Any]: """Construct keyword arguments for dataclass instantiation. This method iterates over the dataclass fields, resolves their values from the configuration data (or defaults), and applies recursive loading for nested dataclasses. Args: cfg: The Config instance scoped to the relevant section. cls: The dataclass type to instantiate. Returns: A dictionary of field names to resolved values. """ type_hints = get_type_hints(cls) data = cfg.as_dict() kwargs: dict[str, Any] = {} for f in fields(cls): value = self._get_field_value(cfg, data, f, type_hints[f.name]) kwargs[f.name] = value return kwargs def _get_field_value( self, cfg: Config, data: Mapping[str, Any], f: Field[Any], ftype: Any, ) -> Any: """Resolve the value of a single dataclass field. The resolution order is: 1. Value from configuration data 2. Field default 3. Field default factory If none are available, a missing value placeholder is returned. Nested dataclasses are resolved recursively. Args: cfg: The Config instance for nested resolution. data: The raw mapping for the current section. f: The dataclass field definition. ftype: The resolved type hint for the field. Returns: The resolved field value. """ key = f.name if key in data: value = data[key] elif f.default is not MISSING: value = f.default elif f.default_factory is not MISSING: value = f.default_factory() else: err_msg = f"Missing required field: {key}" raise ValueError(err_msg) return self._resolve_nested(cfg, value, ftype) def _resolve_nested(self, cfg: Config, value: Any, ftype: Any) -> Any: """Resolve nested dataclass values. If the field type is a dataclass, this method will: - Load it from its declared section if it defines ``__section__`` - Otherwise, attempt to load it from an inline mapping Non-dataclass values are returned unchanged. Args: cfg: The Config instance used for loading. value: The raw value extracted from configuration data. ftype: The type hint associated with the field. Returns: The resolved value, possibly transformed into a dataclass instance. """ if isinstance(ftype, type) and is_dataclass(ftype): section_name = getattr(ftype, "__section__", None) if section_name is not None: return cfg.load(ftype) if isinstance(value, Mapping): return Config(value).load(ftype) return value
[docs] def collect_sections(*configs: Any) -> dict[str, Any]: """Build a TOML-ready dict from config dataclasses. Each dataclass must define __section__. """ result: dict[str, Any] = {} for cfg in configs: if cfg is None: continue section = getattr(type(cfg), "__section__", None) if section is None: err_msg = f"{type(cfg).__name__} has no __section__" raise ValueError(err_msg) result[section] = asdict(cfg) return result
[docs] def merge_config(old: T, new: T) -> T: """Merge two configuration dataclass objects.""" if old is None: return new if new is None: return old policy = getattr(type(old), "__merge_policy__", MergePolicy.IMMUTABLE) if policy == MergePolicy.IMMUTABLE: if old != new: err_msg = f"{type(old).__name__} is immutable" raise ValueError(err_msg) return old if policy == MergePolicy.REPLACE: return new err_msg = f"Unknown merge policy for {type(old).__name__}" raise ValueError(err_msg)