Source code for pyrevs.runner.worker

"""A set of functions used by pyREVS workers."""

from __future__ import annotations
import asyncio
import datetime
import functools
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from pyrevs.core import CoreDB
from pyrevs.trajectory import Trajectory
from pyrevs.trajectory.trajectory import WallTimeLimitError

if TYPE_CHECKING:
    import concurrent.futures
    from collections.abc import Callable
    from pyrevs.strategies.base import TerminationCriterion


_logger = logging.getLogger(__name__)


[docs] def update_trajectory_in_sql(traj: Trajectory, trajdb: CoreDB | None = None, db_path: str | None = None) -> None: """Wrapper for update SQL trajectory info. Args: trajdb: the SQL database to update traj: the traj to get the information from db_path: an optional pyREVS database path """ if trajdb: checkfile_str = ( traj.get_checkfile().relative_to(Path(db_path)).as_posix() if db_path else traj.get_checkfile().as_posix() ) trajdb.update_trajectory(traj.id(), checkfile_str, traj.get_metadata())
[docs] def traj_advance_with_exception( traj: Trajectory, termination_criteria: list[TerminationCriterion], walltime: float, trajdb: CoreDB | None = None, db_path: str | None = None, ) -> Trajectory: """Advance a trajectory with exception handling. Args: traj: a trajectory termination_criteria: a list of termination criterion for the advance call walltime: the time limit to advance the trajectory trajdb: a handle to the SQL database db_path: an optional path to the run database Returns: The updated trajectory """ try: traj.advance(walltime=walltime, termination_criteria=termination_criteria) except WallTimeLimitError: warn_msg = f"Trajectory {traj.idstr()} advance ran out of time !" _logger.warning(warn_msg) except Exception: err_msg = f"Trajectory {traj.idstr()} advance ran into an error !" _logger.exception(err_msg) raise finally: # Update the SQL database if trajdb: if traj.is_terminated(): trajdb.mark_trajectory_as_completed(traj.id()) else: trajdb.release_trajectory(traj.id()) update_trajectory_in_sql(traj, trajdb, db_path) # Trigger a checkfile dump if we are provided with # a database path if db_path: traj.store() return traj
[docs] def pool_worker( traj: Trajectory, termination_criteria: list[TerminationCriterion], end_date: datetime.date, trajdb_path: str | None = None, db_path: str | None = None, ) -> Trajectory: """A worker to generate each initial trajectory. Args: traj: a trajectory termination_criteria: a list of termination criterion for the advance call end_date: the time limit to advance the trajectory trajdb_path: an optional path to the SQL database db_path: an optional path to the run database Returns: The updated trajectory """ # Get wall time wall_time = -1.0 timedelta: datetime.timedelta = end_date - datetime.datetime.now(tz=datetime.timezone.utc) if timedelta: wall_time = timedelta.total_seconds() if wall_time > 0.0 and not traj.is_terminated(): # Try to lock the trajectory in the DB trajdb = None if trajdb_path: trajdb = CoreDB(trajdb_path) get_to_work = trajdb.lock_trajectory(traj.id(), allow_completed_lock=True) if not get_to_work: return traj inf_msg = f"Advancing {traj.idstr()} [time left: {wall_time}]" _logger.info(inf_msg) traj = traj_advance_with_exception(traj, termination_criteria, wall_time, trajdb, db_path) return traj
[docs] def ms_worker( from_traj: Trajectory, rst_traj: Trajectory, min_val: float, new_weight: float, termination_criteria: list[TerminationCriterion], end_date: datetime.date, trajdb_path: str | None = None, db_path: str | None = None, ) -> Trajectory: """A worker to restart trajectories. Args: from_traj: a trajectory to restart from rst_traj: the trajectory being restarted min_val: the value of the score function to restart from new_weight: the weight of the new child trajectory termination_criteria: a list of termination criterion for the advance call end_date: the time limit to advance the trajectory trajdb_path: a path to the SQL database db_path: an optional path to the run database """ # Get wall time wall_time = -1.0 timedelta: datetime.timedelta = end_date - datetime.datetime.now(tz=datetime.timezone.utc) if timedelta: wall_time = timedelta.total_seconds() trajdb = None if trajdb_path: trajdb = CoreDB(trajdb_path) if wall_time > 0.0: # Try to lock the trajectory in the DB if trajdb: get_to_work = trajdb.lock_trajectory(rst_traj.id(), allow_completed_lock=True) if not get_to_work: err_msg = f"Unable to lock trajectory {rst_traj.id()} for branching" _logger.error(err_msg) raise RuntimeError(err_msg) inf_msg = f"Restarting [{rst_traj.id()}] from {from_traj.idstr()} [time left: {wall_time}]" _logger.info(inf_msg) traj = Trajectory.branch_from_trajectory(from_traj, rst_traj, min_val, new_weight) # The branched trajectory has a new checkfile # Update the database to point to the latest one. update_trajectory_in_sql(traj, trajdb, db_path) return traj_advance_with_exception(traj, termination_criteria, wall_time, trajdb, db_path) traj = Trajectory.branch_from_trajectory(from_traj, rst_traj, min_val, new_weight) warn_msg = "MS worker ran out of time before advancing trajectory!" _logger.warning(warn_msg) # The branched trajectory has a new checkfile, even if haven't advanced yet # Update the database to point to the latest one. update_trajectory_in_sql(traj, trajdb, db_path) return traj
[docs] async def worker_async( queue: asyncio.Queue[tuple[Callable[..., Any], Trajectory, float, bool, str]], res_queue: asyncio.Queue[asyncio.Future[Trajectory]], executor: concurrent.futures.Executor, ) -> None: """An async worker for the asyncio taskrunner. It wraps the call to one of the above worker functions with access to the queue. Args: queue: a queue from which to get tasks res_queue: a queue to put the results in executor: an executor to launch the work in """ while True: func, *work_unit = await queue.get() loop = asyncio.get_running_loop() traj: asyncio.Future[Trajectory] = await loop.run_in_executor( executor, functools.partial(func, *work_unit), ) await res_queue.put(traj) queue.task_done()