Source code for pytams.worker

"""A set of functions used by TAMS workers."""

import asyncio
import concurrent.futures
import datetime
import functools
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any
from pytams.sqldb import SQLFile
from pytams.trajectory import Trajectory
from pytams.trajectory import WallTimeLimitError

_logger = logging.getLogger(__name__)


[docs] def update_trajectory_in_sql(traj: Trajectory, sqldb: SQLFile | None = None, db_path: str | None = None) -> None: """Wrapper for update SQL trajectory info. Args: sqldb: the SQL database to update traj: the traj to get the information from db_path: an optional TAMS database path """ if sqldb: checkfile_str = ( traj.get_checkfile().relative_to(Path(db_path)).as_posix() if db_path else traj.get_checkfile().as_posix() ) sqldb.update_trajectory(traj.id(), checkfile_str, traj.serialize_metadata_json())
[docs] def traj_advance_with_exception( traj: Trajectory, walltime: float, sqldb: SQLFile | None = None, db_path: str | None = None ) -> Trajectory: """Advance a trajectory with exception handling. Args: traj: a trajectory walltime: the time limit to advance the trajectory sqldb: a handle to the SQL database db_path: an optional path to the run database Returns: The updated trajectory """ try: traj.advance(walltime=walltime) except WallTimeLimitError: warn_msg = f"Trajectory {traj.idstr()} advance ran out of time !" _logger.warning(warn_msg) except Exception: err_msg = f"Trajectory {traj.idstr()} advance ran into an error !" _logger.exception(err_msg) raise finally: # Update the SQL database if sqldb: if traj.has_ended(): sqldb.mark_trajectory_as_completed(traj.id()) else: sqldb.release_trajectory(traj.id()) update_trajectory_in_sql(traj, sqldb, db_path) # Trigger a checkfile dump if we are provided with # a database path if db_path: traj.store() return traj
[docs] def pool_worker( traj: Trajectory, end_date: datetime.date, sql_path: str | None = None, db_path: str | None = None ) -> Trajectory: """A worker to generate each initial trajectory. Args: traj: a trajectory end_date: the time limit to advance the trajectory sql_path: an optional path to the SQL database db_path: an optional path to the run database Returns: The updated trajectory """ # Get wall time wall_time = -1.0 timedelta: datetime.timedelta = end_date - datetime.datetime.now(tz=datetime.timezone.utc) if timedelta: wall_time = timedelta.total_seconds() if wall_time > 0.0 and not traj.has_ended(): # Try to lock the trajectory in the DB sqldb = None if sql_path: sqldb = SQLFile(sql_path) get_to_work = sqldb.lock_trajectory(traj.id(), allow_completed_lock=True) if not get_to_work: return traj inf_msg = f"Advancing {traj.idstr()} [time left: {wall_time}]" _logger.info(inf_msg) traj = traj_advance_with_exception(traj, wall_time, sqldb, db_path) return traj
[docs] def ms_worker( from_traj: Trajectory, rst_traj: Trajectory, min_val: float, new_weight: float, end_date: datetime.date, sql_path: str | None = None, db_path: str | None = None, ) -> Trajectory: """A worker to restart trajectories. Args: from_traj: a trajectory to restart from rst_traj: the trajectory being restarted min_val: the value of the score function to restart from new_weight: the weight of the new child trajectory end_date: the time limit to advance the trajectory sql_path: a path to the SQL database db_path: an optional path to the run database """ # Get wall time wall_time = -1.0 timedelta: datetime.timedelta = end_date - datetime.datetime.now(tz=datetime.timezone.utc) if timedelta: wall_time = timedelta.total_seconds() sqldb = None if sql_path: sqldb = SQLFile(sql_path) if wall_time > 0.0: # Try to lock the trajectory in the DB if sqldb: get_to_work = sqldb.lock_trajectory(rst_traj.id(), allow_completed_lock=True) if not get_to_work: err_msg = f"Unable to lock trajectory {rst_traj.id()} for branching" _logger.error(err_msg) raise RuntimeError(err_msg) inf_msg = f"Restarting [{rst_traj.id()}] from {from_traj.idstr()} [time left: {wall_time}]" _logger.info(inf_msg) traj = Trajectory.branch_from_trajectory(from_traj, rst_traj, min_val, new_weight) # The branched trajectory has a new checkfile # Update the database to point to the latest one. update_trajectory_in_sql(traj, sqldb, db_path) return traj_advance_with_exception(traj, wall_time, sqldb, db_path) traj = Trajectory.branch_from_trajectory(from_traj, rst_traj, min_val, new_weight) warn_msg = "MS worker ran out of time before advancing trajectory!" _logger.warning(warn_msg) # The branched trajectory has a new checkfile, even if haven't advanced yet # Update the database to point to the latest one. update_trajectory_in_sql(traj, sqldb, db_path) return traj
[docs] async def worker_async( queue: asyncio.Queue[tuple[Callable[..., Any], Trajectory, float, bool, str]], res_queue: asyncio.Queue[asyncio.Future[Trajectory]], executor: concurrent.futures.Executor, ) -> None: """An async worker for the asyncio taskrunner. It wraps the call to one of the above worker functions with access to the queue. Args: queue: a queue from which to get tasks res_queue: a queue to put the results in executor: an executor to launch the work in """ while True: func, *work_unit = await queue.get() loop = asyncio.get_running_loop() traj: asyncio.Future[Trajectory] = await loop.run_in_executor( executor, functools.partial(func, *work_unit), ) await res_queue.put(traj) queue.task_done()