Source code for pyrevs.strategies.base.termination

"""Defines the interface and simple implementations of termination criteria."""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Protocol
from typing import TypeVar

if TYPE_CHECKING:
    from pyrevs.core import ForwardModelBaseClass
    from pyrevs.trajectory import Trajectory

[docs] T_Noise = TypeVar("T_Noise")
[docs] T_State = TypeVar("T_State")
[docs] class TerminationCriterion(Protocol): """Termination criterion interface."""
[docs] def should_terminate( self, model: ForwardModelBaseClass[T_Noise, T_State], trajectory: Trajectory, ) -> bool: """Check if the trajectory should terminate. Args: model: the forward model trajectory: the trajectory Returns: True if the trajectory should terminate """ raise NotImplementedError
[docs] class TimeTerminationCriterion(TerminationCriterion): """Termination criterion based on time. Will trigger termination if the current time is greater than or equal to the end time. """ def __init__(self, end_time: float) -> None: """Initialize the termination criterion. Args: end_time: the end time """ self._end_time = end_time
[docs] def should_terminate( self, model: ForwardModelBaseClass[T_Noise, T_State], trajectory: Trajectory, ) -> bool: """Check if the trajectory should terminate.""" _ = model return trajectory.current_time() >= self._end_time
[docs] class LowScoreTerminationCriterion(TerminationCriterion): """Termination criterion based on score. Will trigger termination if the current score is less than or equal to a threshold. """ def __init__(self, score_threshold: float) -> None: """Initialize the termination criterion. Args: score_threshold: the score threshold """ self._score_threshold = score_threshold
[docs] def should_terminate( self, model: ForwardModelBaseClass[T_Noise, T_State], trajectory: Trajectory, ) -> bool: """Check if the trajectory should terminate.""" _ = trajectory return model.score() <= self._score_threshold
[docs] class ModelTerminationCriterion(TerminationCriterion): """Termination criterion based on model. Will trigger termination if the forward model has decides to. """
[docs] def should_terminate( self, model: ForwardModelBaseClass[T_Noise, T_State], trajectory: Trajectory, ) -> bool: """Check if the trajectory should terminate.""" _ = trajectory return model.check_termination(trajectory.current_step(), trajectory.current_time())