Source code for pyrevs.strategies.ams.config
from dataclasses import dataclass
from dataclasses import field
from pyrevs.core import MergePolicy
@dataclass(frozen=True)
[docs]
class AMSConfig:
"""AMS strategy configuration."""
__section__ = "ams"
__merge_policy__ = MergePolicy.IMMUTABLE
[docs]
ntrajectories: int = field(
default=-1,
metadata={
"doc": "Number of trajectories to sample. Sensible value checked upon initialization.",
},
)
[docs]
nsplititer: int = field(
default=-1,
metadata={
"doc": "Number of splitting iterations. Sensible value checked upon initialization.",
},
)
[docs]
variant: str = field(
default="tams",
metadata={
"doc": "Variant of AMS to use (one of [tams, ams])",
},
)
[docs]
l_j: int = field(
default=1,
metadata={
"doc": "Number of score function levels discarded at each splitting iteration",
},
)
[docs]
init_ensemble_only: bool = field(
default=False,
metadata={
"doc": "Whether or not to stop after initializing the trajectory ensemble",
},
)
[docs]
end_time: float = field(
default=-1.0,
metadata={
"doc": "The end time of the trajectory (TAMS)",
},
)
[docs]
min_score: float | None = field(
default=None,
metadata={
"doc": "The minimum score of the trajectory (AMS)",
},
)
[docs]
def validate(self) -> None:
"""Validate AMS configuration."""
if self.ntrajectories <= 0:
err_msg = " AMSConfig.ntrajectories must be > 0"
raise ValueError(err_msg)
if self.nsplititer <= 0:
err_msg = " AMSConfig.nsplititer must be > 0"
raise ValueError(err_msg)
if self.variant not in ["tams", "ams", "hams"]:
err_msg = " AMSConfig.variant must be one of ['tams', 'ams', 'hams']"
raise ValueError(err_msg)
if self.variant == "tams" and self.end_time <= 0.0:
err_msg = " AMSConfig.end_time must be > 0 for TAMS"
raise ValueError(err_msg)
if self.variant == "ams" and self.min_score is None:
err_msg = " AMSConfig.min_score must be set for AMS"
raise ValueError(err_msg)
if self.variant == "hams" and (self.min_score is None or self.end_time <= 0.0):
err_msg = " AMSConfig.min_score and end_time must be set for HAMS"
raise ValueError(err_msg)