Source code for paidiverpy.utils.parallellisation
"""Module for parallelisation utilities."""
import logging
import multiprocessing
import os
import dask
import dask.config
from dask.distributed import Client
from dask.distributed import LocalCluster
from paidiverpy.models.general_config import GeneralConfig
logger = logging.getLogger("paidiverpy")
[docs]
def get_n_jobs(n_jobs: int) -> int:
"""Determine the number of jobs based on n_jobs parameter.
Uses SLURM_CPUS_ON_NODE when inside a Slurm allocation so that only the
CPUs actually allocated to the job are used, not all CPUs visible on the node.
Args:
n_jobs (int): The number of n_jobs.
Returns:
int: The number of jobs to use.
"""
available = int(os.environ.get("SLURM_CPUS_ON_NODE") or multiprocessing.cpu_count())
if n_jobs == -1:
return available
if n_jobs > 1:
return min(n_jobs, available)
return 1
[docs]
def update_dask_config(dask_config_kwargs: dict) -> None:
"""Update the Dask configuration.
Args:
dask_config_kwargs (dict): Dask configuration keyword arguments.
"""
if dask_config_kwargs is not None:
dask.config.set(dask_config_kwargs)
logger.info("Updated dask configuration settings")
[docs]
def parse_parallellisation_params(config: GeneralConfig | None) -> Client | None:
"""Parse the client configuration.
Args:
config (GeneralConfig | None): Client configuration.
Returns:
dask.distributed.Client | None: Dask client or None if no client is configured.
"""
config_client = config.local_cluster if config else None
dask_config_kwargs = config.dask_config_kwargs if config else None
update_dask_config(dask_config_kwargs)
if config_client is None:
return None
cluster = LocalCluster(**config_client)
n_jobs = config.n_jobs if config else 1
cluster.scale(n_jobs)
client = Client(cluster)
logger.info("Created LocalCluster with Client: %s", client.dashboard_link)
return client