Source code for pyrevs.database.database

"""A database class for the pyREVS sampler."""

from __future__ import annotations
import copy
import datetime
import importlib
import logging
import sys
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Generic
from typing import TypeVar
import cloudpickle
import matplotlib.pyplot as plt
import numpy as np
import toml
from pyrevs.core import Config
from pyrevs.core import CoreDB
from pyrevs.diagnostics import DiagDB
from pyrevs.trajectory import Trajectory
from pyrevs.trajectory import TrajectoryConfig
from pyrevs.trajectory.trajectory import form_trajectory_id
from pyrevs.utils.utils import get_module_local_import
from pyrevs.utils.xmlutils import new_element
from pyrevs.utils.xmlutils import xml_to_dict
from .config import DatabaseConfig

if TYPE_CHECKING:
    from pyrevs.core import ForwardModelBaseClass
    from .extension import StrategyDatabaseExtension

_logger = logging.getLogger(__name__)

[docs] T_Noise = TypeVar("T_Noise")
[docs] T_State = TypeVar("T_State")
@dataclass
[docs] class DatabaseCoreSpec: """A dataclass for the database core specification."""
[docs] ntraj: int
[docs] strategy: str
[docs] deterministic: bool
[docs] diag_configs: dict[str, Config] | None
[docs] class Database(Generic[T_Noise, T_State]): """A database class for the pyREVS sampler. The database class is a container for all the trajectory (and splitting data). When the user provides a path to store the database, a local folder is created holding a number of readable files, any output from the model and SQL files used to lock/release trajectories as the sampling algorithm proceeds. The readable files are currently in an XML format. A database can be loaded independently from the sampling algorithm and used for post-processing. Attributes: _fmodel_t: the forward model type _database_cfg: the database configuration dataclass _parameters: the dictionary of parameters _trajs_db: the list of trajectories """ def __init__( self, fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]], config: Config, read_only: bool = True, ) -> None: """Initialize a pyREVS database. Initialize pyREVS database object, bare in-memory or on-disk. On-disk database trigger if a path is provided in the parameters dictionary. The user can chose to not append/override the existing database in which case the existing path will be copied to a new random name. Args: fmodel_t: the forward model type config: a Config object with the full configuration read_only: the database access mode """ # Access mode self._read_only = read_only # Metadata and immutable parameters self._creation_date = datetime.datetime.now(tz=datetime.timezone.utc) self._version = version("pyrevs") self._name = "pyREVS_" + fmodel_t.name() self._strategy = "unknown" self._fmodel_t = fmodel_t self._config = config self._database_cfg = config.load(DatabaseConfig) self._model_params = config.section_dict("model") # Trajectory ensemble parameters self._deterministic = False self._ntraj: int = -1 # Database format/storage parameters # Update the database name and create an absolute path if self._database_cfg.path: if self._database_cfg.format != "XML": err_msg = f"Unsupported pyREVS database format: {self._database_cfg.format} !" _logger.error(err_msg) raise ValueError(err_msg) self._name = f"{self._database_cfg.path}" self._abs_path: Path = Path.cwd() / self._name # Trajectory ensembles: in-memory # - one for active trajectories # - one for archived (discarded) trajectories self._trajs_db: list[Trajectory[T_Noise, T_State]] = [] self._archived_trajs_db: list[Trajectory[T_Noise, T_State]] = [] # Trajectory ensemble: persistent SQL database self._sql_name: str = "" self._sql_db: CoreDB | None = None # Diagnostics self._diag_configs: dict[str, Config] | None = None # Strategy-specific database extension self._strategy_extension: StrategyDatabaseExtension | None = None @classmethod
[docs] def create( cls, fmodel_t: type[ForwardModelBaseClass[T_Noise, T_State]], config: Config, ) -> Database[T_Noise, T_State]: """Create a new pyREVS database. If no path is provided in the TOML configuration, the database is created in-memory and this is equivalent to calling the constructor. Args: fmodel_t: the forward model type config: a Config object with the full configuration Return: a pyREVS database object """ db = cls(fmodel_t, config, read_only=False) if not db.to_disk(): db.init_traj_pool() return db db._setup_folder() db._write_metadata() db._serialize_fmodel() db.init_traj_pool() return db
@classmethod
[docs] def load(cls, a_path: Path, read_only: bool = True) -> Database[T_Noise, T_State]: """Instantiate a pyREVS database from disk. Args: a_path: the path to the database read_only: the database access mode Return: a pyREVS database object """ if not a_path.exists(): err_msg = f"Database {a_path} does not exist !" _logger.error(err_msg) raise FileNotFoundError(err_msg) # Load Config cfg_path = a_path / "input_params.toml" if not cfg_path.exists(): err_msg = f"Database {cfg_path} does not exist !" _logger.error(err_msg) raise FileNotFoundError(err_msg) with cfg_path.open("r") as f: config = Config(toml.load(f)) # Load picked forward model model_file = Path(a_path / "fmodel.pkl") if not model_file.exists(): err_msg = f"Picked forward model {model_file} is missing from the database !" _logger.error(err_msg) raise FileNotFoundError(err_msg) with model_file.open("rb") as f: path = cloudpickle.load(f) module_name, cls_name = path.rsplit(".", 1) mod = importlib.import_module(module_name) fmodel_t = getattr(mod, cls_name) db = cls(fmodel_t, config, read_only) db._load_metadata() db.init_traj_pool() return db
def _setup_folder(self) -> None: """Initialize the trajectory database folders on disk.""" if self.to_disk(): # Create the top-level folder Path(self._name).mkdir() # Empty trajectories subfolder Path(self._abs_path / "trajectories").mkdir(parents=True) def _serialize_fmodel(self) -> None: # Serialize the model # We need to pickle by value the local modules # which might not be available if we move the database # Note: only one import depth is handled at this point, we might # want to make this recursive in the future model_file = Path(self._abs_path / "fmodel.pkl") cloudpickle.register_pickle_by_value(sys.modules[self._fmodel_t.__module__]) for mods in get_module_local_import(self._fmodel_t.__module__): cloudpickle.register_pickle_by_value(sys.modules[mods]) with model_file.open("wb") as f: cloudpickle.dump(self._fmodel_t.__module__ + "." + self._fmodel_t.__name__, f) def _write_metadata(self) -> None: """Write the database Metadata to disk.""" if self._database_cfg.format == "XML": header_file = self.header_file() root = ET.Element("header") mdata = ET.SubElement(root, "metadata") mdata.append(new_element("pyREVS_version", version("pyrevs"))) mdata.append(new_element("date", self._creation_date)) mdata.append(new_element("model_t", self._fmodel_t.name())) mdata.append(new_element("strategy", self._strategy)) mdata.append(new_element("ntraj", self.n_traj())) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(header_file) else: err_msg = f"Unsupported pyREVS database format: {self._database_cfg.format} !" _logger.error(err_msg) raise ValueError(err_msg) def _load_metadata(self) -> None: """Read the database Metadata from the header.""" if self.to_disk(): if self._database_cfg.format == "XML": tree = ET.parse(self.header_file()) root = tree.getroot() mdata = root.find("metadata") datafromxml = xml_to_dict(mdata) self._ntraj = datafromxml["ntraj"] self._strategy = datafromxml["strategy"] self._version = datafromxml["pyREVS_version"] if self._version != version("pyrevs"): warn_msg = f"Database pyREVS version {self._version} is different from {version('pyrevs')}" _logger.warning(warn_msg) self._creation_date = datafromxml["date"] db_model = datafromxml["model_t"] if db_model != self._fmodel_t.name(): err_msg = f"Database model {db_model} is different from call {self._fmodel_t.name()}" _logger.error(err_msg) raise RuntimeError(err_msg) else: err_msg = f"Unsupported pyREVS database format: {self._database_cfg.format} !" _logger.error(err_msg) raise ValueError(err_msg)
[docs] def initialize_core_state(self, spec: DatabaseCoreSpec) -> None: """Initialize the core state of the database. From a specification dataclass. It also updates the metadata on disk if the database is on-disk. Args: spec: the database core specification """ self._ntraj = spec.ntraj self._strategy = spec.strategy self._deterministic = spec.deterministic self._diag_configs = spec.diag_configs self.ping_diag_database() if self.to_disk(): self._write_metadata()
[docs] def init_traj_pool(self) -> None: """Initialize the trajectory pool.""" # If an in-memory database is requested, a temporary (hidden) # sql database is still created if self.to_disk(): self._sql_name = f"{self._name}/trajPool.db" else: self._sql_name = f".sqldb_pyrevs_{np.random.default_rng().integers(0, 999999):06d}.db" if self._read_only: self._sql_db = CoreDB(self.pool_file(), ro_mode=True) else: self._sql_db = CoreDB(self.pool_file())
[docs] def attach_extension(self, ext: StrategyDatabaseExtension) -> None: """Attach an extension to the database.""" self._strategy_extension = ext if self.to_disk(): self._strategy_extension.serialize()
[docs] def extension(self) -> StrategyDatabaseExtension: """Return the extension attached to the database.""" if self._strategy_extension is None: err_msg = "No strategy extension attached to the database !" _logger.error(err_msg) raise RuntimeError(err_msg) return self._strategy_extension
[docs] def init_active_ensemble(self) -> None: """Initialize the requested number of trajectories.""" traj_cfg = self._config.load(TrajectoryConfig) traj_cfg.validate() for n in range(self.n_traj()): if not self._require_pool_db().check_trajectory_exist(n): workdir = Path(self._abs_path / f"trajectories/{form_trajectory_id(n)}") if self.to_disk() else None t: Trajectory[T_Noise, T_State] = Trajectory( traj_id=n, weight=1.0, fmodel_t=self._fmodel_t, traj_cfg=traj_cfg, diag_configs=self._diag_configs, model_params=self._model_params, workdir=workdir, deterministic=self._deterministic, ) self.append_traj(t, True)
[docs] def save_trajectory(self, traj: Trajectory[T_Noise, T_State]) -> None: """Save a trajectory to disk in the database. Args: traj: the trajectory to save """ if not self.to_disk(): return traj.store()
[docs] def load_data(self, load_archived_trajectories: bool = False) -> None: """Load data stored into the database. The initialization of the database only populate the metadata but not the full trajectories data. Args: load_archived_trajectories: whether to load archived trajectories """ if not self.to_disk(): return # Counter for number of trajectory loaded n_traj_restored = 0 n_traj_initialized = 0 load_frozen = self._read_only traj_cfg = self._config.load(TrajectoryConfig) traj_cfg.validate() ntraj_in_db = self._require_pool_db().get_trajectory_count() for n in range(ntraj_in_db): checkpath, metadata = self._require_pool_db().fetch_trajectory(n) traj_checkfile = Path(self._abs_path) / checkpath workdir = Path(self._abs_path / f"trajectories/{traj_checkfile.stem}") if traj_checkfile.exists(): n_traj_restored += 1 self.append_traj( Trajectory.restore_from_checkfile( traj_checkfile, metadata, fmodel_t=self._fmodel_t, traj_cfg=traj_cfg, diag_configs=self._diag_configs, model_params=self._model_params, workdir=workdir, frozen=load_frozen, ), False, ) else: n_traj_initialized += 1 self.append_traj( Trajectory.init_from_metadata( metadata, fmodel_t=self._fmodel_t, traj_cfg=traj_cfg, diag_configs=self._diag_configs, model_params=self._model_params, workdir=workdir, ), False, ) if n_traj_restored > 0: inf_msg = f"{n_traj_restored} active trajectories loaded" _logger.info(inf_msg) inf_msg = f"{n_traj_initialized} active trajectories initialized" _logger.info(inf_msg) # Load the archived trajectories if requested. # Those are loaded as 'frozen', i.e. the internal model # is not available and advance function disabled. if load_archived_trajectories: self.load_archived_trajectories() self.info()
[docs] def load_archived_trajectories(self) -> None: """Load the archived trajectories data.""" if not self.to_disk(): return n_traj_restored = 0 traj_cfg = self._config.load(TrajectoryConfig) traj_cfg.validate() archived_ntraj_in_db = self._require_pool_db().get_archived_trajectory_count() for n in range(archived_ntraj_in_db): checkpath, metadata_str = self._require_pool_db().fetch_archived_trajectory(n) traj_checkfile = Path(self._abs_path) / checkpath workdir = Path(self._abs_path / f"trajectories/{traj_checkfile.stem}") if traj_checkfile.exists(): n_traj_restored += 1 self.append_archived_traj( Trajectory.restore_from_checkfile( traj_checkfile, metadata_str, fmodel_t=self._fmodel_t, traj_cfg=traj_cfg, diag_configs=self._diag_configs, model_params=self._model_params, workdir=workdir, frozen=True, ), False, ) inf_msg = f"{n_traj_restored} archived trajectories loaded" _logger.info(inf_msg)
[docs] def name(self) -> str: """Accessor to DB name. Return: DB name """ return self._name
[docs] def strategy(self) -> str: """Accessor to DB strategy. Return: DB strategy """ return self._strategy
[docs] def to_disk(self) -> bool: """Check if the database is stored on disk.""" return self._database_cfg.path is not None
[docs] def append_traj(self, a_traj: Trajectory[T_Noise, T_State], update_db: bool) -> None: """Append a Trajectory to the internal list. Args: a_traj: the trajectory update_db: True to update the SQL DB content """ # Also adds it to the SQL pool file. # and set the checkfile if self.to_disk(): checkfile_str = f"./trajectories/{a_traj.idstr()}.xml" checkfile = Path(self._abs_path) / checkfile_str a_traj.set_checkfile(checkfile) else: checkfile_str = f"{a_traj.idstr()}.xml" if update_db: self._require_pool_db().add_trajectory(checkfile_str, a_traj.get_metadata()) self._trajs_db.append(a_traj)
[docs] def append_archived_traj(self, a_traj: Trajectory[T_Noise, T_State], update_db: bool) -> None: """Append an archived Trajectory to the internal list. Args: a_traj: the trajectory update_db: True to update the SQL DB content """ checkfile_str = f"./trajectories/{a_traj.idstr()}.xml" checkfile = Path(self._abs_path) / checkfile_str a_traj.set_checkfile(checkfile) if update_db: self._require_pool_db().archive_trajectory(checkfile_str, a_traj.get_metadata()) self._archived_trajs_db.append(a_traj)
[docs] def traj_list(self) -> list[Trajectory[T_Noise, T_State]]: """Access to the trajectory list. Return: Trajectory list """ return self._trajs_db
[docs] def get_traj(self, idx: int) -> Trajectory[T_Noise, T_State]: """Access to a given trajectory. Args: idx: the index Return: Trajectory Raises: ValueError if idx is out of range """ if idx < 0 or idx >= len(self._trajs_db): err_msg = f"Trying to access a non existing trajectory {idx} !" _logger.error(err_msg) raise ValueError(err_msg) return self._trajs_db[idx]
[docs] def overwrite_traj(self, idx: int, traj: Trajectory[T_Noise, T_State]) -> None: """Deep copy a trajectory into internal list. Args: idx: the index of the trajectory to override traj: the new trajectory Raises: ValueError if idx is out of range """ if idx < 0 or idx >= len(self._trajs_db): err_msg = f"Trying to override a non existing trajectory {idx} !" _logger.error(err_msg) raise ValueError(err_msg) self._trajs_db[idx] = copy.deepcopy(traj)
[docs] def ping_diag_database(self) -> None: """Initialize the diagDB from the main process. To avoid race condition in creating the DB from workers, let it be initialized from the main process here. """ if self._diag_configs is not None: ddb_path = self._abs_path / DiagDB.default_name() if self.to_disk() else Path(DiagDB.default_name()) ddb = DiagDB(ddb_path.absolute().as_posix()) ddb.close()
[docs] def update_diagnostic_weights(self, tweight: float) -> None: """Update the weights of all the active trajectories.""" if self._diag_configs is not None: ddb_path = self._abs_path / DiagDB.default_name() if self.to_disk() else Path(DiagDB.default_name()) ddb = DiagDB(ddb_path.absolute().as_posix()) ddb.update_all_active_weights(tweight) ddb.close()
[docs] def header_file(self) -> str: """Helper returning the DB header file. Return: Header file """ return f"{self._name}/header.xml"
[docs] def pool_file(self) -> str: """Helper returning the DB trajectory pool file. Return: Pool file """ return self._sql_name
[docs] def get_pool_db(self) -> CoreDB | None: """Get the pool SQL database handle.""" return self._sql_db
def _require_pool_db(self) -> CoreDB: """Internal accessor to the SQL database handle. Raises: RuntimeError if the SQL database has not been initialized """ if self._sql_db is None: err_msg = "The SQL database has not been initialized !" _logger.error(err_msg) raise RuntimeError(err_msg) return self._sql_db
[docs] def is_empty(self) -> bool: """Check if list of trajectories is empty. Return: True if the list of trajectories is empty """ return self._require_pool_db().get_trajectory_count() == 0
[docs] def traj_list_len(self) -> int: """Length of the trajectory list. Return: Trajectory list length """ return len(self._trajs_db)
[docs] def archived_traj_list(self) -> list[Trajectory[T_Noise, T_State]]: """Access to the archived trajectory list. Return: Archived trajectory list """ return self._archived_trajs_db
[docs] def archived_traj_list_len(self) -> int: """Length of the archived trajectory list. Return: Trajectory list length """ return len(self._archived_trajs_db)
[docs] def update_traj_list(self, a_traj_list: list[Trajectory[T_Noise, T_State]]) -> None: """Overwrite the trajectory list. Args: a_traj_list: the new trajectory list """ self._trajs_db = a_traj_list
[docs] def archive_trajectory(self, traj: Trajectory[T_Noise, T_State]) -> None: """Archive a trajectory about to be discarded. Args: traj: the trajectory to archive """ # A branched trajectory will be overwritten by the # newly generated one in-place in the _trajs_db list. self._archived_trajs_db.append(traj) # Update the list of archived trajectories in the SQL DB checkfile_str = ( traj.get_checkfile().relative_to(self._abs_path).as_posix() if self.to_disk() else traj.get_checkfile().as_posix() ) self._require_pool_db().archive_trajectory(checkfile_str, traj.get_metadata())
[docs] def lock_trajectory(self, tid: int, allow_completed_lock: bool = False) -> bool: """Lock a trajectory in the SQL DB. Args: tid: the trajectory id allow_completed_lock: True if the trajectory can be locked even if it is completed Return: True if no disk DB and the trajectory was locked Raises: SQLAlchemyError if the DB could not be accessed """ return self._require_pool_db().lock_trajectory(tid, allow_completed_lock)
[docs] def unlock_trajectory(self, tid: int, has_terminated: bool) -> None: """Unlock a trajectory in the SQL DB. Args: tid: the trajectory id has_terminated: True if the trajectory has terminated Raises: SQLAlchemyError if the DB could not be accessed """ if has_terminated: self._require_pool_db().mark_trajectory_as_completed(tid) else: self._require_pool_db().release_trajectory(tid)
[docs] def update_trajectory(self, traj_id: int, traj: Trajectory[T_Noise, T_State]) -> None: """Update a trajectory file in the DB. Args: traj_id : The trajectory id traj : the trajectory to get the data from Raises: SQLAlchemyError if the DB could not be accessed """ checkfile_str = traj.get_checkfile().relative_to(self._abs_path).as_posix() self._require_pool_db().update_trajectory(traj_id, checkfile_str, traj.get_metadata())
[docs] def n_traj(self) -> int: """Return the number of trajectory used for sampling run. Note that this is the requested number of trajectory, not the current length of the trajectory pool. Return: number of trajectory """ return self._ntraj
[docs] def path(self) -> str | None: """Return the path to the database.""" if self.to_disk(): return self._abs_path.absolute().as_posix() return None
[docs] def count_terminated_traj(self) -> int: """Return the number of trajectories that terminated.""" return self._require_pool_db().get_terminated_trajectory_count()
[docs] def count_converged_traj(self) -> int: """Return the number of trajectories that converged.""" return self._require_pool_db().get_converged_trajectory_count()
[docs] def all_converged(self) -> bool: """Check if all the trajectory converged.""" return self.count_converged_traj() == self.n_traj()
[docs] def count_computed_steps(self) -> int: """Return the total number of steps taken. This total count includes both the active and discarded trajectories. """ return self._require_pool_db().get_total_computed_steps()
[docs] def get_event_probability(self) -> float: """Return the event probability. Default to a Monte-Carlo event probability if no strategy extension is attached to the database. Return: the event probability """ if self._strategy_extension is not None: return self._strategy_extension.get_event_probability() return self.count_converged_traj() / self.n_traj()
[docs] def info(self) -> None: """Print database info to logger.""" db_date_str = str(self._creation_date) pretty_line = "####################################################" inf_tbl = f""" {pretty_line} # pyREVS v{self._version:40s} # # Creation Date: {db_date_str:33s} # # Model: {self._fmodel_t.name():41s} # # Sampling strategy: {self._strategy:29s} # {pretty_line} # Requested # of traj: {self.n_traj():27} # # Number of 'Terminated' trajectories: {self.count_terminated_traj():11} # # Number of 'Converged' trajectories: {self.count_converged_traj():12} # # Current total number of steps: {self.count_computed_steps():17} # {pretty_line} """ _logger.info(inf_tbl)
[docs] def print_info(self) -> None: """Print database info to screen.""" db_date_str = str(self._creation_date) pretty_line = "####################################################" inf_tbl = f""" {pretty_line} # pyREVS v{self._version:40s} # # Creation Date: {db_date_str:33s} # # Model: {self._fmodel_t.name():41s} # # Sampling strategy: {self._strategy:29s} # {pretty_line} # Requested # of traj: {self.n_traj():27} # # Number of 'Terminated' trajectories: {self.count_terminated_traj():11} # # Number of 'Converged' trajectories: {self.count_converged_traj():12} # # Current total number of steps: {self.count_computed_steps():17} # {pretty_line} """ print(inf_tbl) # noqa: T201
[docs] def plot_score_functions(self, fname: str | None = None, plot_archived: bool = False) -> None: """Plot the score as function of time for all trajectories.""" pltfile = fname if fname else Path(self._name).stem + "_scores.png" plt.figure(figsize=(10, 6)) for t in self._trajs_db: plt.plot(t.get_time_array(), t.get_score_array(), linewidth=0.8) if plot_archived: for t in self._archived_trajs_db: plt.plot(t.get_time_array(), t.get_score_array(), linewidth=0.8) plt.xlabel(r"$Time$", fontsize="x-large") plt.xlim(left=0.0) plt.ylabel(r"$Score \; [-]$", fontsize="x-large") plt.xticks(fontsize="x-large") plt.yticks(fontsize="x-large") plt.grid(linestyle="dotted") plt.tight_layout() # to fit everything in the prescribed area plt.savefig(pltfile, dpi=300) plt.clf() plt.close()
def __del__(self) -> None: """Destructor of the db. Delete the hidden SQL database if we do not intend to keep the database around. """ # Even if we plan to keep the SQL database around, force # deleting the SQL connection if hasattr(self, "_sql_db"): del self._sql_db # Remove the hidden db file if not self.to_disk(): Path(self.pool_file()).unlink(missing_ok=True)