"""A base class for the stochastic forward model."""fromabcimportABCMetafromabcimportabstractmethodfromloggingimportgetLoggerfrompathlibimportPathfromtypingimportAnyfromtypingimportfinal_logger=getLogger(__name__)
[docs]classForwardModelBaseClass(metaclass=ABCMeta):"""A base class for the stochastic forward model. pyTAMS relies on a separation of the stochastic model, encapsulating the physics of interest, and the TAMS algorithm itself. The ForwardModelBaseClass defines the API the TAMS algorithm requires from the stochastic model. Concrete model classes must implement all the abstract functions defined in this base class. The base class handles some components needed by TAMS, so that the user does not have to ensure compatibility with TAMS requirements. Attributes: _noise: the noise to be used in the next model step _step: the current stochastic step counter _time: the current stochastic time _workdir: the working directory """@finaldef__init__(self,params:dict[Any,Any],ioprefix:str|None=None,workdir:Path|None=None):"""Base class __init__ method. The ABC init method calls the concrete class init method while performing some common initializations. Additionally, this method create/append to a model dictionary to the parameter dictionary to ensure the 'deterministic' parameter is always available in the model dictionary. Upon initializing the model, a first call to make_noise is made to ensure the proper type is generated. Args: params: a dict containing parameters ioprefix: an optional string defining run folder workdir: an optional path to the working directory """# Initialize common toolingself._noise:Any=Noneself._step:int=0self._time:float=0.0self._workdir:Path=Path.cwd()ifworkdirisNoneelseworkdir# Add the deterministic parameter to the model dictionary# for consistencyifparams.get("model"):params["model"]["deterministic"]=params.get("tams",{}).get("deterministic",False)else:params["model"]={"deterministic":params.get("tams",{}).get("deterministic",False)}# Call the concrete class init methodself._init_model(params,ioprefix)# Generate the first noise increment# to at least get the proper type.self._noise=self.make_noise()@final
[docs]defadvance(self,dt:float,need_end_state:bool)->float:"""Base class advance function of the model. This is the advance function called by TAMS internals. It handles updating the model time and step counter, as well as reusing or generating noise only when needed. It also handles exceptions. Args: dt: the time step size over which to advance need_end_state: whether the step end state is needed Return: Some model will not do exactly dt (e.g. sub-stepping) return the actual dt """try:actual_dt=self._advance(self._step,self._time,dt,self._noise,need_end_state)# Update internal counter. Note that actual_dt may differ# from requested dt in some occasions.self._step=self._step+1self._time=self._time+actual_dtexceptException:err_msg="Advance function ran into an error !"_logger.exception(err_msg)raisereturnactual_dt
@final
[docs]defget_noise(self)->Any:"""Return the model's latest noise increment."""returnself._noise
@final
[docs]defset_noise(self,a_noise:Any)->None:"""Set the model's next noise increment."""self._noise=a_noise
[docs]defset_workdir(self,workdir:Path)->None:"""Setter of the model working directory. Args: workdir: the new working directory """self._workdir=workdir
@abstractmethoddef_init_model(self,params:dict[Any,Any]|None=None,ioprefix:str|None=None)->None:"""Concrete class specific initialization. Args: params: an optional dict containing parameters ioprefix: an optional string defining run folder """@abstractmethoddef_advance(self,step:int,time:float,dt:float,noise:Any,need_end_state:bool)->float:"""Concrete class advance function. This is the model-specific advance function. Args: step: the current step counter time: the starting time of the advance call dt: the time step size over which to advance noise: the noise to be used in the model step need_end_state: whether the step end state is needed Return: Some model will not do exactly dt (e.g. sub-stepping) return the actual dt """@abstractmethod
[docs]defget_current_state(self)->Any:"""Return the current state of the model."""
@abstractmethod
[docs]defset_current_state(self,state:Any)->Any:"""Set the current state of the model."""
@abstractmethod
[docs]defscore(self)->float:"""Return the model's current state score."""
@abstractmethod
[docs]defmake_noise(self)->Any:"""Return the model's latest noise increment."""
@final
[docs]defpost_trajectory_branching_hook(self,step:int,time:float)->None:"""Model post trajectory branching hook. Args: step: the current step counter time: the time of the simulation """self._step=stepself._time=timeself._trajectory_branching_hook()
def_trajectory_branching_hook(self)->None:"""Model-specific post trajectory branching hook."""@final
[docs]defpost_trajectory_restore_hook(self,step:int,time:float)->None:"""Model post trajectory restore hook. Args: step: the current step counter time: the time of the simulation """self._step=stepself._time=timeself._trajectory_restore_hook()
def_trajectory_restore_hook(self)->None:"""Model-specific post trajectory restore hook."""
[docs]defcheck_convergence(self,step:int,time:float,current_score:float,target_score:float)->bool:"""Check if the model has converged. This default implementation checks if the current score is greater than or equal to the target score. The user can override this method to implement a different convergence criterion. Args: step: the current step counter time: the time of the simulation current_score: the current score target_score: the target score """_=(step,time)returncurrent_score>=target_score
def_clear_model(self)->Any:"""Clear the concrete forward model internals."""@classmethod
[docs]defname(cls)->str:"""Return a the model name."""return"BaseClassForwardModel"