"""A trajectory class for individual MCMC runs."""
from __future__ import annotations
import copy
import json
import logging
import shutil
import time
import xml.etree.ElementTree as ET
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import TypeVar
from typing import cast
import numpy as np
import numpy.typing as npt
from pyrevs.core import Snapshot
from pyrevs.diagnostics import DiagDB
from pyrevs.diagnostics import DiagnosticPlugin
from pyrevs.diagnostics import diagnosticfactory
from pyrevs.utils.xmlutils import dict_to_xml
from pyrevs.utils.xmlutils import make_xml_snapshot
from pyrevs.utils.xmlutils import read_xml_snapshot
if TYPE_CHECKING:
from pyrevs.core import Config
from pyrevs.core import ForwardModelBaseClass
from pyrevs.strategies.base import TerminationCriterion
from .config import TrajectoryConfig
_logger = logging.getLogger(__name__)
[docs]
T_Noise = TypeVar("T_Noise")
[docs]
T_State = TypeVar("T_State")
[docs]
class WallTimeLimitError(Exception):
"""Exception for running into wall time limit."""
[docs]
def get_index_from_id(identity: str) -> tuple[int, int]:
"""Helper to get trajectory index from ID string.
Args:
identity : trajectory ID
Returns:
trajectory index and number of branching
"""
return int(identity[-10:-5]), int(identity[-4:])
[docs]
class Trajectory(Generic[T_Noise, T_State]):
"""A class defining a stochastic trajectory.
Each trajectory is a MCMC simulation of the forward model.
The trajectory class is a container for time-ordered snapshots.
It contains an instance of the forward model, current and end times, and
a list of the model snapshots. Note that the class uses a plain list of snapshots
and not a more computationally efficient data structure such as a numpy array
for convenience. It is assumed that the computational cost of running the sampling algorithm
resides in the forward model and the overhead of the trajectory class is negligible.
It also provide the forward model with the necessary context to advance in time,
method to move forward in time, methods to save/load the trajectory to/from disk
as well as accessor to the trajectory history (time, state, score, ...).
The _computed_steps variable store the number of steps actually taken by the
trajectory. It differs from the _step variable when a trajectory is branched
from an ancestor.
Attributes:
_tid : the trajectory index
_checkFile : the trajectory checkpoint file
_workdir : the model working directory
_score_max : the maximum score
_snaps : a list of snapshots
_step : the current step counter
_computed_steps : the number of steps explicitly advanced by the trajectory
_t_cur : the current time
_t_end : the end time
_dt : the stochastic time step size
"""
# Maximum number of Trajectory objects and history length
# These are used in assembling the unique ID for each trajectory
def __init__(
self,
traj_id: int,
weight: float,
fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]] | None,
traj_cfg: TrajectoryConfig,
diag_configs: dict[str, Config] | None = None,
model_params: dict[str, Any] | None = None,
workdir: Path | None = None,
deterministic: bool = False,
frozen: bool = False,
) -> None:
"""Initialize a trajectory.
Args:
traj_id: a int for the trajectory index
weight: the trajectory weight in the ensemble
fmodel_t: the forward model type
traj_cfg: the trajectory configuration
diag_configs: a dict with diagnostic configurations
model_params: the model configuration
workdir: an optional working directory
deterministic: whether the trajectory is deterministic
frozen: whether the trajectory is frozen (no fmodel)
"""
# Stash away the configuration(s)
self._traj_cfg = traj_cfg
self._model_params = model_params
self._deterministic = deterministic
# The workdir is a runtime parameter, not saved in the chkfile.
self._tid: int = traj_id
self._workdir: Path = Path.cwd() if workdir is None else workdir
self._score_max: float = -1.0e12
self._has_converged: bool = False
self._has_terminated: bool = False
self._computed_steps: int = 0
self._weight: float = weight
# Sampling is expected to start at t = 0.0, but the forward model
# itself can have a different internal starting point
# or an entirely different time scale.
self._step: int = 0
self._t_cur: float = 0.0
self._t_end: float = traj_cfg.end_time
self._dt: float = traj_cfg.step_size
# List of snapshots
self._snaps: list[Snapshot[T_Noise, T_State]] = []
# When using sparse state or for other reasons, the noise for the next few
# steps might be already available. This backlog is used to store them.
[docs]
self.noise_backlog: list[T_Noise] = []
# Keep track of the branching history for iterative
# sampling algorithms
self._branching_history: list[int] = []
# For large models, the state may not be available at each snapshot due
# to memory constraint (both in-memory and on-disk). Sparse state can
# be specified. Finally, writing a chkfile to disk at each step might
# incur a performance hit and is by default disabled.
self._checkFile: Path = Path(f"{self.idstr()}.xml")
# Each trajectory has its own instance of the forward model
if frozen or fmodel_t is None:
self._fmodel = None
else:
self._fmodel = fmodel_t(
self._tid * (self.max_nbranch + 1) + self.get_nbranching(), deterministic, model_params, self._workdir
)
# Diagnostics
# They are initialize upon the first call so that
# the diagnostic database access is not pickled in the task
self._has_diagnostics = diag_configs is not None
self._diag_configs: dict[str, Config] | None = diag_configs
self._diagplugins: list[DiagnosticPlugin] = []
self._ddb: DiagDB | None = None
self._initialized_diags = False
[docs]
def set_checkfile(self, path: Path) -> None:
"""Setter of the trajectory checkFile.
Args:
path: the new checkFile
"""
self._checkFile = path
[docs]
def set_workdir(self, path: Path) -> None:
"""Setter of the trajectory working directory.
And propagate the workdir to the forward model.
Args:
path: the new working directory
"""
self._workdir = path
if self._fmodel is not None:
self._fmodel.set_workdir(path)
[docs]
def get_workdir(self) -> Path:
"""Get the trajectory working directory.
Returns:
the working directory
"""
return self._workdir
[docs]
def id(self) -> int:
"""Return trajectory index.
This is the index of the trajectory in the ensemble.
Returns:
the trajectory id
"""
return self._tid
[docs]
def unique_id(self) -> int:
"""Return trajectory unique Id.
Combining the index and the number of branching events.
This makes a unique ID.
Returns:
the trajectory unique id
"""
return self._tid * (self.max_nbranch + 1) + self.get_nbranching()
[docs]
def idstr(self) -> str:
"""Return trajectory Id as a padded string.
Returns:
the trajectory id as a string
"""
return form_trajectory_id(self._tid, self.get_nbranching())
[docs]
def advance(
self,
termination_criteria: list[TerminationCriterion] | None = None,
nstep_end: int = -1,
t_end: float = -1.0,
walltime: float = 1.0e12,
) -> None:
"""Advance the trajectory to a prescribed end time.
This is the main time loop of the trajectory object.
Unless specified otherwise, the trajectory will advance until
the end time is reached or the model has converged.
If the walltime limit is reached, a WallTimeLimitError exception is raised.
Note that this exception is treated as a warning not an error by the
pyREVS workers.
Args:
nstep_end: the number of steps to advance
t_end: the end time of the advance
walltime: a walltime limit to advance the model to t_end
termination_criteria: a list of termination criterion
Returns:
None
Raises:
WallTimeLimitError: if the walltime limit is reached
RuntimeError: if the model advance run into a problem
"""
# Check if the trajectory is frozen
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.error(err_msg)
raise RuntimeError(err_msg)
start_time = time.monotonic()
end_time = self._calculate_end_time(t_end)
# Termination from sampling strategy termination criteria
if termination_criteria is None:
termination_criteria = []
# or from runtime arguments (end time, nstep_end)
# (A previous call to advance with other arguments might have switched the boolean)
self._has_terminated = any(
c.should_terminate(self._fmodel, self) for c in termination_criteria
) or self._runtime_termination(end_time, nstep_end)
while not (self._has_terminated or self._has_converged):
# Do a single model step
score = self._one_step()
# Check for termination/convergence
self._has_converged = self._fmodel.check_convergence(
self._step, self._t_cur, score, self._traj_cfg.targetscore
)
self._has_terminated = any(
c.should_terminate(self._fmodel, self) for c in termination_criteria
) or self._runtime_termination(end_time, nstep_end)
# Handle diagnostics
self._update_diagnostics()
# Check timeout before continuing
if (time.monotonic() - start_time) >= walltime:
break
# If the model has converged, terminate
if self._has_converged:
self._has_terminated = True
if self._has_terminated:
self._fmodel.clear()
# Clear the diagnostic
if self._ddb is not None and self._initialized_diags:
self._ddb.close()
self._ddb = None
self._diagplugins = []
self._initialized_diags = False
if (time.monotonic() - start_time) >= walltime:
warn_msg = f"{self.idstr()} ran out of time in advance()"
_logger.warning(warn_msg)
raise WallTimeLimitError(warn_msg)
def _runtime_termination(self, end_time: float, nstep_end: int) -> bool:
"""Returns True if the trajectory should terminate from the runtime arguments."""
step_termination = self._step >= nstep_end if nstep_end > 0 else False
time_termination = self._t_cur >= end_time if end_time > 0.0 else False
return step_termination or time_termination
def _calculate_end_time(self, t_end: float) -> float:
"""Returns the earliest positive end time, or -1.0 if none exist.
Args:
t_end: the end time of the advance
Returns:
the earliest positive end time
"""
valid_times = [t for t in (self._t_end, t_end) if t > 0.0]
return min(valid_times) if valid_times else -1.0
def _update_diagnostics(self) -> None:
"""Handle diagnostic initialization and plugin updates."""
if not self._has_diagnostics:
return
# Initialize diagnostic now, access to diagdb
# no longer needs to be pickled at this point
if not self._initialized_diags:
self._setup_diagnostics()
for plugin in self._diagplugins:
plugin.update(self._snaps[-2], self._snaps[-1])
def _one_step(self) -> float:
"""Perform a single step of the forward model.
Perform a single time step of the forward model. This
function will also set the noise to use for the next step
in the forward model if a backlog is available.
Args:
nstep_end: the maximum number of steps to advance
t_end: the end time of the advance
"""
if not self._fmodel:
err_msg = f"Trajectory {self.idstr()} is frozen, without forward model. Advance() deactivated."
_logger.exception(err_msg)
raise RuntimeError(err_msg)
# Add the initial snapshot to the list
if self._step == 0:
self.setup_noise()
self._append_snapshot()
# Trigger storing the end state of the current time step
# if the next trajectory snapshot needs it
need_end_state = (self._traj_cfg.sparse_start + 1 + self._step + 1) % self._traj_cfg.sparse_freq == 0
try:
dt = self._fmodel.advance(self._dt, need_end_state)
except Exception:
err_msg = f"ForwardModel advance error at step {self._step:08}"
_logger.exception(err_msg)
raise
self._step += 1
self._t_cur = self._t_cur + dt
score = self._fmodel.score()
# Prepare the noise for the next step
self.setup_noise()
# Append a snapshot at the beginning of the time step
self._append_snapshot(score)
if self._traj_cfg.chkfile_dump_all:
self.store()
self._score_max = max(self._score_max, score)
# Increment the computed step counter
self._computed_steps += 1
return score
def _setup_diagnostics(self) -> None:
"""Setup the diagnostic."""
if self._fmodel is not None:
if self._diag_configs is None:
self._initialized_diags = True
return
if self._workdir == Path.cwd():
self._ddb = DiagDB("./diagDB.db")
else:
db_path = self._workdir.parents[1] / "./diagDB.db"
self._ddb = DiagDB(db_path.absolute().as_posix())
self._diagplugins = diagnosticfactory(
self._diag_configs,
self.unique_id(),
self._weight,
self._workdir,
self._fmodel.diagnostic_hook,
self._ddb,
)
self._initialized_diags = True
def _branch_diagnostics(
self, ancestor_id: int, discarded_id: int, child_id: int, child_weight: float, score_threshold: float
) -> None:
"""Duplicate diagnostics entry while branching.
Args:
ancestor_id: the ID of the ancestor traj to duplicate
discarded_id: the ID of the discarded traj
child_id: the ID of the child traj
child_weight: the weight of the child traj
score_threshold: the score threshold up to which duplication is needed
"""
if self._workdir == Path.cwd():
ddb = DiagDB("./diagDB.db")
else:
db_path = self._workdir.parents[1] / "./diagDB.db"
ddb = DiagDB(db_path.absolute().as_posix())
ddb.duplicate_diagnostic_history(ancestor_id, discarded_id, child_id, child_weight, score_threshold)
ddb.close()
[docs]
def setup_noise(self) -> None:
"""Prepare the noise for the next step."""
# Set the noise for the next model step
# if a noise backlog is available, use it otherwise
# make a new noise increment
if self._fmodel:
if self.noise_backlog:
self._fmodel.noise = self.noise_backlog.pop()
else:
self._fmodel.noise = self._fmodel.make_noise()
def _append_snapshot(self, score: float | None = None) -> None:
"""Append the current snapshot to the trajectory list."""
# Append the current snapshot to the trajectory list
if self._fmodel:
need_state = (
self._traj_cfg.sparse_start + 1 + self._step
) % self._traj_cfg.sparse_freq == 0 or self._step == 0
self._snaps.append(
Snapshot[T_Noise, T_State](
time=self._t_cur,
score=score if score else self._fmodel.score(),
noise=self._fmodel.noise,
state=self._fmodel.get_current_state() if need_state else None,
),
)
@classmethod
@classmethod
[docs]
def restore_from_checkfile(
cls,
checkfile: Path,
metadata: dict[str, Any],
fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]],
traj_cfg: TrajectoryConfig,
diag_configs: dict[str, Config] | None = None,
model_params: dict[str, Any] | None = None,
workdir: Path | None = None,
frozen: bool = False,
) -> Trajectory[T_Noise, T_State]:
"""Return a trajectory restored from an XML chkfile."""
if not checkfile.exists():
err_msg = f"Trajectory {checkfile} does not exist."
_logger.exception(err_msg)
raise FileNotFoundError
rest_traj: Trajectory[T_Noise, T_State] = Trajectory.init_from_metadata(
metadata, fmodel_t, traj_cfg, diag_configs, model_params, workdir, frozen
)
# Read in trajectory data
tree = ET.parse(checkfile.absolute())
root = tree.getroot()
snapshots = root.find("snapshots")
if snapshots is not None:
for snap in snapshots:
time, score, noise, state = read_xml_snapshot(snap)
rest_traj._snaps.append(Snapshot[T_Noise, T_State](time=time, score=score, noise=noise, state=state))
# If the trajectory is frozen, that is all we need. Otherwise
# handle sparse state, noise backlog and necessary fmodel initialization
if rest_traj._fmodel:
# Remove snapshots from the list until a state is available
for k in range(len(rest_traj._snaps) - 1, -1, -1):
if not rest_traj._snaps[k].has_state:
# Append the noise history to the backlog
rest_traj.noise_backlog.append(rest_traj._snaps[k].noise)
rest_traj._snaps.pop()
else:
# Because the noise in the snapshot is the noise
# used to reach the next state, append the last to the backlog too
rest_traj.noise_backlog.append(rest_traj._snaps[k].noise)
break
# Current step with python indexing, so remove 1
rest_traj.set_current_time_and_step(rest_traj._snaps[-1].time, len(rest_traj._snaps) - 1)
# Ensure everything is set to start the time stepping loop
rest_traj.setup_noise()
# mypy is not that bright, need explicit check here
if rest_traj._snaps[-1].state is not None:
rest_traj._fmodel.set_current_state(rest_traj._snaps[-1].state)
# Enable the model to perform tweaks
# after a trajectory restore
rest_traj._fmodel.post_trajectory_restore_hook(len(rest_traj._snaps) - 1, rest_traj.current_time())
return rest_traj
@classmethod
[docs]
def branch_from_trajectory(
cls,
from_traj: Trajectory[T_Noise, T_State],
rst_traj: Trajectory[T_Noise, T_State],
score: float,
new_weight: float,
) -> Trajectory:
"""Create a new trajectory.
Loading the beginning of a provided trajectory
for all entries with score below a given score.
This effectively branches the trajectory.
Although the rst_traj is provided as an argument, it is
only used to set metadata of the branched trajectory.
Args:
from_traj: an already existing trajectory to restart from
rst_traj: the trajectory being restarted
score: a threshold score
new_weight: the weight of the child trajectory
"""
# Initialize the new trajectory object (handles path and ID logic)
rest_traj = cls._init_branched_trajectory(from_traj, rst_traj, new_weight)
# Return empty the empty trajectory immediately
if not from_traj._snaps:
return rest_traj
# locate the branching points
high_score_idx, last_snap_with_state = cls._find_branch_indices(from_traj, score)
# Transfer the data from the ancestor to the child
cls._transfer_data(from_traj, rest_traj, high_score_idx, last_snap_with_state)
# Finalize branching
cls._finalize_branch(from_traj.unique_id(), rst_traj.unique_id(), rest_traj, new_weight)
return rest_traj
@classmethod
def _init_branched_trajectory(
cls, from_traj: Trajectory[T_Noise, T_State], rst_traj: Trajectory[T_Noise, T_State], weight: float
) -> Trajectory[T_Noise, T_State]:
"""Initialize a new trajectory for branching.
Args:
from_traj: an already existing trajectory to restart from
rst_traj: the trajectory being restarted
weight: the weight of the child trajectory
"""
tid, nb = get_index_from_id(rst_traj.idstr())
new_name = form_trajectory_id(tid, nb + 1)
new_workdir = Path(rst_traj.get_workdir().parents[0] / new_name)
fmodel_t = type(from_traj._fmodel) if from_traj._fmodel else None
new_traj: Trajectory[T_Noise, T_State] = Trajectory(
traj_id=rst_traj.id(),
weight=weight,
fmodel_t=fmodel_t,
traj_cfg=from_traj._traj_cfg,
diag_configs=from_traj._diag_configs,
model_params=from_traj._model_params,
deterministic=from_traj._deterministic,
workdir=new_workdir,
)
xml_name = f"{new_traj.idstr()}.xml"
new_traj.set_checkfile(Path(rst_traj.get_checkfile().parents[0] / xml_name))
# Copy history if not empty
if from_traj._snaps:
new_traj._branching_history = copy.deepcopy(rst_traj._branching_history)
new_traj._branching_history.append(from_traj.id())
return new_traj
@staticmethod
def _find_branch_indices(from_traj: Trajectory[T_Noise, T_State], score_threshold: float) -> tuple[int, int]:
"""Finds the split index and the last index containing a state.
To ensure that sampling converges, branching occurs on
the first snapshot with a score *strictly* above the target.
Args:
from_traj: the ancestor trajectory
score_threshold: the score threshold up to which we must copy
Returns:
a tuple with the index of the first snapshot above threshold and
the index of the last snapshot with a state (equal or below high_score_idx).
"""
high_score_idx = 0
last_state_idx = 0
for i, snap in enumerate(from_traj._snaps):
if snap.has_state:
last_state_idx = i
if snap.score > score_threshold:
high_score_idx = i
break
return high_score_idx, last_state_idx
@staticmethod
def _transfer_data(
from_traj: Trajectory[T_Noise, T_State], rest_traj: Trajectory[T_Noise, T_State], high_idx: int, state_idx: int
) -> None:
"""Transfer data from the ancestor to the child.
Args:
from_traj: the ancestor trajectory
rest_traj: the child trajectory
high_idx: the highest index of the snapshot to transfer
state_idx: the highest index of the snapshot with noise to transfer
"""
# Prepend existing backlog if states align
if state_idx == from_traj.get_last_state_id() and from_traj.noise_backlog:
rest_traj.noise_backlog.extend(reversed(from_traj.noise_backlog))
# Separate snapshots from noise backlog
# Everything up to the last state is a full snapshot
rest_traj._snaps = from_traj._snaps[: state_idx + 1]
# Everything between last state and high_score_idx becomes noise backlog
for k in range(state_idx, high_idx + 1):
rest_traj.noise_backlog.append(from_traj._snaps[k].noise)
rest_traj.noise_backlog.reverse()
@staticmethod
def _finalize_branch(
ancestor_id: int, discarded_id: int, rest_traj: Trajectory[T_Noise, T_State], weight: float
) -> None:
"""Finalizes metadata, model state, and diagnostics.
Args:
ancestor_id: the unique ID of the ancestor traj
discarded_id: the unique ID of the discarded traj
rest_traj: the child trajectory
weight: the weight of the new child trajectory
"""
rest_traj._t_cur = rest_traj._snaps[-1].time
rest_traj._step = len(rest_traj._snaps) - 1
if rest_traj._fmodel:
rest_traj.setup_noise()
last_snap = rest_traj._snaps[-1]
if last_snap.state is not None:
rest_traj._fmodel.set_current_state(last_snap.state)
rest_traj.update_metadata()
rest_traj._fmodel.post_trajectory_branching_hook(rest_traj._step, rest_traj._t_cur)
if rest_traj._has_diagnostics:
rest_traj._branch_diagnostics(
ancestor_id, discarded_id, rest_traj.unique_id(), weight, rest_traj._score_max
)
[docs]
def store(self, traj_file: Path | None = None, write_metadata_json: bool = False) -> None:
"""Store the trajectory data to an XML chkfile.
The default behavior is to store the trajectory into the
file specified by the attribute self._checkFile unless a
different path is provided.
The metadata are not writen to file by default, as the pyREVS database
handle metadata in an SQL file. It can be triggered when using trajectories
in stand-alone.
Args:
traj_file: an optional path file to store the trajectory to
write_metadata_json: an optional boolean to also write the metadata json dict
"""
root = ET.Element(self.idstr())
root.append(dict_to_xml("params", asdict(self._traj_cfg)))
snaps_xml = ET.SubElement(root, "snapshots")
for k in range(len(self._snaps)):
snaps_xml.append(
make_xml_snapshot(
k,
self._snaps[k].time,
self._snaps[k].score,
self._snaps[k].noise,
self._snaps[k].state,
),
)
tree = ET.ElementTree(root)
ET.indent(tree, space="\t", level=0)
if traj_file is not None:
tree.write(traj_file.as_posix())
else:
tree.write(self._checkFile.as_posix())
# Separately dumps the metadata if requested
if write_metadata_json:
json_path = (
traj_file.parent / Path(traj_file.stem + ".json")
if traj_file is not None
else self._checkFile.parent / Path(self._checkFile.stem + ".json")
)
with json_path.open("w") as f:
f.write(self.serialize_metadata_json())
[docs]
def set_weight(self, weight: float) -> None:
"""Set the trajectory weight.
Args:
weight: the (new) trajectory weight in the ensemble
"""
self._weight = weight
[docs]
def set_current_time_and_step(self, time: float, step: int) -> None:
"""Set the current time and step."""
self._t_cur = time
self._step = step
[docs]
def current_time(self) -> float:
"""Return the current trajectory time."""
return self._t_cur
[docs]
def current_step(self) -> int:
"""Return the current trajectory step."""
return self._step
[docs]
def step_size(self) -> float:
"""Return the time step size."""
return self._dt
[docs]
def score_max(self) -> float:
"""Return the maximum of the score function."""
return self._score_max
[docs]
def is_converged(self) -> bool:
"""Return True for converged trajectory."""
return self._has_converged
[docs]
def is_terminated(self) -> bool:
"""Return True for terminated trajectory."""
return self._has_terminated
[docs]
def has_started(self) -> bool:
"""Return True if computation has started."""
return self._t_cur > 0.0
[docs]
def get_checkfile(self) -> Path:
"""Return the trajectory check file name."""
return self._checkFile
[docs]
def get_time_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory time instants."""
return np.array([snap.time for snap in self._snaps], dtype=np.float64)
[docs]
def get_score_array(self) -> npt.NDArray[np.number]:
"""Return the trajectory scores."""
return np.array([snap.score for snap in self._snaps], dtype=np.float64)
[docs]
def get_noise_array(self) -> npt.NDArray[Any]:
"""Return the trajectory noises."""
if not self._snaps:
return np.array([])
return np.array([snap.noise for snap in self._snaps])
[docs]
def get_state_list(self) -> list[tuple[int, T_State | None]]:
"""Return a list of states and associated indices.
Returns:
A list of tuples with index and states
"""
return [(k, self._snaps[k].state) for k in range(len(self._snaps)) if self._snaps[k].has_state]
[docs]
def get_length(self) -> int:
"""Return the trajectory length."""
return len(self._snaps)
[docs]
def get_nbranching(self) -> int:
"""Return the number of branching events."""
if len(self._branching_history) > self.max_nbranch:
err_msg = f"Branching history size exceeds maximum of {self.max_nbranch}!"
_logger.exception(err_msg)
raise RuntimeError(err_msg)
return len(self._branching_history)
[docs]
def get_computed_steps_count(self) -> int:
"""Return the number of compute steps taken."""
return self._computed_steps
[docs]
def get_last_state(self) -> T_State | None:
"""Return the last state in the trajectory."""
for snap in reversed(self._snaps):
if snap.has_state:
return snap.state
return None
[docs]
def get_last_state_id(self) -> int | None:
"""Return the id of the last state in the trajectory."""
for idx, snap in reversed(list(enumerate(self._snaps))):
if snap.has_state:
return idx
return None
@classmethod
[docs]
def delete(self) -> None:
"""Clear the trajectory on-disk data."""
self._checkFile.unlink()
if self._workdir.exists():
shutil.rmtree(self._workdir)