Source code for pytams.database

"""A database class for TAMS."""

from __future__ import annotations
import copy
import datetime
import logging
import shutil
import sys
import xml.etree.ElementTree as ET
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
import cloudpickle
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import toml
from pytams.sqldb import SQLFile
from pytams.trajectory import Trajectory
from pytams.trajectory import form_trajectory_id
from pytams.xmlutils import new_element
from pytams.xmlutils import xml_to_dict

if TYPE_CHECKING:
    from pytams.fmodel import ForwardModelBaseClass

_logger = logging.getLogger(__name__)


[docs] class DatabaseError(Exception): """Exception class for TAMS Database."""
[docs] class Database: """A database class for TAMS. The database class for TAMS is a container for all the trajectory and splitting data. When the user provides a path to store the database, a local folder is created holding a number of readable files, any output from the model and an SQL file used to lock/release trajectories as the TAMS algorithm proceeds. The readable files are currently in an XML format. A database can be loaded independently from the TAMS algorithm and used for post-processing. Attributes: _fmodel_t: the forward model type _save_to_disk: boolean to trigger saving the database to disk _path: a path to an existing database to restore or a new path _restart: a bool to override an existing database _parameters: the dictionary of parameters _trajs_db: the list of trajectories _ksplit: the current splitting iteration _l_bias: the list of bias _weights: the list of weights _ongoing: the list of ongoing branches if unfinished splitting iteration. """ def __init__( self, fmodel_t: type[ForwardModelBaseClass], params: dict[Any, Any], ntraj: int | None = None, nsplititer: int | None = None, ) -> None: """Initialize a TAMS database. Initialize TAMS database object, bare in-memory or on-disk. On-disk database trigger if a path is provided in the parameters dictonary. The user can chose to not append/override the existing database in which case the existing path will be copied to a new random name. Args: fmodel_t: the forward model type params: a dictionary of parameters ntraj: [OPT] number of traj to hold nsplititer: [OPT] number of splitting iteration to hold """ self._fmodel_t = fmodel_t # Metadata self._save_to_disk = False self._parameters = params self._name = "TAMS_" + fmodel_t.name() self._path: str | None = params.get("database", {}).get("path", None) if self._path: self._save_to_disk = True self._restart = params.get("database", {}).get("restart", False) self._format = params.get("database", {}).get("format", "XML") if self._format not in ["XML"]: err_msg = f"Unsupported TAMS database format: {self._format} !" _logger.error(err_msg) raise DatabaseError(err_msg) self._name = f"{self._path}" self._abs_path: Path = Path.cwd() / self._name self._creation_date = datetime.datetime.now(tz=datetime.timezone.utc) self._version = version(__package__) self._store_archive = params.get("database", {}).get("archive_discarded", False) # Trajectory pools self._trajs_db: list[Trajectory] = [] self._pool_db: SQLFile | None = None self._archived_trajs_db: list[Trajectory] = [] # Splitting data self._ksplit = 0 self._l_bias: list[int] = [] self._weights: list[float] = [1.0] self._minmax: list[npt.NDArray[np.number]] = [] self._ongoing = None # Initialize only metadata at this point # so that the object remains lightweight self._init_metadata(ntraj, nsplititer)
[docs] def n_traj(self) -> int: """Return the number of trajectory used for TAMS. Note that this is the requested number of trajectory, not the current length of the trajectory pool. Return: number of trajectory """ return self._ntraj
[docs] def n_split_iter(self) -> int: """Return the number of splitting iteration used for TAMS. Note that this is the requested number of splitting iteration, not the current splitting iteration. Return: number of splitting iteration """ return self._nsplititer
[docs] def path(self) -> str | None: """Return the path to the database.""" return self._path
@classmethod
[docs] def load(cls, a_path: Path) -> Database: """Instanciate a TAMS database from disk. Args: a_path: the path to the database Return: a TAMS database object """ if not a_path.exists(): err_msg = f"Database {a_path} does not exist !" _logger.error(err_msg) raise DatabaseError(err_msg) # Load necessary elements to call the constructor # Ensure that the database is not restarted at this step db_params = toml.load(a_path / "input_params.toml") db_params["database"].update({"restart": False}) # If the a_path differs from the one stored in the # database (the DB has been moved), update the path if str(a_path) != db_params["database"]["path"]: warn_msg = f"Database {db_params['database']['path']} has been moved to {a_path} !" _logger.warning(warn_msg) db_params["database"]["path"] = str(a_path) model_file = Path(a_path / "fmodel.pkl") with model_file.open("rb") as f: model = cloudpickle.load(f) return cls(model, db_params)
def _init_metadata(self, ntraj: int | None = None, nsplititer: int | None = None) -> None: """Initialize the database. Initialize database internal metadata (only) and setup the database on disk if needed. Args: ntraj: [OPT] number of traj to hold nsplititer: [OPT] number of splitting iteration to hold """ # Initialize or load disk-based database metadata if self._save_to_disk: # Check for an existing database: db_exists = self._abs_path.exists() # Regardless of a pre-existing path we initialize from scratch if not db_exists or self._restart: if not ntraj: err_msg = "Initializing TAMS database from scratch require ntraj !" _logger.error(err_msg) raise DatabaseError(err_msg) if not nsplititer: err_msg = "Initializing TAMS database from scratch require nsplititer !" _logger.error(err_msg) raise DatabaseError(err_msg) self._ntraj = ntraj self._nsplititer = nsplititer self._setup_tree() # Load the database else: self._load_metadata() # Parameters stored in the DB override # newly provided parameters. with Path(self._abs_path / "input_params.toml").open("r") as f: read_in_params = toml.load(f) self._parameters.update(read_in_params) # Initialize in-memory database metadata else: if not ntraj: err_msg = "Initializing TAMS database from scratch require ntraj !" _logger.error(err_msg) raise DatabaseError(err_msg) if not nsplititer: err_msg = "Initializing TAMS database from scratch require nsplititer !" _logger.error(err_msg) raise DatabaseError(err_msg) self._ntraj = ntraj self._nsplititer = nsplititer def _setup_tree(self) -> None: """Initialize the trajectory database tree.""" if self._save_to_disk: if self._abs_path.exists(): rng = np.random.default_rng(12345) copy_exists = True while copy_exists: random_int = rng.integers(0, 999999) path_rnd = Path.cwd() / f"{self._name}_{random_int:06d}" copy_exists = path_rnd.exists() warn_msg = f"Database {self._name} already present. It will be copied to {path_rnd.name}" _logger.warning(warn_msg) shutil.move(self._name, path_rnd.absolute()) Path(self._name).mkdir() # Save the runtime options with Path(self._abs_path / "input_params.toml").open("w") as f: toml.dump(self._parameters, f) # Header file with metadata and pool DB self._write_metadata() # Serialize the model model_file = Path(self._abs_path / "fmodel.pkl") cloudpickle.register_pickle_by_value(sys.modules[self._fmodel_t.__module__]) with model_file.open("wb") as f: cloudpickle.dump(self._fmodel_t, f) # Empty trajectories subfolder Path(self._abs_path / "trajectories").mkdir(parents=True) def _write_metadata(self) -> None: """Write the database Metadata to disk.""" if self._format == "XML": header_file = self.header_file() root = ET.Element("header") mdata = ET.SubElement(root, "metadata") mdata.append(new_element("pyTAMS_version", version(__package__))) mdata.append(new_element("date", self._creation_date)) mdata.append(new_element("model_t", self._fmodel_t.name())) mdata.append(new_element("ntraj", self._ntraj)) mdata.append(new_element("nsplititer", self._nsplititer)) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(header_file) # Initialialize splitting data file self.save_splitting_data() # Initialize the SQL pool file self._pool_db = SQLFile(self.pool_file()) else: err_msg = f"Unsupported TAMS database format: {self._format} !" _logger.error(err_msg) raise DatabaseError(err_msg) def _load_metadata(self) -> None: """Read the database Metadata from the header.""" if self._save_to_disk: if self._format == "XML": tree = ET.parse(self.header_file()) root = tree.getroot() mdata = root.find("metadata") datafromxml = xml_to_dict(mdata) self._ntraj = datafromxml["ntraj"] self._nsplititer = datafromxml["nsplititer"] self._version = datafromxml["pyTAMS_version"] if self._version != version(__package__): warn_msg = f"Database pyTAMS version {self._version} is different from {version(__package__)}" _logger.warning(warn_msg) self._creation_date = datafromxml["date"] db_model = datafromxml["model_t"] if db_model != self._fmodel_t.name(): err_msg = f"Database model {db_model} is different from call {self._fmodel_t.name()}" _logger.error(err_msg) raise DatabaseError(err_msg) # Initialize the SQL pool file self._pool_db = SQLFile(self.pool_file()) else: err_msg = f"Unsupported TAMS database format: {self._format} !" _logger.error(err_msg) raise DatabaseError(err_msg)
[docs] def init_pool(self) -> None: """Initialize the requested number of trajectories.""" for n in range(self._ntraj): workdir = Path(self._abs_path / f"trajectories/{form_trajectory_id(n)}") if self._save_to_disk else None t = Trajectory( traj_id=n, fmodel_t=self._fmodel_t, parameters=self._parameters, workdir=workdir, ) self.append_traj(t, True)
[docs] def save_trajectory(self, traj: Trajectory) -> None: """Save a trajectory to disk in the database. Args: traj: the trajectory to save """ if not self._save_to_disk: return traj.store()
[docs] def save_splitting_data(self, ongoing_trajs: list[int] | None = None) -> None: """Write splitting data to the database. Args: ongoing_trajs: an optional list of ongoing trajectories """ if not self._save_to_disk: return # Splitting data file if self._format == "XML": splitting_data_file = f"{self._name}/splittingData.xml" root = ET.Element("splitting") root.append(new_element("nsplititer", self._nsplititer)) root.append(new_element("ksplit", self._ksplit)) root.append(new_element("bias", np.array(self._l_bias, dtype=int))) root.append(new_element("weight", np.array(self._weights, dtype=float))) root.append(new_element("minmax", np.array(self._minmax, dtype=float))) if ongoing_trajs: root.append(new_element("ongoing", np.array(ongoing_trajs))) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(splitting_data_file) else: err_msg = f"Unsupported TAMS database format: {self._format} !" _logger.error(err_msg) raise DatabaseError(err_msg)
def _read_splitting_data(self) -> None: """Read splitting data.""" # Read data file if self._format == "XML": splitting_data_file = f"{self._name}/splittingData.xml" tree = ET.parse(splitting_data_file) root = tree.getroot() datafromxml = xml_to_dict(root) self._nsplititer = datafromxml["nsplititer"] self._ksplit = datafromxml["ksplit"] self._l_bias = datafromxml["bias"].tolist() self._weights = datafromxml["weight"].tolist() self._minmax = list(np.reshape(datafromxml["minmax"], [3, -1], order="F").T) if "ongoing" in datafromxml: self._ongoing = datafromxml["ongoing"].tolist() else: err_msg = f"Unsupported TAMS database format: {self._format} !" _logger.error(err_msg) raise DatabaseError(err_msg)
[docs] def load_data(self, load_archived_trajectories: bool = False) -> None: """Load data stored into the database. The initialization of the database only populate the metadata but not the full trajectories data. Args: load_archived_trajectories: whether to load archived trajectories """ if not self._save_to_disk: return if not self._pool_db: err_msg = "Database is not initialized !" _logger.exception(err_msg) raise DatabaseError(err_msg) # Counter for number of trajectory loaded n_traj_restored = 0 ntraj_in_db = self._pool_db.get_trajectory_count() for n in range(ntraj_in_db): traj_checkfile = Path(self._abs_path) / self._pool_db.fetch_trajectory(n) workdir = Path(self._abs_path / f"trajectories/{traj_checkfile.stem}") if traj_checkfile.exists(): n_traj_restored += 1 self.append_traj( Trajectory.restore_from_checkfile( traj_checkfile, fmodel_t=self._fmodel_t, parameters=self._parameters, workdir=workdir, ), False, ) else: t = Trajectory( traj_id=n, fmodel_t=self._fmodel_t, parameters=self._parameters, workdir=workdir, ) self.append_traj(t, False) inf_msg = f"{n_traj_restored} trajectories loaded" _logger.info(inf_msg) # Load splitting data self._read_splitting_data() # Load the archived trajectories if requested. # Those are loaded as 'frozen', i.e. the internal model # is not available and advance function disabled. if load_archived_trajectories: archived_ntraj_in_db = self._pool_db.get_archived_trajectory_count() for n in range(archived_ntraj_in_db): traj_checkfile = Path(self._abs_path) / self._pool_db.fetch_archived_trajectory(n) if traj_checkfile.exists(): self._archived_trajs_db.append( Trajectory.restore_from_checkfile( traj_checkfile, fmodel_t=self._fmodel_t, parameters=self._parameters, workdir=None, frozen=True, ), ) self.info()
[docs] def name(self) -> str: """Accessor to DB name. Return: DB name """ return self._name
[docs] def append_traj(self, a_traj: Trajectory, update_db: bool) -> None: """Append a Trajectory to the internal list. Args: a_traj: the trajectory update_db: True to update the SQL DB content """ # Also adds it to the SQL pool file. # and set the checkfile if self._save_to_disk and self._pool_db: checkfile_str = f"./trajectories/{a_traj.idstr()}.xml" checkfile = Path(self._abs_path) / checkfile_str a_traj.set_checkfile(checkfile) if update_db: self._pool_db.add_trajectory(checkfile_str) self._trajs_db.append(a_traj)
[docs] def traj_list(self) -> list[Trajectory]: """Access to the trajectory list. Return: Trajectory list """ return self._trajs_db
[docs] def get_traj(self, idx: int) -> Trajectory: """Access to a given trajectory. Args: idx: the index Return: Trajectory Raises: ValueError if idx is out of range """ if idx < 0 or idx >= len(self._trajs_db): err_msg = f"Trying to access a non existing trajectory {idx} !" _logger.error(err_msg) raise ValueError(err_msg) return self._trajs_db[idx]
[docs] def overwrite_traj(self, idx: int, traj: Trajectory) -> None: """Deep copy a trajectory into internal list. Args: idx: the index of the trajectory to override traj: the new trajectory Raises: ValueError if idx is out of range """ if idx < 0 or idx >= len(self._trajs_db): err_msg = f"Trying to override a non existing trajectory {idx} !" _logger.error(err_msg) raise ValueError(err_msg) self._trajs_db[idx] = copy.deepcopy(traj)
[docs] def header_file(self) -> str: """Helper returning the DB header file. Return: Header file """ return f"{self._name}/header.xml"
[docs] def pool_file(self) -> str: """Helper returning the DB trajectory pool file. Return: Pool file """ return f"{self._name}/trajPool.db"
[docs] def is_empty(self) -> bool: """Check if list of trajectories is empty. Return: True if the list of trajectories is empty """ return self.traj_list_len() == 0
[docs] def traj_list_len(self) -> int: """Length of the trajectory list. Return: Trajectory list length """ return len(self._trajs_db)
[docs] def archived_traj_list_len(self) -> int: """Length of the archived trajectory list. Return: Trajectory list length """ if not self._store_archive: return 0 return len(self._archived_trajs_db)
[docs] def update_traj_list(self, a_traj_list: list[Trajectory]) -> None: """Overwrite the trajectory list. Args: a_traj_list: the new trajectory list """ self._trajs_db = a_traj_list
[docs] def archive_trajectory(self, traj: Trajectory) -> None: """Archive a trajectory about to be discarded. Args: traj: the trajectory to archive """ if not self._store_archive: return # A branched trajectory will be overwritten by the # newly generated one in-place in the _trajs_db list. self._archived_trajs_db.append(traj) # Update the list of archived trajectories in the SQL DB if self._save_to_disk and self._pool_db: checkfile_str = traj.get_checkfile().relative_to(self._abs_path).as_posix() self._pool_db.archive_trajectory(checkfile_str)
[docs] def lock_trajectory(self, tid: int, allow_completed_lock: bool = False) -> bool: """Lock a trajectory in the SQL DB. Args: tid: the trajectory id allow_completed_lock: True if the trajectory can be locked even if it is completed Return: True if no disk DB and the trajectory was locked Raises: SQLAlchemyError if the DB could not be accessed """ if not self._save_to_disk or not self._pool_db: return True return self._pool_db.lock_trajectory(tid, allow_completed_lock)
[docs] def unlock_trajectory(self, tid: int, has_ended: bool) -> None: """Unlock a trajectory in the SQL DB. Args: tid: the trajectory id has_ended: True if the trajectory has ended Raises: SQLAlchemyError if the DB could not be accessed """ if not self._save_to_disk or not self._pool_db: return if has_ended: self._pool_db.mark_trajectory_as_completed(tid) else: self._pool_db.release_trajectory(tid)
[docs] def update_trajectory_file(self, traj_id: int, checkfile: Path) -> None: """Update a trajectory file in the DB. Args: traj_id : The trajectory id checkfile : The new checkfile of that trajectory Raises: SQLAlchemyError if the DB could not be accessed """ if not self._save_to_disk or not self._pool_db: return checkfile_str = checkfile.relative_to(self._abs_path).as_posix() self._pool_db.update_trajectory_file(traj_id, checkfile_str)
[docs] def weights(self) -> list[float]: """Splitting iterations weights.""" return self._weights
[docs] def append_weight(self, weight: float) -> None: """Append a weight to internal list.""" self._weights.append(weight)
[docs] def biases(self) -> list[int]: """Splitting iterations biases.""" return self._l_bias
[docs] def append_bias(self, bias: int) -> None: """Append a bias to internal list.""" self._l_bias.append(bias)
[docs] def append_minmax(self, ksplit: int, minofmaxes: np.number, maxofmaxes: np.number) -> None: """Append min/max of maxes to internal list.""" self._minmax.append(np.array([float(ksplit), minofmaxes, maxofmaxes]))
[docs] def k_split(self) -> int: """Splitting iteration counter.""" return self._ksplit
[docs] def done_with_splitting(self) -> bool: """Check if we are done with splitting.""" return self._ksplit >= self._nsplititer
[docs] def reset_ongoing(self) -> None: """Reset the list of trajectories undergoing branching.""" self._ongoing = None
[docs] def get_ongoing(self) -> list[int] | None: """Return the list of trajectories undergoing branching or None.""" return self._ongoing
[docs] def set_k_split(self, ksplit: int) -> None: """Set splitting iteration counter.""" self._ksplit = ksplit
[docs] def count_ended_traj(self) -> int: """Return the number of trajectories that ended.""" count = 0 for t in self._trajs_db: if t.has_ended(): count = count + 1 return count
[docs] def count_converged_traj(self) -> int: """Return the number of trajectories that converged.""" count = 0 for t in self._trajs_db: if t.is_converged(): count = count + 1 return count
[docs] def get_transition_probability(self) -> float: """Return the transition probability.""" if self.count_ended_traj() < self._ntraj: wrn_msg = "TAMS initialization still ongoing, probability estimate not available yet" _logger.warning(wrn_msg) return 0.0 w = self._ntraj * self._weights[-1] for i in range(len(self._l_bias)): w += self._l_bias[i] * self._weights[i] return self.count_converged_traj() * self._weights[-1] / w
[docs] def info(self) -> None: """Print database info to screen.""" db_date_str = str(self._creation_date) pretty_line = "####################################################" inf_tbl = f""" {pretty_line} # TAMS v{self._version:17s} trajectory database # # Date: {db_date_str:42s} # # Model: {self._fmodel_t.name():41s} # {pretty_line} # Requested # of traj: {self._ntraj:27} # # Requested # of splitting iter: {self._nsplititer:17} # # Number of 'Ended' trajectories: {self.count_ended_traj():16} # # Number of 'Converged' trajectories: {self.count_converged_traj():12} # # Current splitting iter counter: {self._ksplit:16} # # Transition probability: {self.get_transition_probability():24} # {pretty_line} """ _logger.info(inf_tbl)
[docs] def plot_score_functions(self, fname: str | None = None, plot_archived: bool = False) -> None: """Plot the score as function of time for all trajectories.""" pltfile = fname if fname else Path(self._name).stem + "_scores.png" plt.figure(figsize=(10, 6)) for t in self._trajs_db: plt.plot(t.get_time_array(), t.get_score_array(), linewidth=0.8) if plot_archived: for t in self._archived_trajs_db: plt.plot(t.get_time_array(), t.get_score_array(), linewidth=0.8) plt.xlabel(r"$Time$", fontsize="x-large") plt.xlim(left=0.0) plt.ylabel(r"$Score \; [-]$", fontsize="x-large") plt.xticks(fontsize="x-large") plt.yticks(fontsize="x-large") plt.grid(linestyle="dotted") plt.tight_layout() # to fit everything in the prescribed area plt.savefig(pltfile, dpi=300) plt.clf() plt.close()