"""The (T)AMS sampling strategy."""
import logging
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
from pyrevs.core import Config
from pyrevs.core import RuntimeConfig
from pyrevs.database import Database
from pyrevs.database import DatabaseCoreSpec
from pyrevs.runner import RunnerConfig
from pyrevs.runner import make_runner
from pyrevs.runner import ms_worker
from pyrevs.runner import pool_worker
from pyrevs.strategies.base import BaseSamplingStrategy
from pyrevs.strategies.base import LowScoreTerminationCriterion
from pyrevs.strategies.base import TerminationCriterion
from pyrevs.strategies.base import TimeTerminationCriterion
from pyrevs.utils.utils import get_min_scored
from .config import AMSConfig
from .extension import AMSDatabaseExtension
_logger = logging.getLogger(__name__)
@BaseSamplingStrategy.register("ams")
[docs]
class AMS(BaseSamplingStrategy):
"""A strategy class implementing (T)AMS.
The interface to (T)AMS, implementing the main steps of
the algorithm.
Initialization of the (T)AMS class requires a forward model
type which encapsulate all the model-specific code, and
an optional list of options.
The algorithm is roughly divided in two steps:
1. Initialization of the trajectory ensemble
2. Splitting iterations
Separate control of the parallelism is provided for
both steps.
All the algorithm data are contained in the (T)AMS database.
For control purposes, a walltime limit is also provided. It is
passed to working and lead to the termination of the algorithm
in a state that can be saved to disk and restarted at a later stage.
Attributes:
_fmodel_t: the forward model type
_config: the config object
_wallTime: the walltime limit
_init_ensemble_only: whether or not to stop after initializing the trajectory ensemble
"""
def __init__(
self,
fmodel_t: Any,
runtime_cfg: RuntimeConfig,
runner_cfg: RunnerConfig,
strategy_cfg: AMSConfig,
deterministic: bool,
) -> None:
"""Initialize a AMS strategy.
Args:
fmodel_t: the forward model type
runtime_cfg: the runtime config dataclass
runner_cfg: the runner config dataclass
strategy_cfg: the AMS config dataclass
deterministic: the deterministic flag trigger deterministic runs
Raises:
ValueError: if necessary parameters are not found
"""
self._fmodel_t = fmodel_t
self._ams_cfg = strategy_cfg
self._ams_cfg.validate()
self._runner_cfg = runner_cfg
self._loglevel = runtime_cfg.loglevel
self._logfile = runtime_cfg.logfile
self._deterministic = deterministic
self._db_ext: AMSDatabaseExtension | None = None
self._term_crit: list[TerminationCriterion] = []
if strategy_cfg.variant == "ams":
# Check for None. TODO: validate() above does not satisfy mypy
if self._ams_cfg.min_score is None:
err_msg = "AMSConfig.min_score must be set for AMS"
_logger.exception(err_msg)
raise ValueError(err_msg)
self._term_crit.append(LowScoreTerminationCriterion(self._ams_cfg.min_score))
elif strategy_cfg.variant == "tams":
self._term_crit.append(TimeTerminationCriterion(self._ams_cfg.end_time))
elif strategy_cfg.variant == "hams":
# Check for None. TODO: validate() above does not satisfy mypy
if self._ams_cfg.min_score is None:
err_msg = "AMSConfig.min_score must be set for HAMS"
_logger.exception(err_msg)
raise ValueError(err_msg)
self._term_crit.append(TimeTerminationCriterion(self._ams_cfg.end_time))
self._term_crit.append(LowScoreTerminationCriterion(self._ams_cfg.min_score))
else:
err_msg = f"Unknown variant {self._ams_cfg.variant}"
_logger.exception(err_msg)
raise ValueError(err_msg)
def _req_db_ext(self) -> AMSDatabaseExtension:
if self._db_ext is None:
err_msg = "Database extension not initialized"
_logger.exception(err_msg)
raise RuntimeError(err_msg)
return self._db_ext
[docs]
def generate_trajectory_ensemble(self, tdb: Database) -> None:
"""Schedule the generation of an ensemble of stochastic trajectories.
Loop over all the trajectories in the database and schedule
advancing them to either end time or convergence with the
runner.
The runner will use the number of workers specified in the
input file under the runner section.
Args:
tdb: the AMS database
Raises:
Error if the runner fails
"""
inf_msg = f"Creating the initial ensemble of {tdb.n_traj()} trajectories"
_logger.info(inf_msg)
with make_runner(
self._runner_cfg,
pool_worker,
loglevel=self._loglevel,
logfile=self._logfile,
max_workers=self._ams_cfg.ntrajectories,
) as runner:
for t in tdb.traj_list():
task = [t, self._term_crit, self._end_date, tdb.pool_file(), tdb.path()]
runner.make_promise(task)
try:
t_list = runner.execute_promises()
except Exception as exc:
err_msg = f"Failed to generate the initial ensemble of {tdb.n_traj()} trajectories"
_logger.exception(err_msg)
raise RuntimeError(err_msg) from exc
# Re-order list since runner does not guarantee order
# And update list of trajectories in the database
t_list.sort(key=lambda t: t.id())
tdb.update_traj_list(t_list)
if tdb.count_terminated_traj() == tdb.n_traj():
self._req_db_ext().set_init_ensemble_flag(True)
inf_msg = f"Run time: {self.elapsed_time()} s"
_logger.info(inf_msg)
[docs]
def check_exit_splitting_loop(self, tdb: Database, k: int) -> tuple[bool, npt.NDArray[np.number]]:
"""Check for exit criterion of the splitting loop.
Args:
tdb: the AMS database
k: loop counter
Returns:
bool to trigger splitting loop break
array of maximas across all trajectories
"""
# Gather max score from all trajectories
# and check for early convergence
all_converged = True
maxes = np.zeros(tdb.traj_list_len(), dtype=float)
for i in range(tdb.traj_list_len()):
maxes[i] = tdb.get_traj(i).score_max()
all_converged = all_converged and tdb.get_traj(i).is_converged()
# Check for walltime
if self.out_of_time():
warn_msg = f"Ran out of time after {k} splitting iterations"
_logger.warning(warn_msg)
return True, maxes
# Exit if our work is done
if all_converged:
inf_msg = f"All trajectories converged after {k} splitting iterations"
_logger.info(inf_msg)
return True, maxes
# Exit if splitting is stalled
if (np.amax(maxes) - np.amin(maxes)) < STALL_TOL:
err_msg = f"Splitting is stalling with all trajectories stuck at a score_max: {np.amax(maxes)}"
_logger.exception(err_msg)
raise RuntimeError(err_msg)
return False, maxes
[docs]
def finish_ongoing_splitting(self, tdb: Database) -> None:
"""Check and finish unfinished splitting iterations.
If the run was interrupted during a splitting iteration,
the branched trajectories might not have terminated yet. In that case,
a list of trajectories to finish is listed in the database.
"""
# Check the database for unfinished splitting iteration when restarting.
# At this point, branching has been done, but advancing to final
# time is still ongoing.
ongoing_list = self._req_db_ext().get_ongoing()
if ongoing_list:
inf_msg = f"Unfinished splitting iteration detected, traj {ongoing_list} need(s) finishing"
_logger.info(inf_msg)
with make_runner(
self._runner_cfg,
pool_worker,
loglevel=self._loglevel,
logfile=self._logfile,
max_workers=self._ams_cfg.l_j,
) as runner:
for i in ongoing_list:
t = tdb.get_traj(i)
task = [t, self._term_crit, self._end_date, tdb.pool_file(), tdb.path()]
runner.make_promise(task)
try:
finished_traj = runner.execute_promises()
except Exception as exc:
err_msg = f"Failed to finish branching {len(ongoing_list)} trajectories"
_logger.exception(err_msg)
raise RuntimeError(err_msg) from exc
_logger.info("Done with unfinished")
for t in finished_traj:
tdb.overwrite_traj(t.id(), t)
# Wrap up the iteration by updating its status in the
# database and incrementing the iteration counter
self._req_db_ext().mark_last_splitting_iteration_as_done()
[docs]
def get_restart_at_random(self, tdb: Database, min_idx_list: list[int]) -> list[int]:
"""Get a list of trajectory index to restart from at random.
Select trajectories to restart from among the ones not
in min_idx_list.
Args:
tdb: the AMS database
min_idx_list: list of trajectory index to restart from
Returns:
list of trajectory index to restart from
"""
# Enable deterministic runs by setting a (different) seed
# for each splitting iteration
rng = (
np.random.default_rng(seed=42 * self._req_db_ext().k_split())
if self._deterministic
else np.random.default_rng()
)
rest_idx = [-1] * len(min_idx_list)
for i in range(len(min_idx_list)):
rest_idx[i] = min_idx_list[0]
while rest_idx[i] in min_idx_list:
rest_idx[i] = rng.integers(low=0, high=tdb.traj_list_len(), dtype=int)
return rest_idx
[docs]
def do_multilevel_splitting(self, tdb: Database, plot_diags: bool) -> None:
"""Schedule splitting of the initial ensemble of stochastic trajectories.
Perform the multi-level splitting iterations, possibly restarting multiple
trajectories at each iterations. All the trajectories in an iterations are
advanced together, such that each iteration takes the maximum duration among
the branched trajectories.
If the walltime is exceeded, the splitting loop is stopped and ongoing
trajectories are flagged in the database in order to finish them upon
restart.
The runner will use the number of workers specified in the
input file under the runner section.
Args:
tdb: the AMS database
plot_diags: whether or not to plot diagnostics
Raises:
Error if the runner fails
"""
inf_msg = "Using multi-level splitting to get the probability"
_logger.info(inf_msg)
# Finish any unfinished splitting iteration
self.finish_ongoing_splitting(tdb)
# Initialize splitting iterations counter
k = self._req_db_ext().k_split()
with make_runner(
self._runner_cfg,
ms_worker,
loglevel=self._loglevel,
logfile=self._logfile,
max_workers=self._ams_cfg.l_j,
) as runner:
while k < self._ams_cfg.nsplititer:
inf_msg = f"Starting AMS iter. {k} with {runner.n_workers()} workers"
_logger.info(inf_msg)
# Plot trajectory database scores
if plot_diags:
pltfile = f"Score_k{k:05}.png"
if Path(pltfile).exists():
wrn_msg = f"Attempting to overwrite the plot file {pltfile}"
_logger.warning(wrn_msg)
tdb.plot_score_functions(pltfile)
# Get the ensemble maximums and check for early exit conditions
early_exit, maxes = self.check_exit_splitting_loop(tdb, k)
# Get the nworker lower scored trajectories
# or more if equal score
min_idx_list, min_vals = get_min_scored(maxes, runner.n_workers())
# Randomly select trajectory to branch from
ancestor_idx = self.get_restart_at_random(tdb, min_idx_list)
n_branch = len(min_idx_list)
# Update the database with the data of the current
# iteration
self._req_db_ext().append_splitting_iteration_data(
k, n_branch, min_idx_list, ancestor_idx, min_vals.tolist(), [np.min(maxes), np.max(maxes)]
)
# Query the current iteration weight
# to compute the individual weight of each trajectory in the ensemble
# at the end of the splitting iteration
new_traj_weight = self._req_db_ext().weights()[-1]
# Exit the loop if needed
if early_exit:
# If AMS converged, final update of the weights.
if tdb.all_converged():
self._req_db_ext().update_trajectories_weights()
break
# Assemble a list of promises
# and archive the discarded trajectories
for i in range(n_branch):
# Archive
tdb.archive_trajectory(tdb.get_traj(min_idx_list[i]))
# Worker task
task = [
tdb.get_traj(ancestor_idx[i]),
tdb.get_traj(min_idx_list[i]),
np.max(min_vals),
new_traj_weight,
self._term_crit,
self._end_date,
tdb.pool_file(),
tdb.path(),
]
runner.make_promise(task)
try:
restarted_trajs = runner.execute_promises()
except Exception as exc:
err_msg = f"Failed to branch {n_branch} trajectories at iteration {k}"
_logger.exception(err_msg)
raise RuntimeError(err_msg) from exc
# Update the trajectories in the database
for t in restarted_trajs:
tdb.overwrite_traj(t.id(), t)
# Update the weights of all trajectories in the ensemble with the current
# iteration weight
self._req_db_ext().update_trajectories_weights()
if self.out_of_time():
# Save splitting data with ongoing trajectories
# but do not increment splitting index yet
warn_msg = f"Ran out of time after {k} splitting iterations"
_logger.warning(warn_msg)
break
# Wrap up the iteration by updating its status in the
# database and incrementing the iteration counter
self._req_db_ext().mark_last_splitting_iteration_as_done()
k = k + n_branch
[docs]
def compute_probability(self, tdb: Database, plot_diags: bool) -> float:
"""Compute the probability using AMS.
Args:
tdb: the AMS database
plot_diags: whether or not to plot diagnostics
Returns:
the rare-event probability
"""
inf_msg = f"Computing {self._fmodel_t.name()} rare event probability using AMS"
_logger.info(inf_msg)
# Generate the initial trajectory ensemble
init_ensemble_need_work = not self._req_db_ext().init_ensemble_done()
if init_ensemble_need_work:
self.generate_trajectory_ensemble(tdb)
# Check for early convergence
all_converged = True
for t in tdb.traj_list():
if not t.is_converged():
all_converged = False
break
if init_ensemble_need_work and all_converged:
inf_msg = "All trajectories in the ensemble converged prior to splitting !"
_logger.info(inf_msg)
return 1.0
if self.out_of_time():
warn_msg = "Ran out of walltime ! Exiting now."
_logger.warning(warn_msg)
return -1.0
if self._ams_cfg.init_ensemble_only:
warn_msg = "Stopping after the initial ensemble stage !"
_logger.warning(warn_msg)
return -1.0
# Perform multilevel splitting
if not all_converged:
self.do_multilevel_splitting(tdb, plot_diags)
if self.out_of_time():
warn_msg = "Ran out of walltime ! Exiting now."
_logger.warning(warn_msg)
return -1.0
rare_event_probability = tdb.get_event_probability()
inf_msg = f"Run time: {self.elapsed_time()} s"
_logger.info(inf_msg)
# Load the archived trajectories data since the workers
# discarded them but the persistent Python process did not
# kept track
tdb.load_archived_trajectories()
tdb.info()
return rare_event_probability
def _execute_sampling(self, database: Database, plot_diags: bool) -> None:
"""Shallow wrapper to enable sampler."""
database.load_data()
# Initialize the active ensemble
database.init_active_ensemble()
# Check that the database extension is initialized
if self._db_ext is None:
err_msg = "Database extension is not initialized ! Call initialize_database_schema() first"
_logger.exception(err_msg)
raise RuntimeError(err_msg)
proba = self.compute_probability(database, plot_diags)
inf_msg = f"Event probability: {proba}"
_logger.info(inf_msg)
[docs]
def initialize_database_schema(self, database: Database, diag_configs: dict[str, Config] | None) -> None:
"""Initialize database core state."""
spec = DatabaseCoreSpec(
ntraj=self._ams_cfg.ntrajectories,
strategy="ams",
deterministic=self._deterministic,
diag_configs=diag_configs,
)
database.initialize_core_state(spec)
# Setup the AMS database extension
self._db_ext = AMSDatabaseExtension()
self._db_ext.initialize(self._ams_cfg.nsplititer, database)
database.attach_extension(self._db_ext)