Source code for pytams.tams

import copy
import os
import shutil
import time
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import List
import numpy as np
from pytams.daskutils import DaskRunner
from pytams.trajectory import Trajectory
from pytams.xmlutils import dict_to_xml
from pytams.xmlutils import new_element
from pytams.xmlutils import xml_to_dict


[docs] class TAMSError(Exception): """Exception class for TAMS.""" pass
[docs] class TAMS: """A class implementing TAMS. Hold a Trajectory database and mechanisms to populate, explore and IO the database. """ def __init__(self, fmodel_t, parameters: dict) -> None: """Initialize a TAMS run. Args: fmodel_t: the forward model type parameters: a dictionary of input parameters """ self._fmodel_t = fmodel_t self.parameters = parameters # Parse user-inputs self.v = parameters.get("Verbose", False) self._saveDB = self.parameters.get("DB_save", False) self._prefixDB = self.parameters.get("DB_prefix", "TAMS") self._restartDB = self.parameters.get("DB_restart", None) self._nTraj = self.parameters.get("nTrajectories", 500) self._nSplitIter = self.parameters.get("nSplitIter", 2000) self._wallTime = self.parameters.get("wallTime", 600.0) # Trajectory Pool self._trajs_db = [] # Trajectory Database if (self._saveDB): self._nameDB = "{}.tdb".format(self._prefixDB) # Splitting data self._kSplit = 0 self._l_bias = [] self._weights = [1] # Initialize self._startTime = time.monotonic() if self._restartDB is not None: self.restoreTrajDB() else: self.initTrajDB() self.init_trajectory_pool()
[docs] def initTrajDB(self) -> None: """Initialize the trajectory database.""" if self._saveDB: self.verbosePrint( "Initializing the trajectories database {}".format(self._nameDB) ) if os.path.exists(self._nameDB) and self._nameDB != self._restartDB: rng = np.random.default_rng(12345) copy_exists = True while copy_exists: random_int = rng.integers(0, 999999) nameDB_rnd = "{}_{:06d}".format(self._nameDB, random_int) copy_exists = os.path.exists(nameDB_rnd) print( """ TAMS database {} already present but not specified as restart. It will be copied to {}.""".format( self._nameDB, nameDB_rnd ) ) shutil.move(self._nameDB, nameDB_rnd) os.mkdir(self._nameDB) # Header file with metadata headerFile = "{}/header.xml".format(self._nameDB) root = ET.Element("header") mdata = ET.SubElement(root, "metadata") mdata.append(new_element("pyTAMS_version", datetime.now())) mdata.append(new_element("date", datetime.now())) mdata.append(new_element("model_t", self._fmodel_t.name())) root.append(dict_to_xml("parameters", self.parameters)) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(headerFile) # Initialialize splitting data file self.saveSplittingData(self._nameDB) # Dynamically updated file with trajectory pool # Empty for now databaseFile = "{}/trajPool.xml".format(self._nameDB) root = ET.Element("trajectories") root.append(new_element("nTraj", self._nTraj)) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(databaseFile) # Empty trajectories subfolder os.mkdir("{}/{}".format(self._nameDB, "trajectories"))
[docs] def appendTrajsToDB(self) -> None: """Append started trajectories to the pool file.""" if self._saveDB: self.verbosePrint( "Appending started trajectories to database {}".format(self._nameDB) ) databaseFile = "{}/trajPool.xml".format(self._nameDB) tree = ET.parse(databaseFile) root = tree.getroot() for T in self._trajs_db: T_entry = root.find(T.id()) if T.hasStarted() and T_entry is None: loc = T.checkFile() root.append(new_element(T.id(), loc)) ET.indent(tree, space="\t", level=0) tree.write(databaseFile)
[docs] def saveSplittingData(self, a_db: str) -> None: """Write splitting data to XML file.""" # Splitting data file splittingDataFile = "{}/splittingData.xml".format(a_db) root = ET.Element("Splitting") root.append(new_element("kSplit", self._kSplit)) root.append(new_element("bias", np.array(self._l_bias))) root.append(new_element("weight", np.array(self._weights))) tree = ET.ElementTree(root) ET.indent(tree, space="\t", level=0) tree.write(splittingDataFile)
[docs] def readSplittingData(self, a_db: str) -> None: """Read splitting data from XML file.""" # Read data file splittingDataFile = "{}/splittingData.xml".format(a_db) tree = ET.parse(splittingDataFile) root = tree.getroot() datafromxml = xml_to_dict(root) self._kSplit = datafromxml["kSplit"] self._l_bias = datafromxml["bias"].tolist() self._weights = datafromxml["weight"].tolist()
[docs] def restoreTrajDB(self) -> None: """Initialize TAMS from a stored trajectory database.""" if os.path.exists(self._restartDB): self.verbosePrint( "Restoring from the trajectories database {}".format(self._restartDB) ) # Check the database parameters against current run self.check_database_consistency(self._restartDB) # Load splitting data self.readSplittingData(self._restartDB) # Load trajectories stored in the database when available. dbFile = "{}/trajPool.xml".format(self._restartDB) nTrajRestored = self.loadTrajectoryDB(dbFile) self.verbosePrint( "--> {} trajectories restored from the database".format(nTrajRestored) ) else: raise TAMSError( "Could not find the {} TAMS database !".format(self._restartDB) )
[docs] def loadTrajectoryDB(self, dbFile: str) -> int: """Load trajectories stored into the database. Args: dbFile: the database file Return: number of trajectories loaded """ # Counter for number of trajectory loaded nTrajRestored = 0 tree = ET.parse(dbFile) root = tree.getroot() for n in range(self._nTraj): trajId = "traj{:06}".format(n) T_entry = root.find(trajId) if T_entry is not None: chkFile = T_entry.text if os.path.exists(chkFile): nTrajRestored += 1 self._trajs_db.append( Trajectory.restoreFromChk( chkFile, fmodel_t=self._fmodel_t, ) ) else: raise TAMSError( "Could not find the trajectory checkFile {} listed in the TAMS database !".format( chkFile ) ) else: self._trajs_db.append( Trajectory( fmodel_t=self._fmodel_t, parameters=self.parameters, trajId="traj{:06}".format(n), ) ) return nTrajRestored
[docs] def check_database_consistency(self, a_db: str) -> None: """Check the restart database consistency.""" # Open and load header headerFile = "{}/header.xml".format(a_db) tree = ET.parse(headerFile) root = tree.getroot() headerfromxml = xml_to_dict(root.find("metadata")) if self._fmodel_t.name() != headerfromxml["model_t"]: raise TAMSError( "Trying to restore a TAMS with {} model from database with {} model !".format( self._fmodel_t.name(), headerfromxml["model_t"] ) ) # Parameters stored in the database override any # newly modified params # TODO: will need to relax this later on paramsfromxml = xml_to_dict(root.find("parameters")) self.parameters.update(paramsfromxml)
[docs] def verbosePrint(self, message: str) -> None: """Print only in verbose mode.""" if self.v: print("TAMS-[{}]".format(message))
[docs] def elapsed_time(self) -> float: """Return the elapsed wallclock time. Since the initialization of TAMS [seconds]. Returns: TAMS elapse time. """ return time.monotonic() - self._startTime
[docs] def remaining_walltime(self) -> float: """Return the remaining wallclock time. [seconds] Returns: TAMS remaining wall time. """ return self._wallTime - self.elapsed_time()
[docs] def out_of_time(self) -> bool: """Return true if insufficient walltime remains. Returns: boolean indicating wall time availability. """ return self.remaining_walltime() < 0.05 * self._wallTime
[docs] def init_trajectory_pool(self): """Initialize the trajectory pool.""" self.hasEnded = np.full((self._nTraj), False) for n in range(self._nTraj): self._trajs_db.append( Trajectory( fmodel_t=self._fmodel_t, parameters=self.parameters, trajId="traj{:06}".format(n), ) )
[docs] def task_delayed(self, traj: Trajectory) -> Trajectory: """A worker to generate each initial trajectory. Args: traj: a trajectory """ if not self.out_of_time() and not traj.hasEnded(): traj.advance(walltime=self.remaining_walltime()) if self._saveDB: traj.setCheckFile( "{}/{}/{}.xml".format(self._nameDB, "trajectories", traj.id()) ) traj.store() return traj
[docs] def generate_trajectory_pool(self) -> None: """Schedule the generation of a pool of stochastic trajectories.""" self.verbosePrint( "Creating the initial pool of {} trajectories".format(self._nTraj) ) with DaskRunner(self.parameters) as runner: # Assemble a list of promises # All the trajectories are added, even those already done tasks_p = [] for T in self._trajs_db: tasks_p.append(runner.make_promise(self.task_delayed, T)) self._trajs_db = runner.execute_promises(tasks_p) # Update the trajectory database self.appendTrajsToDB() self.verbosePrint("Run time: {} s".format(self.elapsed_time()))
[docs] def worker( self, t_end: float, min_idx_list: List[int], rstId: str, min_val: float ) -> Trajectory: """A worker to restart trajectories. Args: t_end: a final time min_idx_list: the list of trajectory restarted in the current splitting iteration rstId: Id of the trajectory being worked on min_val: the value of the score function to restart from """ rng = np.random.default_rng() rest_idx = min_idx_list[0] while rest_idx in min_idx_list: rest_idx = rng.integers(0, len(self._trajs_db)) traj = Trajectory.restartFromTraj(self._trajs_db[rest_idx], rstId, min_val) traj.advance(walltime=self.remaining_walltime()) return traj
[docs] def do_multilevel_splitting(self) -> None: """Schedule splitting of the initial pool of stochastic trajectories.""" self.verbosePrint("Using multi-level splitting to get the probability") # Initialize splitting iterations counter k = self._kSplit with DaskRunner(self.parameters) as runner: while k <= self._nSplitIter: # Check for walltime if self.out_of_time(): self.verbosePrint( "Ran out of time after {} splitting iterations".format( k ) ) break # Gather max score from all trajectories # and check for early convergence allConverged = True maxes = np.zeros(len(self._trajs_db)) for i in range(len(self._trajs_db)): maxes[i] = self._trajs_db[i].scoreMax() allConverged = allConverged and self._trajs_db[i].isConverged() # Exit if our work is done if allConverged: self.verbosePrint( "All trajectory converged after {} splitting iterations".format( k ) ) break # Exit if splitting is stalled if (np.amax(maxes) - np.amin(maxes)) < 1e-10: raise TAMSError( "Splitting is stalling with all trajectories stuck at a score_max: {}".format( np.amax(maxes)) ) # Get the nworker lower scored trajectories min_idx_list = np.argpartition(maxes, runner.dask_nworker)[ : runner.dask_nworker ] min_vals = maxes[min_idx_list] self._l_bias.append(len(min_idx_list)) self._weights.append(self._weights[-1] * (1 - self._l_bias[-1] / self._nTraj)) # Assemble a list of promises tasks_p = [] for i in range(len(min_idx_list)): tasks_p.append( runner.make_promise(self.worker, 1.0e9, min_idx_list, self._trajs_db[min_idx_list[i]].id(), min_vals[i]) ) restartedTrajs = runner.execute_promises(tasks_p) # Update the trajectory pool and database k += runner.dask_nworker self._kSplit = k for i in range(len(min_idx_list)): self._trajs_db[min_idx_list[i]] = copy.deepcopy(restartedTrajs[i]) if self._saveDB: self.saveSplittingData(self._nameDB) tid = self._trajs_db[min_idx_list[i]].id() self._trajs_db[min_idx_list[i]].setCheckFile( "{}/{}/{}.xml".format(self._nameDB, "trajectories", tid) ) self._trajs_db[min_idx_list[i]].store()
[docs] def compute_probability(self) -> float: """Compute the probability using TAMS. Returns: the transition probability """ self.verbosePrint( "Computing {} rare event probability using TAMS".format( self._fmodel_t.name() ) ) # Skip pool stage if splitting iterative # process has started skip_pool = self._kSplit > 0 # Generate the initial trajectory pool if not skip_pool: self.generate_trajectory_pool() # Check for early convergence allConverged = True for T in self._trajs_db: if not T.isConverged(): allConverged = False break if not skip_pool and allConverged: self.verbosePrint("All trajectory converged prior to splitting !") return 1.0 if self.out_of_time(): self.verbosePrint("Ran out of walltime ! Exiting now.") return -1.0 # Perform multilevel splitting self.do_multilevel_splitting() W = self._nTraj * self._weights[-1] for i in range(len(self._l_bias)): W += self._l_bias[i] * self._weights[i] # Compute how many traj. converged to the vicinity of B successCount = 0 for T in self._trajs_db: if T.isConverged(): successCount += 1 trans_prob = successCount * self._weights[-1] / W self.verbosePrint("Run time: {} s".format(self.elapsed_time())) return trans_prob
[docs] def nTraj(self) -> int: """Return the number of trajectory used for TAMS.""" return self._nTraj