from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import ntpath
import shutil
from abc import ABCMeta
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
import dask
from dask.distributed import Client
from dask.distributed import WorkerPlugin
from dask_jobqueue import SLURMCluster
from typing_extensions import Self
from pyrevs.utils import setup_logger
from .worker import worker_async
if TYPE_CHECKING:
from collections.abc import Callable
from dask.distributed import Worker
from .config import RunnerConfig
_logger = logging.getLogger(__name__)
[docs]
class WorkerLoggerPlugin(WorkerPlugin):
"""A plugin to configure logging on each worker."""
def __init__(self, loglevel: str, logfile: str | None) -> None:
"""Init function pass in the params dict."""
self._loglevel = loglevel
self._logfile = logfile
[docs]
def setup(self, worker: Worker) -> None:
"""Configure logging on the worker.
Args:
worker: the dask worker
"""
# Configure logging on each worker
_ = worker
setup_logger(self._loglevel, self._logfile)
[docs]
class RunnerError(Exception):
"""Exception class for the runner."""
[docs]
class BaseRunner(metaclass=ABCMeta):
"""An ABC for the task runners."""
@abstractmethod
def __init__(
self,
runner_cfg: RunnerConfig,
worker_fn: Callable,
n_workers: int = 1,
loglevel: str = "INFO",
logfile: str | None = None,
):
"""A dummy init method."""
@abstractmethod
def __enter__(self) -> BaseRunner:
"""To enable use of with."""
@abstractmethod
def __exit__(self, *args: object) -> None:
"""Executed leaving with scope."""
@abstractmethod
[docs]
def make_promise(self, task: list[Any]) -> None:
"""Log a new task to the list of task to tackle."""
@abstractmethod
[docs]
def execute_promises(self) -> Any:
"""Execute the list of promises."""
@abstractmethod
[docs]
def n_workers(self) -> int:
"""Return the number of workers in the runner."""
[docs]
class AsIORunner(BaseRunner):
"""A task runner class based on asyncIO.
An runner that relies on asyncio to schedule
tasks concurrently in worker processes.
Tasks are added to an internal queue from
which worker can take them and put the results
back into result queue.
Attributes:
_queue: an asyncio.Queue() to place the tasks in
_rqueue: an asyncio.Queue() where the results are returned
_n_workers: the number of workers in the runner
_sync_worker: the synchrone worker function
_async_worker: the asynchrone worker function
_loop: the event loop associated with the workers
_executor: an executor for the worker to work in
_workers: a list of worker tasks
"""
def __init__(
self,
runner_cfg: RunnerConfig,
worker_fn: Callable,
n_workers: int = 1,
loglevel: str = "INFO",
logfile: str | None = None,
):
"""Init the asyncio task runner.
Args:
runner_cfg: a RunnerConfig dataclass
worker_fn: a synchronous worker function
async_wk: an asynchronous worker function
n_workers: number of workers
loglevel: optional logging level
logfile: optional logging file
"""
_ = runner_cfg
self._loglevel = loglevel
self._logfile = logfile
self._queue: asyncio.Queue[Any] = asyncio.Queue()
self._rqueue: asyncio.Queue[Any] = asyncio.Queue()
self._n_workers: int = n_workers
self._sync_worker = worker_fn
self._async_worker = worker_async
self._loop: asyncio.AbstractEventLoop | None = None
self._executor: concurrent.futures.Executor | None = None
self._workers: list[asyncio.Task[Any]] | None = None
def __enter__(self) -> Self:
"""To enable use of with."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
return self
def __exit__(self, *args: object) -> None:
"""Executed leaving with scope."""
if self._workers:
for w in self._workers:
w.cancel()
if self._executor:
self._executor.shutdown()
if self._loop:
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
self._loop.close()
asyncio.set_event_loop(None)
[docs]
async def add_task(self, task: list[Any]) -> None:
"""Append a task to the queue."""
await self._queue.put([self._sync_worker, *task])
[docs]
def make_promise(self, task: list[Any]) -> None:
"""A synchronous wrapper to add_task."""
asyncio.run(self.add_task(task))
[docs]
async def run_tasks(self) -> list[Any]:
"""Create worker tasks and run.
Initialize the executor and setup the workers (tasks) if not
already done.
The join() task is created seperately and awaited with the others in
order to catch any exception coming from the workers as they are generated
and stop everything as soon as one task fails.
"""
if not self._workers:
self._executor = concurrent.futures.ProcessPoolExecutor(
max_workers=self._n_workers, initializer=setup_logger, initargs=(self._loglevel, self._logfile)
)
self._workers = [
asyncio.create_task(self._async_worker(self._queue, self._rqueue, self._executor))
for _ in range(self._n_workers)
]
# Create a separate task for the join()
# and check the tasks status as they are completed
join_task = asyncio.create_task(self._queue.join())
done, _ = await asyncio.wait([join_task, *self._workers], return_when=asyncio.FIRST_COMPLETED)
# If a task raise an exception, cancel all other tasks
# and re-raise.
for task in done:
if task != join_task and task.exception():
excep = task.exception()
for t in self._workers:
if not t.done():
t.cancel()
join_task.cancel()
if excep is None:
err_msg = "Caught an 'odd' exception in tasks !"
raise RunnerError(err_msg)
raise excep
# Keep assembling the results list
res = []
while not self._rqueue.empty():
res.append(await self._rqueue.get())
return res
[docs]
def execute_promises(self) -> Any:
"""A synchronous wrapper to run_tasks."""
if not self._loop:
err_msg = "AsIORunner has not been initialized."
_logger.exception(err_msg)
raise RuntimeError(err_msg)
try:
res = self._loop.run_until_complete(self.run_tasks())
except Exception:
err_msg = "Error in AsIORunner while executing promises."
_logger.exception(err_msg)
raise
else:
return res
[docs]
def n_workers(self) -> int:
"""Return the number of workers in the runner."""
return self._n_workers
[docs]
class DaskRunner(BaseRunner):
"""A task runner class based on Dask.
An runner that relies on dask to schedule
a tasks concurrently in workers.
"""
def __init__(
self,
runner_cfg: RunnerConfig,
worker_fn: Callable,
n_workers: int = 1,
loglevel: str = "INFO",
logfile: str | None = None,
):
"""Start the Dask cluster and client.
Args:
runner_cfg: a RunnerConfig dataclass
worker_fn: a synchronous worker function
async_wk: an asynchronous worker function
n_workers: number of workers
loglevel: optional logging level
logfile: optional logging file
"""
dask_cfg = runner_cfg.dask_config
self._n_workers: int = n_workers
self._sync_worker = worker_fn
self._tlist: list[Any] = []
backend = dask_cfg.backend
if backend == "local":
self.client = Client(threads_per_worker=1, n_workers=self._n_workers)
self.cluster = None
elif backend == "slurm":
self.slurm_config_file = dask_cfg.slurm_config_file
if self.slurm_config_file:
if not Path(self.slurm_config_file).exists():
err_msg = f"Specified slurm_config_file do not exists: {self.slurm_config_file}"
_logger.exception(err_msg)
raise FileNotFoundError(err_msg)
config_file = ntpath.basename(self.slurm_config_file)
shutil.move(
self.slurm_config_file,
f"~/.config/dask/{config_file}",
)
self.cluster = SLURMCluster()
else:
ntasks_per_node = dask_cfg.ntasks_per_node if dask_cfg.ntasks_per_node > 0 else dask_cfg.ntasks_per_job
self.cluster = SLURMCluster(
queue=dask_cfg.queue,
cores=dask_cfg.ncores_per_worker,
memory="144GB",
walltime=dask_cfg.worker_walltime,
processes=1,
interface="ib0",
job_script_prologue=dask_cfg.job_prologue,
job_extra_directives=[
f"--ntasks={dask_cfg.ntasks_per_job}",
f"--tasks-per-node={ntasks_per_node}",
"--exclusive",
],
job_directives_skip=["--cpus-per-task=", "--mem"],
)
self.cluster.scale(jobs=self._n_workers)
self.client = Client(self.cluster)
else:
err_msg = f"Unknown [dask] backend: {backend}"
_logger.exception(err_msg)
raise RunnerError(err_msg)
# Setup the worker logging
self.client.register_plugin(WorkerLoggerPlugin(loglevel, logfile))
def __enter__(self) -> Self:
"""To enable use of with."""
return self
def __exit__(self, *args: object) -> None:
"""Executed leaving with scope."""
if self.cluster:
self.cluster.close()
self.client.close()
[docs]
def make_promise(self, task: list[Any]) -> None:
"""Append a task to the internal task list."""
self._tlist.append(dask.delayed(self._sync_worker)(*task))
[docs]
def just_delay(self, obj: Any) -> Any:
"""Delay an object."""
return dask.delayed(obj)
[docs]
def execute_promises(self) -> Any:
"""Execute a list of promises.
Args:
list_of_p: a list of dask promises
Returns:
A list with the return argument of each promised task.
Raises:
Exception if compute fails (raise internal error)
"""
try:
res = list(dask.compute(*self._tlist))
except Exception:
err_msg = "Error in DaskRunner while executing promises."
_logger.exception(err_msg)
raise
else:
self._tlist.clear()
return res
[docs]
def n_workers(self) -> int:
"""Return the number of workers in the runner."""
return self._n_workers
[docs]
def make_runner(
runner_cfg: RunnerConfig,
worker_fn: Callable,
loglevel: str = "INFO",
logfile: str | None = None,
max_workers: int = -1,
) -> BaseRunner:
"""Factory that instantiates a configured runner.
Args:
runner_cfg: a config mapping for the runner
worker_fn: a worker function
loglevel: logging level
logfile: logging file
max_workers: maximum number of workers
"""
runner_type = runner_cfg.type.lower()
runner_map: dict[str, type[BaseRunner]] = {
"asyncio": AsIORunner,
"dask": DaskRunner,
}
n_workers = min(runner_cfg.nworkers, max_workers) if max_workers > 0 else runner_cfg.nworkers
if runner_type not in runner_map:
err_msg = f"Unknown runner type: {runner_type}"
_logger.exception(err_msg)
raise ValueError(err_msg)
runner_cls = runner_map[runner_type]
return runner_cls(
runner_cfg=runner_cfg,
worker_fn=worker_fn,
n_workers=n_workers,
loglevel=loglevel,
logfile=logfile,
)