Source code for pyrevs.strategies.base.strategy
"""Defines the generic interface for sampling strategies."""
from __future__ import annotations
import datetime
import logging
from abc import ABC
from abc import abstractmethod
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import TypeVar
from typing import final
if TYPE_CHECKING:
from collections.abc import Callable
from pyrevs.core import Config
from pyrevs.database import Database
_logger = logging.getLogger(__name__)
[docs]
T = TypeVar("T", bound="BaseSamplingStrategy")
[docs]
class BaseSamplingStrategy(ABC):
"""An interface for all rare-event algorithms.
Define the common interface for sampling strategies within
the sampler object.
A registry is used to store all available strategies.
It is managed using a decorator and entry_points.
Subclasses must implement the following methods:
- :meth:`sample`
- :meth:`out_of_time`
"""
# Registry, loaded on first use
_registry: ClassVar[dict[str, type[BaseSamplingStrategy]]] = {}
_strategies_loaded: ClassVar[bool] = False
@classmethod
def _load_strategies(cls) -> None:
"""Load all available strategies."""
if cls._strategies_loaded:
return
eps = entry_points(group="pyrevs.strategies")
for ep in eps:
try:
ep.load()
except (ImportError, AttributeError) as exc: # noqa: PERF203
wrn_msg = f"Failed to load strategy {ep.name}: {exc}"
_logger.warning(wrn_msg)
cls._strategies_loaded = True
@classmethod
[docs]
def register(cls, name: str) -> Callable[[type[T]], type[T]]:
"""Register a new strategy.
Args:
name: the strategy name
"""
key = name.lower()
def decorator(subclass: type[T]) -> type[T]:
if key in cls._registry:
err_msg = f"Strategy {key} already registered"
raise ValueError(err_msg)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
[docs]
def create(cls, name: str, *args: Any, **kwargs: Any) -> BaseSamplingStrategy:
"""Instantiate a strategy out of the registry.
Args:
name: the strategy name
*args: positional arguments
**kwargs: keyword arguments
"""
cls._load_strategies()
key = name.lower()
try:
return cls._registry[key](*args, **kwargs)
except KeyError as err:
err_msg = f"Unknown strategy type: {key}"
raise ValueError(err_msg) from err
@classmethod
[docs]
def available_strategies(cls) -> list[str]:
"""Return list of registered strategy names."""
cls._load_strategies()
return sorted(cls._registry.keys())
# Time management uses UTC date
_start_date: datetime.datetime
_end_date: datetime.datetime
_min_remaining_time: float
_MIN_REMAINING_TIME_RATIO: ClassVar[float] = 0.05
@final
[docs]
def sample(self, database: Database, walltime: float, plot_diags: bool) -> None:
"""Run the sampling lifecycle.
This method handles walltime bookkeeping and delegates the
algorithm implementation to ``_execute_sampling``.
Subclasses must implement ``_execute_sampling`` and should
regularly check ``out_of_time()`` to terminate gracefully.
Args:
database: Database used for storing results.
walltime: Maximum allowed runtime in seconds.
plot_diags: Whether to enable diagnostic plotting.
"""
self._start_date = self._now()
self._end_date = self._start_date + datetime.timedelta(seconds=walltime)
self._min_remaining_time = self._MIN_REMAINING_TIME_RATIO * walltime
self._execute_sampling(database, plot_diags)
def _now(self) -> datetime.datetime:
"""Return the current time."""
return datetime.datetime.now(tz=datetime.timezone.utc)
[docs]
def remaining_time(self) -> float:
"""Return the remaining wallclock time."""
if not hasattr(self, "_end_date"):
err_msg = "Sampling has not been started. Call 'sample()' first."
raise RuntimeError(err_msg)
return (self._end_date - self._now()).total_seconds()
[docs]
def out_of_time(self) -> bool:
"""Return true if insufficient walltime remains."""
return self.remaining_time() <= self._min_remaining_time
[docs]
def elapsed_time(self) -> float:
"""Return the elapsed wallclock time."""
if not hasattr(self, "_start_date"):
err_msg = "Sampling has not been started. Call 'sample()' first."
raise RuntimeError(err_msg)
return (self._now() - self._start_date).total_seconds()
@abstractmethod
def _execute_sampling(self, database: Database, plot_diags: bool) -> None:
"""Implement the core sampling algorithm.
This method is called by :meth:`sample` after time bookkeeping
has been initialized. Concrete implementations should
implement the actual rare event sampling algorithm.
Implementations should:
- Regularly call :meth:`out_of_time` to respect walltime limits
- Store intermediate and final results in ``database``
- Optionally produce diagnostics if ``plot_diags`` is True
Args:
database: Database used for storing results.
plot_diags: Whether to enable diagnostic plotting.
Notes:
This method should terminate gracefully when time is exhausted.
"""
raise NotImplementedError
@abstractmethod
[docs]
def initialize_database_schema(self, database: Database, diag_configs: dict[str, Config] | None) -> None:
"""Initialize the schema of the database."""
raise NotImplementedError
def __repr__(self) -> str:
"""Return a string representation of the object."""
return f"{self.__class__.__name__}()"