Source code for pyrevs.strategies.montecarlo.montecarlo

"""The MonteCarlo sampling strategy."""

import logging
from pathlib import Path
from typing import Any
from pyrevs.core import Config
from pyrevs.core import RuntimeConfig
from pyrevs.database import Database
from pyrevs.database import DatabaseCoreSpec
from pyrevs.runner import RunnerConfig
from pyrevs.runner import make_runner
from pyrevs.runner import pool_worker
from pyrevs.strategies.base import BaseSamplingStrategy
from pyrevs.strategies.base import TerminationCriterion
from pyrevs.strategies.base import TimeTerminationCriterion
from .config import MCConfig
from .extension import MCDatabaseExtension

_logger = logging.getLogger(__name__)


@BaseSamplingStrategy.register("montecarlo")
[docs] class MonteCarlo(BaseSamplingStrategy): """A strategy class implementing MonteCarlo. Monte-Carlo or Direct Numerical Simulation (DNS) is not per-se a sampling strategy tailored for rare events but it provides a baseline for comparison with other sampling strategies. An ensemble of size n_traj is constructed and the rare-event probability is simply computed as the ratio of the number of converged trajectories to the total number of trajectories in the ensemble n_traj. In practice, this is the first step of a TAMS or AMS run (depending on the termination condition), such that this class is a lightweight version of these other strategies. Notes: This strategy relies on time management provided by BaseSamplingStrategy (e.g. ``self._end_date``, ``elapsed_time()``). """ def __init__( self, fmodel_t: Any, runtime_cfg: RuntimeConfig, runner_cfg: RunnerConfig, strategy_cfg: MCConfig, deterministic: bool, ) -> None: """Initialize a Monte-Carlo object. Args: fmodel_t: the forward model type runtime_cfg: the runtime config runner_cfg: the runner config strategy_cfg: the montecarlo config deterministic: the deterministic flag Raises: ValueError: if necessary config parameters are not found """ self._fmodel_t = fmodel_t self._mc_cfg = strategy_cfg self._mc_cfg.validate() self._runner_cfg = runner_cfg self._loglevel = runtime_cfg.loglevel self._logfile = runtime_cfg.logfile self._deterministic = deterministic self._term_crit: list[TerminationCriterion] = [] if strategy_cfg.end_time is not None: self._term_crit.append(TimeTerminationCriterion(strategy_cfg.end_time))
[docs] def generate_trajectory_ensemble(self, tdb: Database) -> None: """Schedule the generation of an ensemble of stochastic trajectories. Loop over all the trajectories in the database and schedule advancing them to either end time or convergence with the runner. The runner will use the number of workers specified in the input file under the runner section. Raises: Error if the runner fails """ inf_msg = f"Creating a Monte Carlo ensemble of {tdb.n_traj()} trajectories" _logger.info(inf_msg) with make_runner( self._runner_cfg, pool_worker, loglevel=self._loglevel, logfile=self._logfile, max_workers=self._mc_cfg.ntrajectories, ) as runner: for t in tdb.traj_list(): task = [t, self._term_crit, self._end_date, tdb.pool_file(), tdb.path()] runner.make_promise(task) try: t_list = runner.execute_promises() except Exception as exc: err_msg = f"Failed to generate the ensemble of {tdb.n_traj()} trajectories" _logger.exception(err_msg) raise RuntimeError(err_msg) from exc # Re-order list since runner does not guarantee order # And update list of trajectories in the database t_list.sort(key=lambda t: t.id()) tdb.update_traj_list(t_list) inf_msg = f"Run time: {self.elapsed_time()} s" _logger.info(inf_msg)
[docs] def compute_probability(self, tdb: Database) -> float: """Compute the rare-event probability using MonteCarlo. Returns: the rare-event probability """ # Generate the initial trajectory ensemble self.generate_trajectory_ensemble(tdb) return tdb.get_event_probability()
def _execute_sampling(self, database: Database, plot_diags: bool) -> None: """Shallow wrapper to enable sampler.""" database.load_data() # Initialize an empty trajectory ensemble database.init_active_ensemble() inf_msg = f"Computing {self._fmodel_t.name()} rare event probability using MonteCarlo" _logger.info(inf_msg) proba = self.compute_probability(database) # Plot trajectory database scores if plot_diags: pltfile = "Score_MCEnd.png" if Path(pltfile).exists(): wrn_msg = f"Attempting to overwrite the plot file {pltfile}" _logger.warning(wrn_msg) database.plot_score_functions(pltfile) database.info() inf_msg = f"Event probability: {proba}" _logger.info(inf_msg)
[docs] def initialize_database_schema(self, database: Database, diag_configs: dict[str, Config] | None) -> None: """Initialize database core state.""" spec = DatabaseCoreSpec( ntraj=self._mc_cfg.ntrajectories, strategy="montecarlo", deterministic=self._deterministic, diag_configs=diag_configs, ) database.initialize_core_state(spec) # Setup MC extension self._db_ext = MCDatabaseExtension() self._db_ext.initialize(database) database.attach_extension(self._db_ext)