Source code for pyrevs.strategies.base.termination
"""Defines the interface and simple implementations of termination criteria."""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Protocol
from typing import TypeVar
if TYPE_CHECKING:
from pyrevs.core import ForwardModelBaseClass
from pyrevs.trajectory import Trajectory
[docs]
T_Noise = TypeVar("T_Noise")
[docs]
T_State = TypeVar("T_State")
[docs]
class TerminationCriterion(Protocol):
"""Termination criterion interface."""
[docs]
def should_terminate(
self,
model: ForwardModelBaseClass[T_Noise, T_State],
trajectory: Trajectory,
) -> bool:
"""Check if the trajectory should terminate.
Args:
model: the forward model
trajectory: the trajectory
Returns:
True if the trajectory should terminate
"""
raise NotImplementedError
[docs]
class TimeTerminationCriterion(TerminationCriterion):
"""Termination criterion based on time.
Will trigger termination if the current time is greater than or equal
to the end time.
"""
def __init__(self, end_time: float) -> None:
"""Initialize the termination criterion.
Args:
end_time: the end time
"""
self._end_time = end_time
[docs]
def should_terminate(
self,
model: ForwardModelBaseClass[T_Noise, T_State],
trajectory: Trajectory,
) -> bool:
"""Check if the trajectory should terminate."""
_ = model
return trajectory.current_time() >= self._end_time
[docs]
class LowScoreTerminationCriterion(TerminationCriterion):
"""Termination criterion based on score.
Will trigger termination if the current score is less than or equal
to a threshold.
"""
def __init__(self, score_threshold: float) -> None:
"""Initialize the termination criterion.
Args:
score_threshold: the score threshold
"""
self._score_threshold = score_threshold
[docs]
def should_terminate(
self,
model: ForwardModelBaseClass[T_Noise, T_State],
trajectory: Trajectory,
) -> bool:
"""Check if the trajectory should terminate."""
_ = trajectory
return model.score() <= self._score_threshold
[docs]
class ModelTerminationCriterion(TerminationCriterion):
"""Termination criterion based on model.
Will trigger termination if the forward model has decides to.
"""
[docs]
def should_terminate(
self,
model: ForwardModelBaseClass[T_Noise, T_State],
trajectory: Trajectory,
) -> bool:
"""Check if the trajectory should terminate."""
_ = trajectory
return model.check_termination(trajectory.current_step(), trajectory.current_time())