diff --git a/joblib/parallel.py b/joblib/parallel.py
index 77395c4..b95dbb2 100644
--- a/joblib/parallel.py
+++ b/joblib/parallel.py
@@ -38,7 +38,12 @@ _backend = threading.local()
def _register_dask():
"""Register Dask Backend if called with parallel_config(backend="dask")"""
- pass
+ try:
+ from dask.distributed import Client
+ from ._dask import DaskDistributedBackend
+ register_parallel_backend('dask', DaskDistributedBackend)
+ except ImportError:
+ raise ImportError("To use the dask backend, you need to install 'dask' and 'distributed' libraries.")
EXTERNAL_BACKENDS = {'dask': _register_dask}
default_parallel_config = {'backend': _Sentinel(default_value=None), 'n_jobs': _Sentinel(default_value=None), 'verbose': _Sentinel(default_value=0), 'temp_folder': _Sentinel(default_value=None), 'max_nbytes': _Sentinel(default_value='1M'), 'mmap_mode': _Sentinel(default_value='r'), 'prefer': _Sentinel(default_value=None), 'require': _Sentinel(default_value=None)}
VALID_BACKEND_HINTS = ('processes', 'threads', None)
@@ -50,15 +55,46 @@ def _get_config_param(param, context_config, key):
Explicitly setting it in Parallel has priority over setting in a
parallel_(config/backend) context manager.
"""
- pass
+ if param is not _Sentinel():
+ return param
+ elif context_config is not None:
+ return context_config.get(key, default_parallel_config[key])
+ else:
+ return default_parallel_config[key]
def get_active_backend(prefer=default_parallel_config['prefer'], require=default_parallel_config['require'], verbose=default_parallel_config['verbose']):
"""Return the active default backend"""
- pass
+ return _get_active_backend(prefer, require, verbose)[0]
def _get_active_backend(prefer=default_parallel_config['prefer'], require=default_parallel_config['require'], verbose=default_parallel_config['verbose']):
"""Return the active default backend"""
- pass
+ config = getattr(_backend, 'config', default_parallel_config)
+ backend = config.get('backend', None)
+
+ if backend is None:
+ if require is not None:
+ if require not in VALID_BACKEND_CONSTRAINTS:
+ raise ValueError(f"Invalid 'require' parameter: {require}")
+ if require == 'sharedmem':
+ backend = ThreadingBackend()
+ else:
+ backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+ elif prefer is not None:
+ if prefer not in VALID_BACKEND_HINTS:
+ raise ValueError(f"Invalid 'prefer' parameter: {prefer}")
+ if prefer == 'processes':
+ backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+ elif prefer == 'threads':
+ backend = ThreadingBackend()
+ else:
+ backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+ else:
+ backend = BACKENDS[DEFAULT_BACKEND]()
+
+ if verbose > 0:
+ print(f"Using {backend.__class__.__name__} as parallel backend")
+
+ return backend, config
class parallel_config:
"""Set the default backend or configuration for :class:`~joblib.Parallel`.
@@ -351,7 +387,7 @@ def cpu_count(only_physical_cores=False):
If only_physical_cores is True, do not take hyperthreading / SMT logical
cores into account.
"""
- pass
+ return loky.cpu_count(only_physical_cores=only_physical_cores)
def _verbosity_filter(index, verbose):
""" Returns False for indices increasingly apart, the distance
@@ -359,11 +395,16 @@ def _verbosity_filter(index, verbose):
We use a lag increasing as the square of index
"""
- pass
+ if not verbose:
+ return True
+ return (index % int(sqrt(index) * 10 * verbose) == 0)
def delayed(function):
"""Decorator used to capture the arguments of a function."""
- pass
+ @functools.wraps(function)
+ def delayed_function(*args, **kwargs):
+ return lambda: function(*args, **kwargs)
+ return delayed_function
class BatchCompletionCallBack(object):
"""Callback to keep track of completed results and schedule the next tasks.
@@ -393,7 +434,7 @@ class BatchCompletionCallBack(object):
def register_job(self, job):
"""Register the object returned by `apply_async`."""
- pass
+ self.job = job
def get_result(self, timeout):
"""Returns the raw result of the task that was submitted.
@@ -412,7 +453,13 @@ class BatchCompletionCallBack(object):
return it or raise. It will block at most `self.timeout` seconds
waiting for retrieval to complete, after that it raises a TimeoutError.
"""
- pass
+ def get_result(self, timeout):
+ if self.parallel._backend.supports_retrieve_callback:
+ if self.get_status(timeout) == TASK_PENDING:
+ raise TimeoutError("Timeout reached while waiting for result")
+ return self.result
+ else:
+ return self.job.get(timeout=timeout)
def get_status(self, timeout):
"""Get the status of the task.
@@ -420,7 +467,14 @@ class BatchCompletionCallBack(object):
This function also checks if the timeout has been reached and register
the TimeoutError outcome when it is the case.
"""
- pass
+ if self.parallel._backend.supports_retrieve_callback:
+ return self.status
+ else:
+ try:
+ self.job.wait(timeout=timeout)
+ return TASK_DONE
+ except TimeoutError:
+ return TASK_PENDING
def __call__(self, out):
"""Function called by the callback thread after a job is completed."""
@@ -440,7 +494,12 @@ class BatchCompletionCallBack(object):
def _dispatch_new(self):
"""Schedule the next batch of tasks to be processed."""
- pass
+ with self.parallel._lock:
+ if self.parallel._aborting:
+ return
+ if self.parallel._original_iterator is None:
+ return
+ self.parallel.dispatch_next()
def _retrieve_result(self, out):
"""Fetch and register the outcome of a task.
@@ -449,14 +508,27 @@ class BatchCompletionCallBack(object):
This function is only called by backends that support retrieving
the task result in the callback thread.
"""
- pass
+ success, result = out
+ self.result = result
+ if success:
+ self.status = TASK_DONE
+ return True
+ else:
+ self.status = TASK_ERROR
+ return False
def _register_outcome(self, outcome):
"""Register the outcome of a task.
This method can be called only once, future calls will be ignored.
"""
- pass
+ if self.status is not TASK_PENDING:
+ return
+ self.result = outcome
+ if isinstance(outcome, BaseException):
+ self.status = TASK_ERROR
+ else:
+ self.status = TASK_DONE
def register_parallel_backend(name, factory, make_default=False):
"""Register a new Parallel backend factory.
@@ -473,7 +545,10 @@ def register_parallel_backend(name, factory, make_default=False):
.. versionadded:: 0.10
"""
- pass
+ BACKENDS[name] = factory
+ if make_default:
+ global DEFAULT_BACKEND
+ DEFAULT_BACKEND = name
def effective_n_jobs(n_jobs=-1):
"""Determine the number of jobs that can actually run in parallel
@@ -496,7 +571,11 @@ def effective_n_jobs(n_jobs=-1):
.. versionadded:: 0.10
"""
- pass
+ if n_jobs == 0:
+ raise ValueError('n_jobs == 0 in Parallel has no meaning')
+ elif n_jobs < 0:
+ n_jobs = max(cpu_count() + 1 + n_jobs, 1)
+ return n_jobs
class Parallel(Logger):
""" Helper class for readable parallel mapping.
@@ -832,16 +911,33 @@ class Parallel(Logger):
def _initialize_backend(self):
"""Build a process or thread pool and return the number of workers"""
- pass
+ if self._backend is None:
+ self._backend, _ = _get_active_backend(
+ prefer=self._backend_args['prefer'],
+ require=self._backend_args['require'],
+ verbose=self._backend_args['verbose']
+ )
+
+ n_jobs = self.n_jobs
+ if n_jobs == 0:
+ raise ValueError('n_jobs == 0 in Parallel has no meaning')
+ elif n_jobs < 0:
+ n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1)
+
+ self._n_jobs = n_jobs
+ self._backend.configure(n_jobs=self._n_jobs, parallel=self, **self._backend_args)
+ return n_jobs
def _dispatch(self, batch):
"""Queue the batch for computing, with or without multiprocessing
WARNING: this method is not thread-safe: it should be only called
indirectly via dispatch_one_batch.
-
"""
- pass
+ if self._backend.supports_retrieve_callback:
+ self._backend.apply_async(batch, callback=self._callback)
+ else:
+ self._pending_outputs.append(self._backend.apply_async(batch))
def dispatch_next(self):
"""Dispatch more data for parallel processing
@@ -849,9 +945,12 @@ class Parallel(Logger):
This method is meant to be called concurrently by the multiprocessing
callback. We rely on the thread-safety of dispatch_one_batch to protect
against concurrent consumption of the unprotected iterator.
-
"""
- pass
+ if self.dispatch_one_batch(self._original_iterator):
+ return True
+ else:
+ self._iterating = False
+ return False
def dispatch_one_batch(self, iterator):
"""Prefetch the tasks for the next batch and dispatch them.
@@ -863,41 +962,69 @@ class Parallel(Logger):
lock so calling this function should be thread safe.
"""
- pass
-
- def _get_batch_size(self):
- """Returns the effective batch size for dispatch"""
- pass
def _print(self, msg):
- """Display the message on stout or stderr depending on verbosity"""
- pass
+ """Display the message on stdout or stderr depending on verbosity"""
+ if self.verbose > 10:
+ print(msg, file=sys.stderr)
+ elif self.verbose:
+ print(msg)
def _is_completed(self):
"""Check if all tasks have been completed"""
- pass
+ return self._n_completed_tasks == self._n_dispatched_tasks
def print_progress(self):
- """Display the process of the parallel execution only a fraction
+ """Display the progress of the parallel execution only a fraction
of time, controlled by self.verbose.
"""
- pass
+ if not self.verbose:
+ return
+ elapsed_time = time.time() - self._start_time
+
+ # This is heuristic code to print only 'verbose' times a message
+ if self._n_completed_tasks > self._n_dispatched_tasks // self.verbose:
+ # We are done dispatching
+ self._print('Done %3i tasks | elapsed: %s' %
+ (self._n_completed_tasks,
+ short_format_time(elapsed_time)))
+ else:
+ # We are not done dispatching
+ self._print('Completed %3i tasks | elapsed: %s' %
+ (self._n_completed_tasks,
+ short_format_time(elapsed_time)))
def _get_outputs(self, iterator, pre_dispatch):
"""Iterator returning the tasks' output as soon as they are ready."""
- pass
+ while self._iterating or self._n_completed_tasks < self._n_dispatched_tasks:
+ if self._n_completed_tasks < self._n_dispatched_tasks:
+ yield self._backend.retrieve()
+ self._n_completed_tasks += 1
+ self.print_progress()
+ else:
+ if self._iterating:
+ self.dispatch_next()
+ else:
+ break
+ if self._aborting:
+ self._raise_error_fast()
def _wait_retrieval(self):
"""Return True if we need to continue retrieving some tasks."""
- pass
+ return (self._n_completed_tasks < self._n_dispatched_tasks or
+ not self._backend.supports_retrieve_callback)
def _raise_error_fast(self):
"""If we are aborting, raise if a job caused an error."""
- pass
+ if self._aborting:
+ error = self._backend.get_first_exception()
+ if error is not None:
+ raise error
def _warn_exit_early(self):
- """Warn the user if the generator is gc'ed before being consumned."""
- pass
+ """Warn the user if the generator is gc'ed before being consumed."""
+ warnings.warn("Parallel generator was garbage collected before being fully consumed. "
+ "Some tasks may not have been executed.", UserWarning)
def _get_sequential_output(self, iterable):
"""Separate loop for sequential output.
@@ -905,11 +1032,20 @@ class Parallel(Logger):
This simplifies the traceback in case of errors and reduces the
overhead of calling sequential tasks with `joblib`.
"""
- pass
+ self._start_time = time.time()
+ for i, func in enumerate(iterable):
+ self._print('Done %3i tasks | elapsed: %s' %
+ (i, short_format_time(time.time() - self._start_time)))
+ yield func()
+ self._n_completed_tasks += 1
def _reset_run_tracking(self):
"""Reset the counters and flags used to track the execution."""
- pass
+ self._n_completed_tasks = 0
+ self._n_dispatched_tasks = 0
+ self._n_dispatched_batches = 0
+ self._aborting = False
+ self._iterating = True
def __call__(self, iterable):
"""Main function to dispatch parallel tasks."""
@@ -957,4 +1093,4 @@ class Parallel(Logger):
return output if self.return_generator else list(output)
def __repr__(self):
- return '%s(n_jobs=%s)' % (self.__class__.__name__, self.n_jobs)
\ No newline at end of file
+ return '%s(n_jobs=%s)' % (self.__class__.__name__, self.n_jobs)