Source code for pyrevs.strategies.ams.extension

"""An extension class for the AMS strategy."""

import json
import logging
from pathlib import Path
from typing import TypeVar
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from pyrevs.core import CoreDB
from pyrevs.database import Database
from pyrevs.database import StrategyDatabaseExtension
from pyrevs.trajectory import Trajectory
from .sql import AMSDB

_logger = logging.getLogger(__name__)

[docs] T_Noise = TypeVar("T_Noise")
[docs] T_State = TypeVar("T_State")
[docs] class AMSDatabaseExtension(StrategyDatabaseExtension): """An extension class for the AMS strategy. Attributes: _nsplitting: maximum number of splitting iterations _ams_db: an instance of AMSDB, extending the SQL database """ def __init__(self) -> None: self._nsplitting: int = -1 self._ams_db: AMSDB | None = None self._init_ensemble_done: bool = False
[docs] def initialize(self, nsplitting: int, tdb: Database) -> None: """Initialize the AMS database extension. Args: nsplitting: maximum number of splitting iterations tdb: the core trajectory database """ self._nsplitting = nsplitting self._tdb = tdb self._ams_db = AMSDB(self._req_tdb_db().engine())
[docs] def initialize_from_database(self, tdb: Database) -> None: """Initialize the AMS database extension. Args: tdb: the core trajectory database """ self._tdb = tdb self._ams_db = AMSDB(self._req_tdb_db().engine()) self.deserialize()
[docs] def serialize(self) -> None: """Serialize the extension.""" spath = Path(self._tdb.name()) / "ams_metadata.json" data = { "nsplitting": self._nsplitting, "init_ensemble_done": self._init_ensemble_done, } with spath.open("w") as f: json.dump(data, f, indent=2)
[docs] def deserialize(self) -> None: """Serialize the extension.""" spath = Path(self._tdb.name()) / "ams_metadata.json" with spath.open("r") as f: data = json.load(f) self._nsplitting = data["nsplitting"] self._init_ensemble_done = data["init_ensemble_done"]
def _req_ams_db(self) -> AMSDB: if self._ams_db is None: err_msg = "AMSDB has not been initialized ! Call initialize()" _logger.exception(err_msg) raise RuntimeError(err_msg) return self._ams_db def _req_tdb_db(self) -> CoreDB: pool_db = self._tdb.get_pool_db() if pool_db is None: err_msg = "CoreDB has not been initialized ! Database is not ready" _logger.exception(err_msg) raise RuntimeError(err_msg) return pool_db
[docs] def k_split(self) -> int: """Get the current splitting iteration index. The current splitting iteration index is equal to the ksplit + bias (number of branching event in the last iteration) entries of last entry in the SQL db table Returns: Internal splitting iteration index """ return self._req_ams_db().get_k_split()
[docs] def init_ensemble_done(self) -> bool: """Get the initial ensemble status flag. Returns: the flag indicating that the initial ensemble is finished """ return self._init_ensemble_done
[docs] def set_init_ensemble_flag(self, status: bool) -> None: """Change the initial ensemble status flag. Args: status: the new status """ self._init_ensemble_done = status if self._tdb.to_disk(): self.serialize()
[docs] def get_ongoing(self) -> list[int] | None: """Get the list of ongoing trajectories if any. Returns: Either a list trajectories or None if nothing was left to do """ return self._req_ams_db().get_ongoing()
[docs] def weights(self) -> npt.NDArray[np.number]: """Splitting iterations weights.""" return self._req_ams_db().get_weights()
[docs] def mark_last_splitting_iteration_as_done(self) -> None: """Flag the last splitting iteration as done.""" self._req_ams_db().mark_last_iteration_as_completed()
[docs] def append_splitting_iteration_data( self, ksplit: int, bias: int, discarded_ids: list[int], ancestor_ids: list[int], min_vals: list[float], min_max: list[float], ) -> None: """Append a set of splitting data to internal list. Args: ksplit : The splitting iteration index bias : The number of restarted trajectories, also ref. to as bias discarded_ids : The list of discarded trajectory ids ancestor_ids : The list of trajectories used to restart (ancestors) min_vals : The list of minimum values min_max : The score minimum and maximum values Raises: ValueError if the provided ksplit is incompatible with the db state """ # Compute the weight of the ensemble at the current iteration # Insert 1.0 at the front of the weight array weights = np.insert(self._req_ams_db().get_weights(), 0, 1.0) new_weight = weights[-1] * (1.0 - float(bias) / float(self._tdb.n_traj())) # Check the splitting iteration index. If the incoming split is not # equal to the one in the database, something is wrong. if ksplit != self.k_split(): self._req_ams_db().dump_file_json() err_msg = f"Attempting to add splitting iteration with splitting index {ksplit} \ incompatible with the last entry of the database {self.k_split()} !" _logger.exception(err_msg) raise ValueError(err_msg) # Check that the new min of maxes is larger than # at the previous step self._req_ams_db().check_new_min_of_maxes(min_max[0]) self._req_ams_db().add_splitting_data(ksplit, bias, new_weight, discarded_ids, ancestor_ids, min_vals, min_max)
[docs] def update_splitting_iteration_data( self, ksplit: int, bias: int, discarded_ids: list[int], ancestor_ids: list[int], min_vals: list[float], min_max: list[float], ) -> None: """Update the last set of splitting data to internal list. Args: ksplit : The splitting iteration index bias : The number of restarted trajectories, also ref. to as bias discarded_ids : The list of discarded trajectory ids ancestor_ids : The list of trajectories used to restart (ancestors) min_vals : The list of minimum values min_max : The score minimum and maximum values Raises: ValueError if the provided ksplit is incompatible with the db state """ # Compute the weight of the ensemble at the current iteration # Insert 1.0 at the front of the weight array weights = np.insert(self._req_ams_db().get_weights(), 0, 1.0) new_weight = weights[-1] * (1.0 - bias / self._tdb.n_traj()) # Check the splitting iteration index. If the incoming split is not # equal to the one in the database, something is wrong. if (ksplit + bias) != self.k_split(): self._req_ams_db().dump_file_json() err_msg = f"Attempting to update splitting iteration with splitting index {ksplit + bias} \ incompatible with the last entry of the database {self.k_split()} !" _logger.exception(err_msg) raise ValueError(err_msg) self._req_ams_db().update_splitting_data( ksplit, bias, new_weight, discarded_ids, ancestor_ids, min_vals, min_max )
[docs] def update_trajectories_weights(self) -> None: """Update the weights of all the trajectories. Using the the current splitting iteration weight. """ tweight = self.weights()[-1] for t in self._tdb.traj_list(): t.set_weight(float(tweight)) if self._tdb.to_disk(): self._req_tdb_db().update_trajectory_weight(t.id(), float(tweight)) # Update the diagDB in the core database self._tdb.update_diagnostic_weights(tweight)
[docs] def get_event_probability(self) -> float: """Return the event probability.""" if self._tdb.count_terminated_traj() < self._tdb.n_traj(): return 0.0 # Insert a first element to the weight array weights = np.insert(self._req_ams_db().get_weights(), 0, 1.0) biases = self._req_ams_db().get_biases() w = self._tdb.n_traj() * weights[-1] for i in range(biases.shape[0]): w += biases[i] * weights[i] return float(self._tdb.count_converged_traj() * weights[-1] / w)
[docs] def plot_min_max_span(self, fname: str | None = None) -> None: """Plot the evolution of the ensemble min/max during iterations.""" pltfile = fname if fname else Path(self._tdb.name()).stem + "_minmax.png" plt.figure(figsize=(6, 4)) min_max_data = self._req_ams_db().get_minmax() plt.plot(min_max_data[:, 0], min_max_data[:, 1], linewidth=1.0, label="min of maxes") plt.plot(min_max_data[:, 0], min_max_data[:, 2], linewidth=1.0, label="max of maxes") plt.grid(linestyle="dotted") ax = plt.gca() ax.set_ylim(0.0, 1.0) ax.set_xlim(0.0, np.max(min_max_data[:, 0])) ax.legend() plt.tight_layout() plt.savefig(pltfile, dpi=300) plt.clf() plt.close()
def _get_location_and_indices_at_k(self, k_in: int) -> list[tuple[str, int]]: """Return the location and indices of active trajectory at k_in. Location here can be either 'active' or "archive" depending on whether the trajectory we are interested in is still in the current active list or in the archived list. Args: k_in : the index of the splitting iteration Returns: A list of tuple with the location and index of each trajectory active at iteration k """ # Initialize active @k list with current active list # For now handle tuple with (active/archived, index) # The actual trajectory list will be filled later active_list_index = [("active", i) for i in range(self._tdb.n_traj())] # Traverse in reverse the splitting iteration table idx_in_archive = self._tdb.archived_traj_list_len() - 1 for k in range(self._req_ams_db().get_iteration_count() - 1, k_in - 1, -1): splitting_data = self._req_ams_db().fetch_splitting_data(k) if splitting_data: _, nbranch, _, discarded, _, _, _, status = splitting_data if status == "locked": continue for discarded_idx in discarded: for i in range(idx_in_archive, idx_in_archive - nbranch, -1): if self._tdb.archived_traj_list()[i].id() == discarded_idx: active_list_index[discarded_idx] = ("archive", i) idx_in_archive = idx_in_archive - nbranch return active_list_index
[docs] def get_trajectory_active_at_k(self, k_in: int) -> list[Trajectory[T_Noise, T_State]]: """Return the list of trajectory active at a given splitting iteration. To explore the ensemble evolution during splitting iterations, it is useful to reconstruct the list of active trajectories at the beginning of any given splitting iteration. Note that k here is not the splitting index, but the iteration index. Since more than one child can be spawned at each splitting iteration, the two might differ. Args: k_in : the index of the splitting iteration Returns: The list of trajectories active at the beginning of iteration k """ # Check that the requested index is available in the database if k_in >= self._req_ams_db().get_iteration_count(): err_msg = ( f"Attempting to read splitting iteration {k_in} data" f"larger than stored data {self._req_ams_db().get_iteration_count()}" ) _logger.exception(err_msg) raise ValueError(err_msg) # Check that archived trajectories are stored if self._tdb.archived_traj_list_len() == 0: err_msg = "Cannot reconstruct active set without stored archives !" _logger.exception(err_msg) raise RuntimeError(err_msg) # First get the location and indices of the trajectories # active at iteration k active_list_index = self._get_location_and_indices_at_k(k_in) # Retrieve the trajectories from the active/archived lists active_list_at_k = [] for location, idx in active_list_index: if location == "active": active_list_at_k.append(self._tdb.traj_list()[idx]) elif location == "archive": active_list_at_k.append(self._tdb.archived_traj_list()[idx]) return active_list_at_k