from __future__ import annotations
import copy
import json
import logging
import shutil
import time
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import TypeVar
from typing import cast
import numpy as np
import numpy.typing as npt
from pytams.diagdb import DiagDB
from pytams.diagnostic import DiagnosticPlugin
from pytams.diagnostic import diagnosticfactory
from pytams.snapshot import Snapshot
from pytams.xmlutils import dict_to_xml
from pytams.xmlutils import make_xml_snapshot
from pytams.xmlutils import read_xml_snapshot
if TYPE_CHECKING:
from pytams.fmodel import ForwardModelBaseClass
_logger = logging.getLogger(__name__)
[docs]
T_Noise = TypeVar("T_Noise")
[docs]
T_State = TypeVar("T_State")
[docs]
class WallTimeLimitError(Exception):
"""Exception for running into wall time limit."""
[docs]
def get_index_from_id(identity: str) -> tuple[int, int]:
"""Helper to get trajectory index from ID string.
Args:
identity : trajectory ID
Returns:
trajectory index and number of branching
"""
return int(identity[-10:-5]), int(identity[-4:])
[docs]
class Trajectory(Generic[T_Noise, T_State]):
"""A class defining a stochastic trajectory.
The trajectory class is a container for time-ordered snapshots.
It contains an instance of the forward model, current and end times, and
a list of the model snapshots. Note that the class uses a plain list of snapshots
and not a more computationally efficient data structure such as a numpy array
for convenience. It is assumed that the computational cost of running TAMS
resides in the forward model and the overhead of the trajectory class is negligible.
It also provide the forward model with the necessary context to advance in time,
method to move forward in time, methods to save/load the trajectory to/from disk
as well as accessor to the trajectory history (time, state, score, ...).
The _computed_steps variable store the number of steps actually taken by the
trajectory. It differs from the _step variable when a trajectory is branched
from an ancestor.
Attributes:
_parameters_full : the full parameters dictionary
_tid : the trajectory index
_checkFile : the trajectory checkpoint file
_workdir : the model working directory
_score_max : the maximum score
_snaps : a list of snapshots
_step : the current step counter
_computed_steps : the number of steps explicitly advanced by the trajectory
_t_cur : the current time
_t_end : the end time
_dt : the stochastic time step size
"""
# Maximum number of Trajectory objects and history length
# These are used in assembling the unique ID for each trajectory
def __init__(
self,
traj_id: int,
weight: float,
fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]] | None,
parameters: dict[Any, Any],
workdir: Path | None = None,
frozen: bool = False,
) -> None:
"""Create a trajectory.
Args:
traj_id: a int for the trajectory index
weight: the trajectory weight in the ensemble
fmodel_t: the forward model type
parameters: a dictionary of input parameters
workdir: an optional working directory
frozen: whether the trajectory is frozen (no fmodel)
"""
# Stash away the full parameters dict
self._parameters_full: dict[Any, Any] = parameters
traj_params = parameters.get("trajectory", {})
if "end_time" not in traj_params or "step_size" not in traj_params:
err_msg = "Trajectory 'end_time' and 'step_size' must be specified in the input file !"
_logger.error(err_msg)
raise ValueError
# The workdir is a runtime parameter, not saved in the chkfile.
self._tid: int = traj_id
self._workdir: Path = Path.cwd() if workdir is None else workdir
self._score_max: float = -1.0e12
self._has_ended: bool = False
self._has_converged: bool = False
self._computed_steps: int = 0
self._weight: float = weight
# TAMS is expected to start at t = 0.0, but the forward model
# itself can have a different internal starting point
# or an entirely different time scale.
self._step: int = 0
self._t_cur: float = 0.0
self._t_end: float = traj_params.get("end_time")
self._dt: float = traj_params.get("step_size")
# Trajectory convergence is defined by a target score, with
# the score provided by the forward model, mapping the model state to
# a s \in [0,1]. A default value of 0.95 is provided.
self._convergedVal: float = traj_params.get("targetscore", 0.95)
# List of snapshots
self._snaps: list[Snapshot[T_Noise, T_State]] = []
# When using sparse state or for other reasons, the noise for the next few
# steps might be already available. This backlog is used to store them.
[docs]
self.noise_backlog: list[T_Noise] = []
# Keep track of the branching history during TAMS
# iterations
self._branching_history: list[int] = []
# For large models, the state may not be available at each snapshot due
# to memory constraint (both in-memory and on-disk). Sparse state can
# be specified. Finally, writing a chkfile to disk at each step might
# incur a performance hit and is by default disabled.
self._sparse_state_int: int = traj_params.get("sparse_freq", 1)
self._sparse_state_beg: int = traj_params.get("sparse_start", 0) + 1
self._write_chkfile_all: bool = traj_params.get("chkfile_dump_all", False)
self._checkFile: Path = Path(f"{self.idstr()}.xml")
# Each trajectory has its own instance of the forward model
if frozen or fmodel_t is None:
self._fmodel = None
else:
self._fmodel = fmodel_t(
self._tid * (self.max_nbranch + 1) + self.get_nbranching(), parameters, self._workdir
)
# Diagnostics
# They are initialize upon the first call so that
# the diagnostic database access is not pickled in the task
self._has_diagnostics = self._check_for_diag_request()
self._diagplugins: list[DiagnosticPlugin] = []
self._ddb: DiagDB | None = None
self._initialized_diags = False
[docs]
def set_checkfile(self, path: Path) -> None:
"""Setter of the trajectory checkFile.
Args:
path: the new checkFile
"""
self._checkFile = path
[docs]
def set_workdir(self, path: Path) -> None:
"""Setter of the trajectory working directory.
And propagate the workdir to the forward model.
Args:
path: the new working directory
"""
self._workdir = path
if self._fmodel is not None:
self._fmodel.set_workdir(path)
[docs]
def get_workdir(self) -> Path:
"""Get the trajectory working directory.
Returns:
the working directory
"""
return self._workdir
[docs]
def id(self) -> int:
"""Return trajectory index.
This is the index of the trajectory in the ensemble.
Returns:
the trajectory id
"""
return self._tid
[docs]
def unique_id(self) -> int:
"""Return trajectory unique Id.
Combining the index and the number of branching events.
This makes a unique ID.
Returns:
the trajectory unique id
"""
return self._tid * (self.max_nbranch + 1) + self.get_nbranching()
[docs]
def idstr(self) -> str:
"""Return trajectory Id as a padded string.
Returns:
the trajectory id as a string
"""
return form_trajectory_id(self._tid, self.get_nbranching())
[docs]
def advance(self, t_end: float = 1.0e12, walltime: float = 1.0e12) -> None:
"""Advance the trajectory to a prescribed end time.
This is the main time loop of the trajectory object.
Unless specified otherwise, the trajectory will advance until
the end time is reached or the model has converged.
If the walltime limit is reached, a WallTimeLimitError exception is raised.
Note that this exception is treated as a warning not an error by the
TAMS workers.
Args:
t_end: the end time of the advance
walltime: a walltime limit to advance the model to t_end
Returns:
None
Raises:
WallTimeLimitError: if the walltime limit is reached
RuntimeError: if the model advance run into a problem
"""
start_time = time.monotonic()
remaining_time = walltime - time.monotonic() + start_time
end_time = min(t_end, self._t_end)
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.error(err_msg)
raise RuntimeError(err_msg)
while self._t_cur < end_time and not self._has_converged and remaining_time >= 0.05 * walltime:
# Do a single step and keep track of remaining walltime
_ = self._one_step()
# Initialize diagnostic now, access to diagdb
# no longer needs to be pickled at this point
if not self._initialized_diags and self._has_diagnostics:
self._setup_diagnostics()
# Perform any diagnostic requested
for plugin in self._diagplugins:
plugin.update(self._snaps[-2], self._snaps[-1])
remaining_time = walltime - time.monotonic() + start_time
if self._t_cur >= self._t_end or self._has_converged:
self._has_ended = True
if self._has_ended:
self._fmodel.clear()
# Clear the diagnostic
if self._ddb is not None and self._initialized_diags:
self._ddb.close()
self._ddb = None
self._diagplugins = []
self._initialized_diags = False
if remaining_time < 0.05 * walltime:
warn_msg = f"{self.idstr()} ran out of time in advance()"
_logger.warning(warn_msg)
raise WallTimeLimitError(warn_msg)
def _one_step(self) -> float:
"""Perform a single step of the forward model.
Perform a single time step of the forward model. This
function will also set the noise to use for the next step
in the forward model if a backlog is available.
"""
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.exception(err_msg)
raise RuntimeError(err_msg)
# Add the initial snapshot to the list
if self._step == 0:
self.setup_noise()
self._append_snapshot()
# Trigger storing the end state of the current time step
# if the next trajectory snapshot needs it
need_end_state = (self._sparse_state_beg + self._step + 1) % self._sparse_state_int == 0
try:
dt = self._fmodel.advance(self._dt, need_end_state)
except Exception:
err_msg = f"ForwardModel advance error at step {self._step:08}"
_logger.exception(err_msg)
raise
self._step += 1
self._t_cur = self._t_cur + dt
score = self._fmodel.score()
# Prepare the noise for the next step
self.setup_noise()
# Append a snapshot at the beginning of the time step
self._append_snapshot(score)
if self._write_chkfile_all:
self.store()
self._score_max = max(self._score_max, score)
# The default ABC method simply check for a score above
# the target value, but concrete implementations can override
# with mode complex convergence criteria
self._has_converged = self._fmodel.check_convergence(self._step, self._t_cur, score, self._convergedVal)
# Increment the computed step counter
self._computed_steps += 1
return score
def _check_for_diag_request(self) -> bool:
"""Check if any diagnostics are requested in the parameters.
Returns:
A boolean indicating if diagnostics are requested
"""
diag_list = self._parameters_full.get("tams", {}).get("diagnostics", [])
# Return True only if the list exists and has at least one entry
return isinstance(diag_list, list) and len(diag_list) > 0
def _setup_diagnostics(self) -> None:
"""Setup the diagnostic."""
if self._fmodel is not None:
if self._workdir == Path.cwd():
self._ddb = DiagDB("./diagDB.db")
else:
db_path = self._workdir.parents[1] / "./diagDB.db"
self._ddb = DiagDB(db_path.absolute().as_posix())
self._diagplugins = diagnosticfactory(
self._parameters_full,
self.unique_id(),
self._weight,
self._workdir,
self._fmodel.diagnostic_hook,
self._ddb,
)
self._initialized_diags = True
def _branch_diagnostics(
self, ancestor_id: int, discarded_id: int, child_id: int, child_weight: float, score_threshold: float
) -> None:
"""Duplicate diagnostics entry while branching.
Args:
ancestor_id: the ID of the ancestor traj to duplicate
discarded_id: the ID of the discarded traj
child_id: the ID of the child traj
child_weight: the weight of the child traj
score_threshold: the score threshold up to which duplication is needed
"""
if self._workdir == Path.cwd():
ddb = DiagDB("./diagDB.db")
else:
db_path = self._workdir.parents[1] / "./diagDB.db"
ddb = DiagDB(db_path.absolute().as_posix())
ddb.duplicate_diagnostic_history(ancestor_id, discarded_id, child_id, child_weight, score_threshold)
ddb.close()
[docs]
def setup_noise(self) -> None:
"""Prepare the noise for the next step."""
# Set the noise for the next model step
# if a noise backlog is available, use it otherwise
# make a new noise increment
if self._fmodel:
if self.noise_backlog:
self._fmodel.noise = self.noise_backlog.pop()
else:
self._fmodel.noise = self._fmodel.make_noise()
def _append_snapshot(self, score: float | None = None) -> None:
"""Append the current snapshot to the trajectory list."""
# Append the current snapshot to the trajectory list
if self._fmodel:
need_state = (self._sparse_state_beg + self._step) % self._sparse_state_int == 0 or self._step == 0
self._snaps.append(
Snapshot[T_Noise, T_State](
time=self._t_cur,
score=score if score else self._fmodel.score(),
noise=self._fmodel.noise,
state=self._fmodel.get_current_state() if need_state else None,
),
)
@classmethod
@classmethod
[docs]
def restore_from_checkfile(
cls,
checkfile: Path,
metadata: dict[str, Any],
fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]],
parameters: dict[Any, Any],
workdir: Path | None = None,
frozen: bool = False,
) -> Trajectory[T_Noise, T_State]:
"""Return a trajectory restored from an XML chkfile."""
if not checkfile.exists():
err_msg = f"Trajectory {checkfile} does not exist."
_logger.exception(err_msg)
raise FileNotFoundError
rest_traj: Trajectory[T_Noise, T_State] = Trajectory.init_from_metadata(
metadata, fmodel_t, parameters, workdir, frozen
)
# Read in trajectory data
tree = ET.parse(checkfile.absolute())
root = tree.getroot()
snapshots = root.find("snapshots")
if snapshots is not None:
for snap in snapshots:
time, score, noise, state = read_xml_snapshot(snap)
rest_traj._snaps.append(Snapshot[T_Noise, T_State](time=time, score=score, noise=noise, state=state))
# If the trajectory is frozen, that is all we need. Otherwise
# handle sparse state, noise backlog and necessary fmodel initialization
if rest_traj._fmodel:
# Remove snapshots from the list until a state is available
for k in range(len(rest_traj._snaps) - 1, -1, -1):
if not rest_traj._snaps[k].has_state:
# Append the noise history to the backlog
rest_traj.noise_backlog.append(rest_traj._snaps[k].noise)
rest_traj._snaps.pop()
else:
# Because the noise in the snapshot is the noise
# used to reach the next state, append the last to the backlog too
rest_traj.noise_backlog.append(rest_traj._snaps[k].noise)
break
# Current step with python indexing, so remove 1
rest_traj.set_current_time_and_step(rest_traj._snaps[-1].time, len(rest_traj._snaps) - 1)
# Ensure everything is set to start the time stepping loop
rest_traj.setup_noise()
# mypy is not that bright, need explicit check here
if rest_traj._snaps[-1].state is not None:
rest_traj._fmodel.set_current_state(rest_traj._snaps[-1].state)
# Enable the model to perform tweaks
# after a trajectory restore
rest_traj._fmodel.post_trajectory_restore_hook(len(rest_traj._snaps) - 1, rest_traj.current_time())
return rest_traj
@classmethod
[docs]
def branch_from_trajectory(
cls,
from_traj: Trajectory[T_Noise, T_State],
rst_traj: Trajectory[T_Noise, T_State],
score: float,
new_weight: float,
) -> Trajectory:
"""Create a new trajectory.
Loading the beginning of a provided trajectory
for all entries with score below a given score.
This effectively branches the trajectory.
Although the rst_traj is provided as an argument, it is
only used to set metadata of the branched trajectory.
Args:
from_traj: an already existing trajectory to restart from
rst_traj: the trajectory being restarted
score: a threshold score
new_weight: the weight of the child trajectory
"""
# Initialize the new trajectory object (handles path and ID logic)
rest_traj = cls._init_branched_trajectory(from_traj, rst_traj, new_weight)
# Return empty the empty trajectory immediately
if not from_traj._snaps:
return rest_traj
# locate the branching points
high_score_idx, last_snap_with_state = cls._find_branch_indices(from_traj, score)
# Transfer the data from the ancestor to the child
cls._transfer_data(from_traj, rest_traj, high_score_idx, last_snap_with_state)
# Finalize branching
cls._finalize_branch(from_traj.unique_id(), rst_traj.unique_id(), rest_traj, new_weight)
return rest_traj
@classmethod
def _init_branched_trajectory(
cls, from_traj: Trajectory[T_Noise, T_State], rst_traj: Trajectory[T_Noise, T_State], weight: float
) -> Trajectory[T_Noise, T_State]:
"""Initialize a new trajectory for branching.
Args:
from_traj: an already existing trajectory to restart from
rst_traj: the trajectory being restarted
weight: the weight of the child trajectory
"""
tid, nb = get_index_from_id(rst_traj.idstr())
new_name = form_trajectory_id(tid, nb + 1)
new_workdir = Path(rst_traj.get_workdir().parents[0] / new_name)
fmodel_t = type(from_traj._fmodel) if from_traj._fmodel else None
new_traj: Trajectory[T_Noise, T_State] = Trajectory(
traj_id=rst_traj.id(),
weight=weight,
fmodel_t=fmodel_t,
parameters=from_traj._parameters_full,
workdir=new_workdir,
)
xml_name = f"{new_traj.idstr()}.xml"
new_traj.set_checkfile(Path(rst_traj.get_checkfile().parents[0] / xml_name))
# Copy history if not empty
if from_traj._snaps:
new_traj._branching_history = copy.deepcopy(rst_traj._branching_history)
new_traj._branching_history.append(from_traj.id())
return new_traj
@staticmethod
def _find_branch_indices(from_traj: Trajectory[T_Noise, T_State], score_threshold: float) -> tuple[int, int]:
"""Finds the split index and the last index containing a state.
To ensure that TAMS converges, branching occurs on
the first snapshot with a score *strictly* above the target.
Args:
from_traj: the ancestor trajectory
score_threshold: the score threshold up to which we must copy
Returns:
a tuple with the index of the first snapshot above threshold and
the index of the last snapshot with a state (equal or below high_score_idx).
"""
high_score_idx = 0
last_state_idx = 0
for i, snap in enumerate(from_traj._snaps):
if snap.has_state:
last_state_idx = i
if snap.score > score_threshold:
high_score_idx = i
break
return high_score_idx, last_state_idx
@staticmethod
def _transfer_data(
from_traj: Trajectory[T_Noise, T_State], rest_traj: Trajectory[T_Noise, T_State], high_idx: int, state_idx: int
) -> None:
"""Transfer data from the ancestor to the child.
Args:
from_traj: the ancestor trajectory
rest_traj: the child trajectory
high_idx: the highest index of the snapshot to transfer
state_idx: the highest index of the snapshot with noise to transfer
"""
# Prepend existing backlog if states align
if state_idx == from_traj.get_last_state_id() and from_traj.noise_backlog:
rest_traj.noise_backlog.extend(reversed(from_traj.noise_backlog))
# Separate snapshots from noise backlog
# Everything up to the last state is a full snapshot
rest_traj._snaps = from_traj._snaps[: state_idx + 1]
# Everything between last state and high_score_idx becomes noise backlog
for k in range(state_idx, high_idx + 1):
rest_traj.noise_backlog.append(from_traj._snaps[k].noise)
rest_traj.noise_backlog.reverse()
@staticmethod
def _finalize_branch(
ancestor_id: int, discarded_id: int, rest_traj: Trajectory[T_Noise, T_State], weight: float
) -> None:
"""Finalizes metadata, model state, and diagnostics.
Args:
ancestor_id: the unique ID of the ancestor traj
discarded_id: the unique ID of the discarded traj
rest_traj: the child trajectory
weight: the weight of the new child trajectory
"""
rest_traj._t_cur = rest_traj._snaps[-1].time
rest_traj._step = len(rest_traj._snaps) - 1
if rest_traj._fmodel:
rest_traj.setup_noise()
last_snap = rest_traj._snaps[-1]
if last_snap.state is not None:
rest_traj._fmodel.set_current_state(last_snap.state)
rest_traj.update_metadata()
rest_traj._fmodel.post_trajectory_branching_hook(rest_traj._step, rest_traj._t_cur)
if rest_traj._has_diagnostics:
rest_traj._branch_diagnostics(
ancestor_id, discarded_id, rest_traj.unique_id(), weight, rest_traj._score_max
)
[docs]
def store(self, traj_file: Path | None = None, write_metadata_json: bool = False) -> None:
"""Store the trajectory data to an XML chkfile.
The default behavior is to store the trajectory into the
file specified by the attribute self._checkFile unless a
different path is provided.
The metadata are not writen to file by default, as the TAMS database
handle metadata in an SQL file. It can be triggered when using trajectories
in stand-alone.
Args:
traj_file: an optional path file to store the trajectory to
write_metadata_json: an optional boolean to also write the metadata json dict
"""
root = ET.Element(self.idstr())
root.append(dict_to_xml("params", self._parameters_full["trajectory"]))
snaps_xml = ET.SubElement(root, "snapshots")
for k in range(len(self._snaps)):
snaps_xml.append(
make_xml_snapshot(
k,
self._snaps[k].time,
self._snaps[k].score,
self._snaps[k].noise,
self._snaps[k].state,
),
)
tree = ET.ElementTree(root)
ET.indent(tree, space="\t", level=0)
if traj_file is not None:
tree.write(traj_file.as_posix())
else:
tree.write(self._checkFile.as_posix())
# Separately dumps the metadata if requested
if write_metadata_json:
json_path = (
traj_file.parent / Path(traj_file.stem + ".json")
if traj_file is not None
else self._checkFile.parent / Path(self._checkFile.stem + ".json")
)
with json_path.open("w") as f:
f.write(self.serialize_metadata_json())
[docs]
def set_weight(self, weight: float) -> None:
"""Set the trajectory weight.
Args:
weight: the (new) trajectory weight in the ensemble
"""
self._weight = weight
[docs]
def set_current_time_and_step(self, time: float, step: int) -> None:
"""Set the current time and step."""
self._t_cur = time
self._step = step
[docs]
def current_time(self) -> float:
"""Return the current trajectory time."""
return self._t_cur
[docs]
def step_size(self) -> float:
"""Return the time step size."""
return self._dt
[docs]
def score_max(self) -> float:
"""Return the maximum of the score function."""
return self._score_max
[docs]
def is_converged(self) -> bool:
"""Return True for converged trajectory."""
return self._has_converged
[docs]
def has_ended(self) -> bool:
"""Return True for terminated trajectory."""
return self._has_ended
[docs]
def has_started(self) -> bool:
"""Return True if computation has started."""
return self._t_cur > 0.0
[docs]
def get_checkfile(self) -> Path:
"""Return the trajectory check file name."""
return self._checkFile
[docs]
def get_time_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory time instants."""
return np.array([snap.time for snap in self._snaps], dtype=np.float64)
[docs]
def get_score_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory scores."""
return np.array([snap.score for snap in self._snaps], dtype=np.float64)
[docs]
def get_noise_array(self) -> npt.NDArray[Any]:
"""Return the trajectory noises."""
if not self._snaps:
return np.array([])
return np.array([snap.noise for snap in self._snaps])
[docs]
def get_state_list(self) -> list[tuple[int, T_State | None]]:
"""Return a list of states and associated indices.
Returns:
A list of tuples with index and states
"""
return [(k, self._snaps[k].state) for k in range(len(self._snaps)) if self._snaps[k].has_state]
[docs]
def get_length(self) -> int:
"""Return the trajectory length."""
return len(self._snaps)
[docs]
def get_nbranching(self) -> int:
"""Return the number of branching events."""
if len(self._branching_history) > self.max_nbranch:
err_msg = f"Branching history size exceeds maximum of {self.max_nbranch}!"
_logger.exception(err_msg)
raise RuntimeError(err_msg)
return len(self._branching_history)
[docs]
def get_computed_steps_count(self) -> int:
"""Return the number of compute steps taken."""
return self._computed_steps
[docs]
def get_last_state(self) -> T_State | None:
"""Return the last state in the trajectory."""
for snap in reversed(self._snaps):
if snap.has_state:
return snap.state
return None
[docs]
def get_last_state_id(self) -> int | None:
"""Return the id of the last state in the trajectory."""
for idx, snap in reversed(list(enumerate(self._snaps))):
if snap.has_state:
return idx
return None
@classmethod
[docs]
def delete(self) -> None:
"""Clear the trajectory on-disk data."""
self._checkFile.unlink()
if self._workdir.exists():
shutil.rmtree(self._workdir)