from __future__ import annotations
import logging
import time
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
import numpy as np
import numpy.typing as npt
from pytams.xmlutils import dict_to_xml
from pytams.xmlutils import make_xml_snapshot
from pytams.xmlutils import new_element
from pytams.xmlutils import read_xml_snapshot
from pytams.xmlutils import xml_to_dict
if TYPE_CHECKING:
from pytams.fmodel import ForwardModelBaseClass
_logger = logging.getLogger(__name__)
[docs]
class WallTimeLimitError(Exception):
"""Exception for running into wall time limit."""
[docs]
def get_index_from_id(identity: str) -> tuple[int, int]:
"""Helper to get trajectory index from ID string.
Args:
identity : trajectory ID
Returns:
trajectory index and number of branching
"""
return int(identity[-10:-5]), int(identity[-4:])
@dataclass
[docs]
class Snapshot:
"""A dataclass defining a snapshot.
Gathering what defines a snapshot into an object.
The time and score are of float type, but the
actual type of the noise and state are completely
determined by the forward model.
A snapshot is allowed to have a state or not to
accomodate memory savings.
Attributes:
time : snapshot time
score : score function value
noise : noise used to reach this snapshot
state : model state
"""
[docs]
state: Any | None = None
[docs]
def has_state(self) -> bool:
"""Check if snapshot has state.
Returns:
bool : True if state is not None
"""
return self.state is not None
[docs]
class Trajectory:
"""A class defining a stochastic trajectory.
The trajectory class is a container for time-ordered snapshots.
It contains an instance of the forward model, current and end times, and
a list of the model snapshots. Note that the class uses a plain list of snapshots
and not a more computationally efficient data structure such as a numpy array
for convenience. It is assumed that the computational cost of running TAMS
reside in the forward model and the overhead of the trajectory class is negligible.
It also provide the forward model with the necessary context to advance in time,
method to move forward in time, methods to save/load the trajectory to/from disk
as well as accessor to the trajectory history (time, state, score, ...).
Attributes:
_parameters_full : the full parameters dictionary
_tid : the trajectory index
_checkFile : the trajectory checkpoint file
_workdir : the model working directory
_score_max : the maximum score
_snaps : a list of snapshots
_step : the current step counter
_t_cur : the current time
_t_end : the end time
_dt : the stochastic time step size
"""
def __init__(
self,
traj_id: int,
fmodel_t: type[ForwardModelBaseClass] | None,
parameters: dict[Any, Any],
workdir: Path | None = None,
frozen: bool = False,
) -> None:
"""Create a trajectory.
Args:
traj_id: a int for the trajectory index
fmodel_t: the forward model type
parameters: a dictionary of input parameters
workdir: an optional working directory
frozen: whether the trajectory is frozen (no fmodel)
"""
# Stash away the full parameters dict
self._parameters_full: dict[Any, Any] = parameters
traj_params = parameters.get("trajectory", {})
if "end_time" not in traj_params or "step_size" not in traj_params:
err_msg = "Trajectory 'end_time' and 'step_size' must be specified in the input file !"
_logger.error(err_msg)
raise ValueError
# The workdir is a runtime parameter, not saved in the chkfile.
self._tid: int = traj_id
self._workdir: Path = Path.cwd() if workdir is None else workdir
self._score_max: float = 0.0
self._has_ended: bool = False
self._has_converged: bool = False
# TAMS is expected to start at t = 0.0, but the forward model
# itself can have a different internal starting point
# or an entirely different time scale.
self._step: int = 0
self._t_cur: float = 0.0
self._t_end: float = traj_params.get("end_time")
self._dt: float = traj_params.get("step_size")
# Trajectory convergence is defined by a target score, with
# the score provided by the forward model, mapping the model state to
# a s \in [0,1]. A default value of 0.95 is provided.
self._convergedVal: float = traj_params.get("targetscore", 0.95)
# List of snapshots
self._snaps: list[Snapshot] = []
# When using sparse state or for other reasons, the noise for the next few
# steps might be already available. This backlog is used to store them.
self._noise_backlog: list[Any] = []
# Keep track of the branching history during TAMS
# iterations
self._branching_history: list[int] = []
# For large models, the state may not be available at each snapshot due
# to memory constraint (both in-memory and on-disk). Sparse state can
# be specified. Finally, writing a chkfile to disk at each step might
# incur a performance hit and is by default disabled.
self._sparse_state_int: int = traj_params.get("sparse_freq", 1)
self._sparse_state_beg: int = traj_params.get("sparse_start", 0) + 1
self._write_chkfile_all: bool = traj_params.get("chkfile_dump_all", False)
self._checkFile: Path = Path(f"{self.idstr()}.xml")
# Each trajectory has its own instance of the forward model
if frozen or fmodel_t is None:
self._fmodel = None
else:
self._fmodel = fmodel_t(parameters, self.idstr(), self._workdir)
[docs]
def set_checkfile(self, path: Path) -> None:
"""Setter of the trajectory checkFile.
Args:
path: the new checkFile
"""
self._checkFile = path
[docs]
def set_workdir(self, path: Path) -> None:
"""Setter of the trajectory working directory.
And propagate the workdir to the forward model.
Args:
path: the new working directory
"""
self._workdir = path
if self._fmodel is not None:
self._fmodel.set_workdir(path)
[docs]
def get_workdir(self) -> Path:
"""Get the trajectory working directory.
Returns:
the working directory
"""
return self._workdir
[docs]
def id(self) -> int:
"""Return trajectory Id.
Returns:
the trajectory id
"""
return self._tid
[docs]
def idstr(self) -> str:
"""Return trajectory Id as a padded string.
Returns:
the trajectory id as a string
"""
return form_trajectory_id(self._tid, self.get_nbranching())
[docs]
def advance(self, t_end: float = 1.0e12, walltime: float = 1.0e12) -> None:
"""Advance the trajectory to a prescribed end time.
This is the main time loop of the trajectory object.
Unless specified otherwise, the trajectory will advance until
the end time is reached or the model has converged.
If the walltime limit is reached, a WallTimeLimitError exception is raised.
Note that this exception is treated as a warning not an error by the
TAMS workers.
Args:
t_end: the end time of the advance
walltime: a walltime limit to advance the model to t_end
Returns:
None
Raises:
WallTimeLimitError: if the walltime limit is reached
"""
start_time = time.monotonic()
remaining_time = walltime - time.monotonic() + start_time
end_time = min(t_end, self._t_end)
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.error(err_msg)
raise RuntimeError(err_msg)
while self._t_cur < end_time and not self._has_converged and remaining_time >= 0.05 * walltime:
# Do a single step and keep track of remaining walltime
_ = self._one_step()
remaining_time = walltime - time.monotonic() + start_time
if self._t_cur >= self._t_end or self._has_converged:
self._has_ended = True
if self._has_ended:
self._fmodel.clear()
if remaining_time < 0.05 * walltime:
warn_msg = f"{self.idstr()} ran out of time in advance()"
_logger.warning(warn_msg)
raise WallTimeLimitError(warn_msg)
def _one_step(self) -> float:
"""Perform a single step of the forward model.
Perform a single time step of the forward model. This
function will also set the noise to use for the next step
in the forward model if a backlog is available.
"""
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.exception(err_msg)
raise RuntimeError(err_msg)
# Add the initial snapshot to the list
if self._step == 0:
self._setup_noise()
self._append_snapshot()
# Trigger storing the end state of the current time step
# if the next trajectory snapshot needs it
need_end_state = (self._sparse_state_beg + self._step + 1) % self._sparse_state_int == 0
try:
dt = self._fmodel.advance(self._dt, need_end_state)
except Exception:
err_msg = f"ForwardModel advance error at step {self._step:08}"
_logger.exception(err_msg)
raise
self._step += 1
self._t_cur = self._t_cur + dt
score = self._fmodel.score()
# Prepare the noise for the next step
self._setup_noise()
# Append a snapshot at the beginning of the time step
self._append_snapshot(score)
if self._write_chkfile_all:
self.store()
self._score_max = max(self._score_max, score)
# The default ABC method simply check for a score above
# the target value, but concrete implementations can override
# with mode complex convergence criteria
self._has_converged = self._fmodel.check_convergence(self._step, self._t_cur, score, self._convergedVal)
return score
def _setup_noise(self) -> None:
"""Prepare the noise for the next step."""
# Set the noise for the next model step
# if a noise backlog is available, use it otherwise
# make a new noise increment
if self._fmodel:
if self._noise_backlog:
self._fmodel.set_noise(self._noise_backlog.pop())
else:
self._fmodel.set_noise(self._fmodel.make_noise())
def _append_snapshot(self, score: float | None = None) -> None:
"""Append the current snapshot to the trajectory list."""
# Append the current snapshot to the trajectory list
if self._fmodel:
need_state = (self._sparse_state_beg + self._step) % self._sparse_state_int == 0 or self._step == 0
self._snaps.append(
Snapshot(
time=self._t_cur,
score=score if score else self._fmodel.score(),
noise=self._fmodel.get_noise(),
state=self._fmodel.get_current_state() if need_state else None,
),
)
@classmethod
[docs]
def restore_from_checkfile(
cls,
checkfile: Path,
fmodel_t: type[ForwardModelBaseClass],
parameters: dict[Any, Any],
workdir: Path | None = None,
frozen: bool = False,
) -> Trajectory:
"""Return a trajectory restored from an XML chkfile."""
if not checkfile.exists():
err_msg = f"Trajectory {checkfile} does not exist."
_logger.exception(err_msg)
raise FileNotFoundError
# Read in trajectory metadata
tree = ET.parse(checkfile.absolute())
root = tree.getroot()
metadata = xml_to_dict(root.find("metadata"))
t_id = metadata["id"]
rest_traj = Trajectory(
traj_id=t_id,
fmodel_t=fmodel_t,
parameters=parameters,
workdir=workdir,
frozen=frozen,
)
rest_traj._t_end = metadata["t_end"]
rest_traj._t_cur = metadata["t_cur"]
rest_traj._dt = metadata["dt"]
rest_traj._score_max = metadata["score_max"]
rest_traj._has_ended = metadata["ended"]
rest_traj._has_converged = metadata["converged"]
rest_traj._branching_history = metadata["branching_history"].tolist()
snapshots = root.find("snapshots")
if snapshots is not None:
for snap in snapshots:
time, score, noise, state = read_xml_snapshot(snap)
rest_traj._snaps.append(Snapshot(time, score, noise, state))
# Current step with python indexing, so remove 1
rest_traj._step = len(rest_traj._snaps) - 1
# If the trajectory is frozen, that is all we need. Otherwise
# handle sparse state, noise backlog and necessary fmodel initialization
if rest_traj._fmodel:
# Remove snapshots from the list until a state is available
need_update = False
for k in range(len(rest_traj._snaps) - 1, -1, -1):
if not rest_traj._snaps[k].has_state():
# Append the noise history to the backlog
rest_traj._noise_backlog.append(rest_traj._snaps[k].noise)
rest_traj._snaps.pop()
need_update = True
else:
# Because the noise in the snapshot is the noise
# used to reach the next state, append the last to the backlog too
rest_traj._noise_backlog.append(rest_traj._snaps[k].noise)
break
rest_traj._t_cur = rest_traj._snaps[-1].time
# Current step with python indexing, so remove 1
rest_traj._step = len(rest_traj._snaps) - 1
# Ensure everything is set to start the time stepping loop
rest_traj._setup_noise()
rest_traj._fmodel.set_current_state(rest_traj._snaps[-1].state)
# Reset score_max, ended and converged
if need_update:
rest_traj.update_metadata()
# Enable the model to perform tweaks
# after a trajectory restore
rest_traj._fmodel.post_trajectory_restore_hook(len(rest_traj._snaps) - 1, rest_traj._t_cur)
return rest_traj
@classmethod
[docs]
def branch_from_trajectory(
cls,
from_traj: Trajectory,
rst_traj: Trajectory,
score: float,
) -> Trajectory:
"""Create a new trajectory.
Loading the beginning of a provided trajectory
for all entries with score below a given score.
This effectively branches the trajectory.
Although the rst_traj is provided as an argument, it is
only used to set metadata of the branched trajectory.
Args:
from_traj: an already existing trajectory to restart from
rst_traj: the trajectory being restarted
score: a threshold score
"""
# Check for empty trajectory
if len(from_traj._snaps) == 0:
tid, nb = get_index_from_id(rst_traj.idstr())
new_workdir = Path(rst_traj.get_workdir().parents[0] / form_trajectory_id(tid, nb + 1))
fmodel_t = type(from_traj._fmodel) if from_traj._fmodel else None
rest_traj = Trajectory(
traj_id=rst_traj.id(),
fmodel_t=fmodel_t,
parameters=from_traj._parameters_full,
workdir=new_workdir,
)
rest_traj.set_checkfile(Path(rst_traj.get_checkfile().parents[0] / f"{rest_traj.idstr()}.xml"))
return rest_traj
# To ensure that TAMS converges, branching occurs on
# the first snapshot with a score *strictly* above the target
# Traverse the trajectory until a snapshot with a score >
# the target is encountered
high_score_idx = 0
last_snap_with_state = 0
while from_traj._snaps[high_score_idx].score <= score:
high_score_idx += 1
if from_traj._snaps[high_score_idx].has_state():
last_snap_with_state = high_score_idx
# Init empty trajectory
tid, nb = get_index_from_id(rst_traj.idstr())
new_workdir = Path(rst_traj.get_workdir().parents[0] / form_trajectory_id(tid, nb + 1))
fmodel_t = type(from_traj._fmodel) if from_traj._fmodel else None
rest_traj = Trajectory(
traj_id=rst_traj.id(),
fmodel_t=fmodel_t,
parameters=from_traj._parameters_full,
workdir=new_workdir,
)
rest_traj._branching_history = rst_traj._branching_history
rest_traj._branching_history.append(from_traj.id())
rest_traj.set_checkfile(Path(rst_traj.get_checkfile().parents[0] / f"{rest_traj.idstr()}.xml"))
# Append snapshots, up to high_score_idx + 1 to
# ensure > behavior
for k in range(high_score_idx + 1):
if k < last_snap_with_state:
rest_traj._snaps.append(from_traj._snaps[k])
elif k == last_snap_with_state:
rest_traj._snaps.append(from_traj._snaps[k])
rest_traj._noise_backlog.append(from_traj._snaps[k].noise)
else:
rest_traj._noise_backlog.append(from_traj._snaps[k].noise)
# Reverse the backlog to ensure correct order
rest_traj._noise_backlog.reverse()
# Update trajectory metadata
rest_traj._t_cur = rest_traj._snaps[-1].time
rest_traj._step = len(rest_traj._snaps) - 1
if rest_traj._fmodel:
rest_traj._setup_noise()
rest_traj._fmodel.set_current_state(rest_traj._snaps[-1].state)
rest_traj.update_metadata()
# Enable the model to perform tweaks
# after a trajectory restart
rest_traj._fmodel.post_trajectory_branching_hook(len(rest_traj._snaps) - 1, rest_traj._t_cur)
return rest_traj
[docs]
def store(self, traj_file: Path | None = None) -> None:
"""Store the trajectory to an XML chkfile."""
root = ET.Element(self.idstr())
root.append(dict_to_xml("params", self._parameters_full["trajectory"]))
mdata = ET.SubElement(root, "metadata")
mdata.append(new_element("id", self._tid))
mdata.append(new_element("t_cur", self._t_cur))
mdata.append(new_element("t_end", self._t_end))
mdata.append(new_element("dt", self._dt))
mdata.append(new_element("score_max", self._score_max))
mdata.append(new_element("ended", self._has_ended))
mdata.append(new_element("converged", self._has_converged))
mdata.append(new_element("branching_history", np.array(self._branching_history, dtype=int)))
snaps_xml = ET.SubElement(root, "snapshots")
for k in range(len(self._snaps)):
snaps_xml.append(
make_xml_snapshot(
k,
self._snaps[k].time,
self._snaps[k].score,
self._snaps[k].noise,
self._snaps[k].state,
),
)
tree = ET.ElementTree(root)
ET.indent(tree, space="\t", level=0)
if traj_file is not None:
tree.write(traj_file.as_posix())
else:
tree.write(self._checkFile.as_posix())
[docs]
def current_time(self) -> float:
"""Return the current trajectory time."""
return self._t_cur
[docs]
def step_size(self) -> float:
"""Return the time step size."""
return self._dt
[docs]
def score_max(self) -> float:
"""Return the maximum of the score function."""
return self._score_max
[docs]
def is_converged(self) -> bool:
"""Return True for converged trajectory."""
return self._has_converged
[docs]
def has_ended(self) -> bool:
"""Return True for terminated trajectory."""
return self._has_ended
[docs]
def has_started(self) -> bool:
"""Return True if computation has started."""
return self._t_cur > 0.0
[docs]
def get_checkfile(self) -> Path:
"""Return the trajectory check file name."""
return self._checkFile
[docs]
def get_time_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory time instants."""
times = np.zeros(len(self._snaps))
for k in range(len(self._snaps)):
times[k] = self._snaps[k].time
return times
[docs]
def get_score_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory scores."""
scores = np.zeros(len(self._snaps))
for k in range(len(self._snaps)):
scores[k] = self._snaps[k].score
return scores
[docs]
def get_noise_array(self) -> npt.NDArray[Any]:
"""Return the trajectory noises."""
noises = np.zeros(len(self._snaps), dtype=type(self._snaps[0].noise))
for k in range(len(self._snaps)):
noises[k] = self._snaps[k].noise
return noises
[docs]
def get_length(self) -> int:
"""Return the trajectory length."""
return len(self._snaps)
[docs]
def get_nbranching(self) -> int:
"""Return the number of branching events."""
return len(self._branching_history)
[docs]
def get_last_state(self) -> Any | None:
"""Return the last state in the trajectory."""
for snap in reversed(self._snaps):
if snap.has_state():
return snap.state
return None