"""A set of functions used by TAMS workers."""importasyncioimportconcurrent.futuresimportdatetimeimportfunctoolsimportloggingfromcollections.abcimportCallablefrompathlibimportPathfromtypingimportAnyfrompytams.databaseimportDatabasefrompytams.trajectoryimportTrajectoryfrompytams.trajectoryimportWallTimeLimitError_logger=logging.getLogger(__name__)
[docs]deftraj_advance_with_exception(traj:Trajectory,walltime:float,a_db:Database|None)->Trajectory:"""Advance a trajectory with exception handling. Args: traj: a trajectory walltime: the time limit to advance the trajectory a_db: a database Returns: The updated trajectory """try:traj.advance(walltime=walltime)exceptWallTimeLimitError:warn_msg=f"Trajectory {traj.idstr()} advance ran out of time !"_logger.warning(warn_msg)exceptException:err_msg=f"Trajectory {traj.idstr()} advance ran into an error !"_logger.exception(err_msg)raisefinally:ifa_db:a_db.unlock_trajectory(traj.id(),traj.has_ended())a_db.save_trajectory(traj)returntraj
[docs]defpool_worker(traj:Trajectory,end_date:datetime.date,db_path:str|None=None)->Trajectory:"""A worker to generate each initial trajectory. Args: traj: a trajectory end_date: the time limit to advance the trajectory db_path: a path to a TAMS database or None Returns: The updated trajectory """# Get wall timewall_time=-1.0timedelta:datetime.timedelta=end_date-datetime.datetime.now(tz=datetime.timezone.utc)iftimedelta:wall_time=timedelta.total_seconds()ifwall_time>0.0andnottraj.has_ended():db=Noneifdb_path:db=Database.load(Path(db_path))# Try to lock the trajectory in the DBget_to_work=db.lock_trajectory(traj.id())ifnotget_to_work:returntrajinf_msg=f"Advancing {traj.idstr()} [time left: {wall_time}]"_logger.info(inf_msg)traj=traj_advance_with_exception(traj,wall_time,db)returntraj
[docs]defms_worker(from_traj:Trajectory,rst_traj:Trajectory,min_val:float,end_date:datetime.date,db_path:str|None=None,)->Trajectory:"""A worker to restart trajectories. Args: from_traj: a trajectory to restart from rst_traj: the trajectory being restarted min_val: the value of the score function to restart from end_date: the time limit to advance the trajectory db_path: a database path or None """# Get wall timewall_time=-1.0timedelta:datetime.timedelta=end_date-datetime.datetime.now(tz=datetime.timezone.utc)iftimedelta:wall_time=timedelta.total_seconds()ifwall_time>0.0:db=Noneifdb_path:# Fetch a handle to the trajectory we are branching in the database pool# Try to lock the trajectory in the DBdb=Database.load(Path(db_path))get_to_work=db.lock_trajectory(rst_traj.id(),True)ifnotget_to_work:err_msg=f"Unable to lock trajectory {rst_traj.id()} for branching"_logger.error(err_msg)raiseRuntimeError(err_msg)# Archive the trajectory we are branchingdb.archive_trajectory(rst_traj)inf_msg=f"Restarting [{rst_traj.id()}] from {from_traj.idstr()} [time left: {wall_time}]"_logger.info(inf_msg)traj=Trajectory.branch_from_trajectory(from_traj,rst_traj,min_val)# The branched trajectory has a new checkfile# Update the database to point to the latest one.ifdb:db.update_trajectory_file(traj.id(),traj.get_checkfile())returntraj_advance_with_exception(traj,wall_time,db)returnTrajectory.branch_from_trajectory(from_traj,rst_traj,min_val)
[docs]asyncdefworker_async(queue:asyncio.Queue[tuple[Callable[...,Any],Trajectory,float,bool,str]],res_queue:asyncio.Queue[asyncio.Future[Trajectory]],executor:concurrent.futures.Executor,)->None:"""An async worker for the asyncio taskrunner. It wraps the call to one of the above worker functions with access to the queue. Args: queue: a queue from which to get tasks res_queue: a queue to put the results in executor: an executor to launch the work in """whileTrue:func,*work_unit=awaitqueue.get()loop=asyncio.get_running_loop()traj:asyncio.Future[Trajectory]=awaitloop.run_in_executor(executor,functools.partial(func,*work_unit),)awaitres_queue.put(traj)queue.task_done()