Source code for pytams.taskrunner

from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import ntpath
import shutil
from abc import ABCMeta
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
import dask
from dask.distributed import Client
from dask.distributed import WorkerPlugin
from dask_jobqueue import SLURMCluster
from typing_extensions import Self
from pytams.utils import setup_logger
from pytams.worker import worker_async

if TYPE_CHECKING:
    from collections.abc import Callable
    from dask.distributed import Worker

_logger = logging.getLogger(__name__)


[docs] class WorkerLoggerPlugin(WorkerPlugin): """A plugin to configure logging on each worker.""" def __init__(self, params: dict[Any, Any]) -> None: """Init function pass in the params dict.""" self._params = params
[docs] def setup(self, worker: Worker) -> None: """Configure logging on the worker. Args: worker: the dask worker """ # Configure logging on each worker _ = worker setup_logger(self._params)
[docs] class RunnerError(Exception): """Exception class for the runner."""
[docs] class BaseRunner(metaclass=ABCMeta): """An ABC for the task runners.""" @abstractmethod def __init__( self, params: dict, sync_wk: Callable, n_workers: int = 1, ): """A dummy init method.""" @abstractmethod def __enter__(self) -> BaseRunner: """To enable use of with.""" @abstractmethod def __exit__(self, *args: object) -> None: """Executed leaving with scope.""" @abstractmethod
[docs] def make_promise(self, task: list[Any]) -> None: """Log a new task to the list of task to tackle."""
@abstractmethod
[docs] def execute_promises(self) -> Any: """Execute the list of promises."""
@abstractmethod
[docs] def n_workers(self) -> int: """Return the number of workers in the runner."""
[docs] class AsIORunner(BaseRunner): """A task runner class based on asyncIO. An runner that relies on asyncio to schedule a tasks concurently in worker processes. Tasks are added to an internal queue from which worker can take them and put the results back into result queue. """ def __init__( self, params: dict, sync_wk: Callable, n_workers: int = 1, ): """Init the task runner. Args: params: a dictionary of parameters sync_wk: a synchronous worker function async_wk: an asynchronous worker function n_workers: number of workers """ self._params = params self._queue: asyncio.Queue[Any] = asyncio.Queue() self._rqueue: asyncio.Queue[Any] = asyncio.Queue() self._n_workers: int = n_workers self._sync_worker = sync_wk self._async_worker = worker_async self._loop: asyncio.AbstractEventLoop | None = None self._executor: concurrent.futures.Executor | None = None self._workers: list[asyncio.Task[Any]] | None = None def __enter__(self) -> Self: """To enable use of with.""" self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) return self def __exit__(self, *args: object) -> None: """Executed leaving with scope.""" if self._workers: for w in self._workers: w.cancel() if self._executor: self._executor.shutdown() if self._loop: self._loop.run_until_complete(self._loop.shutdown_asyncgens()) self._loop.close() asyncio.set_event_loop(None)
[docs] async def add_task(self, task: list[Any]) -> None: """Append a task to the queue.""" await self._queue.put([self._sync_worker, *task])
[docs] def make_promise(self, task: list[Any]) -> None: """A synchronous wrapper to add_task.""" asyncio.run(self.add_task(task))
[docs] async def run_tasks(self) -> list[Any]: """Create worker tasks and run.""" if not self._workers: self._executor = concurrent.futures.ProcessPoolExecutor( max_workers=self._n_workers, initializer=setup_logger, initargs=(self._params,) ) self._workers = [ asyncio.create_task(self._async_worker(self._queue, self._rqueue, self._executor)) for _ in range(self._n_workers) ] # Wait until all tasks are processed await self._queue.join() res = [] while not self._rqueue.empty(): res.append(await self._rqueue.get()) return res
[docs] def execute_promises(self) -> Any: """A synchronous wrapper to run_tasks.""" if not self._loop: err_msg = "AsIORunner has not been initialized." _logger.exception(err_msg) raise RuntimeError(err_msg) try: res = self._loop.run_until_complete(self.run_tasks()) except Exception: err_msg = "Error in AsIORunner while executing promises." _logger.exception(err_msg) raise else: return res
[docs] def n_workers(self) -> int: """Return the number of workers in the runner.""" return self._n_workers
[docs] class DaskRunner(BaseRunner): """A task runner class based on Dask. An runner that relies on dask to schedule a tasks concurently in workers. """ def __init__( self, params: dict, sync_wk: Callable, n_workers: int = 1, ): """Start the Dask cluster and client. Args: params: a dictionary with params sync_wk: a synchronous worker function async_wk: an asynchronous worker function n_workers: number of workers """ dask_dict = params.get("dask", {})
[docs] self.dask_backend = dask_dict.get("backend", "local")
self._n_workers: int = n_workers self._sync_worker = sync_wk self._tlist: list[Any] = [] if self.dask_backend == "local": self.client = Client(threads_per_worker=1, n_workers=self._n_workers) self.cluster = None elif self.dask_backend == "slurm": self.slurm_config_file = dask_dict.get("slurm_config_file", None) if self.slurm_config_file: if not Path(self.slurm_config_file).exists(): err_msg = f"Specified slurm_config_file do not exists: {self.slurm_config_file}" _logger.exception(err_msg) raise RunnerError(err_msg) config_file = ntpath.basename(self.slurm_config_file) shutil.move( self.slurm_config_file, f"~/.config/dask/{config_file}", ) self.cluster = SLURMCluster() else: self.dask_queue = dask_dict.get("queue", "regular") self.dask_ntasks = dask_dict.get("ntasks_per_job", 1) self.dask_ntasks_per_node = dask_dict.get("ntasks_per_node", self.dask_ntasks) self.dask_nworker_ncore = dask_dict.get("ncores_per_worker", 1) self.dask_prologue = dask_dict.get("job_prologue", []) self.dask_walltime = dask_dict.get("worker_walltime", "04:00:00") self.cluster = SLURMCluster( queue=self.dask_queue, cores=self.dask_nworker_ncore, memory="144GB", walltime=self.dask_walltime, processes=1, interface="ib0", job_script_prologue=self.dask_prologue, job_extra_directives=[ f"--ntasks={self.dask_ntasks}", f"--tasks-per-node={self.dask_ntasks_per_node}", "--exclusive", ], job_directives_skip=["--cpus-per-task=", "--mem"], ) self.cluster.scale(jobs=self._n_workers) self.client = Client(self.cluster) else: err_msg = f"Unknown [dask] backend: {self.dask_backend}" _logger.exception(err_msg) raise RunnerError(err_msg) # Setup the worker logging self.client.register_plugin(WorkerLoggerPlugin(params)) def __enter__(self) -> Self: """To enable use of with.""" return self def __exit__(self, *args: object) -> None: """Executed leaving with scope.""" if self.cluster: self.cluster.close() self.client.close()
[docs] def make_promise(self, task: list[Any]) -> None: """Append a task to the internal task list.""" self._tlist.append(dask.delayed(self._sync_worker)(*task))
[docs] def just_delay(self, obj: Any) -> Any: """Delay an object.""" return dask.delayed(obj)
[docs] def execute_promises(self) -> Any: """Execute a list of promises. Args: list_of_p: a list of dask promises Returns: A list with the return argument of each promised task. Raises: Exception if compute fails (raise internal error) """ try: res = list(dask.compute(*self._tlist)) except Exception: err_msg = "Error in DaskRunner while executing promises." _logger.exception(err_msg) raise else: self._tlist.clear() return res
[docs] def n_workers(self) -> int: """Return the number of workers in the runner.""" return self._n_workers
[docs] def get_runner_type(params: dict) -> type[BaseRunner]: """Create an engine from parameters.""" runner_map = { "dask": DaskRunner, "asyncio": AsIORunner, } runner_str = params.get("runner", {}).get("type").lower() if runner_str not in runner_map: err_msg = f"Unable to get {runner_str} runner." _logger.exception(err_msg) raise RunnerError(err_msg) return runner_map[runner_str]