back to Reference (Gold) summary
Reference (Gold): joblib
Pytest Summary for test test
status | count |
---|---|
passed | 1441 |
skipped | 35 |
xpassed | 4 |
failed | 2 |
total | 1482 |
collected | 1482 |
Failed pytests:
test_memmapping.py::test_child_raises_parent_exits_cleanly[multiprocessing]
test_memmapping.py::test_child_raises_parent_exits_cleanly[multiprocessing]
test_parallel.py::test_main_thread_renamed_no_warning[multiprocessing]
test_parallel.py::test_main_thread_renamed_no_warning[multiprocessing]
backend = 'multiprocessing' monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x7fc0ecd00fe0> @parametrize('backend', ALL_VALID_BACKENDS) def test_main_thread_renamed_no_warning(backend, monkeypatch): # Check that no default backend relies on the name of the main thread: # https://github.com/joblib/joblib/issues/180#issuecomment-253266247 # Some programs use a different name for the main thread. This is the case # for uWSGI apps for instance. monkeypatch.setattr(target=threading.current_thread(), name='name', value='some_new_name_for_the_main_thread') with warnings.catch_warnings(record=True) as warninfo: results = Parallel(n_jobs=2, backend=backend)( delayed(square)(x) for x in range(3)) assert results == [0, 1, 4] # Due to the default parameters of LokyBackend, there is a chance that # warninfo catches Warnings from worker timeouts. We remove it if it exists warninfo = [w for w in warninfo if "worker timeout" not in str(w.message)] # The multiprocessing backend will raise a warning when detecting that is # started from the non-main thread. Let's check that there is no false # positive because of the name change. > assert len(warninfo) == 0 E assert 2 == 0 E + where 2 = len([, ]) joblib/test/test_parallel.py:199: AssertionError
test_parallel.py::test_main_thread_renamed_no_warning[backend7]
test_parallel.py::test_main_thread_renamed_no_warning[backend7]
backend =monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x7fc0ecd57d40> @parametrize('backend', ALL_VALID_BACKENDS) def test_main_thread_renamed_no_warning(backend, monkeypatch): # Check that no default backend relies on the name of the main thread: # https://github.com/joblib/joblib/issues/180#issuecomment-253266247 # Some programs use a different name for the main thread. This is the case # for uWSGI apps for instance. monkeypatch.setattr(target=threading.current_thread(), name='name', value='some_new_name_for_the_main_thread') with warnings.catch_warnings(record=True) as warninfo: results = Parallel(n_jobs=2, backend=backend)( delayed(square)(x) for x in range(3)) assert results == [0, 1, 4] # Due to the default parameters of LokyBackend, there is a chance that # warninfo catches Warnings from worker timeouts. We remove it if it exists warninfo = [w for w in warninfo if "worker timeout" not in str(w.message)] # The multiprocessing backend will raise a warning when detecting that is # started from the non-main thread. Let's check that there is no false # positive because of the name change. > assert len(warninfo) == 0 E assert 2 == 0 E + where 2 = len([ , ]) joblib/test/test_parallel.py:199: AssertionError
test_parallel.py::test_nested_exception_dispatch[multiprocessing]
test_parallel.py::test_nested_exception_dispatch[multiprocessing]
test_parallel.py::test_nested_exception_dispatch[loky]
test_parallel.py::test_nested_exception_dispatch[loky]
test_parallel.py::test_nested_exception_dispatch[threading]
test_parallel.py::test_nested_exception_dispatch[threading]
Patch diff
diff --git a/joblib/_cloudpickle_wrapper.py b/joblib/_cloudpickle_wrapper.py
index 78a1b36..daf899d 100644
--- a/joblib/_cloudpickle_wrapper.py
+++ b/joblib/_cloudpickle_wrapper.py
@@ -2,9 +2,18 @@
Small shim of loky's cloudpickle_wrapper to avoid failure when
multiprocessing is not available.
"""
+
+
from ._multiprocessing_helpers import mp
+
+
+def _my_wrap_non_picklable_objects(obj, keep_wrapper=True):
+ return obj
+
+
if mp is not None:
from .externals.loky import wrap_non_picklable_objects
else:
wrap_non_picklable_objects = _my_wrap_non_picklable_objects
-__all__ = ['wrap_non_picklable_objects']
+
+__all__ = ["wrap_non_picklable_objects"]
diff --git a/joblib/_dask.py b/joblib/_dask.py
index 726f453..4288ed0 100644
--- a/joblib/_dask.py
+++ b/joblib/_dask.py
@@ -1,30 +1,56 @@
from __future__ import print_function, division, absolute_import
+
import asyncio
import concurrent.futures
import contextlib
+
import time
from uuid import uuid4
import weakref
+
from .parallel import parallel_config
from .parallel import AutoBatchingMixin, ParallelBackendBase
-from ._utils import _TracebackCapturingWrapper, _retrieve_traceback_capturing_wrapped_call
+
+from ._utils import (
+ _TracebackCapturingWrapper,
+ _retrieve_traceback_capturing_wrapped_call
+)
+
try:
import dask
import distributed
except ImportError:
dask = None
distributed = None
+
if dask is not None and distributed is not None:
from dask.utils import funcname
from dask.sizeof import sizeof
- from dask.distributed import Client, as_completed, get_client, secede, rejoin
+ from dask.distributed import (
+ Client,
+ as_completed,
+ get_client,
+ secede,
+ rejoin,
+ )
from distributed.utils import thread_state
+
try:
+ # asyncio.TimeoutError, Python3-only error thrown by recent versions of
+ # distributed
from distributed.utils import TimeoutError as _TimeoutError
except ImportError:
from tornado.gen import TimeoutError as _TimeoutError
+def is_weakrefable(obj):
+ try:
+ weakref.ref(obj)
+ return True
+ except TypeError:
+ return False
+
+
class _WeakKeyDictionary:
"""A variant of weakref.WeakKeyDictionary for unhashable objects.
@@ -42,6 +68,7 @@ class _WeakKeyDictionary:
def __getitem__(self, obj):
ref, val = self._data[id(obj)]
if ref() is not obj:
+ # In case of a race condition with on_destroy.
raise KeyError(obj)
return val
@@ -50,9 +77,12 @@ class _WeakKeyDictionary:
try:
ref, _ = self._data[key]
if ref() is not obj:
+ # In case of race condition with on_destroy.
raise KeyError(obj)
except KeyError:
-
+ # Insert the new entry in the mapping along with a weakref
+ # callback to automatically delete the entry from the mapping
+ # as soon as the object used as key is garbage collected.
def on_destroy(_):
del self._data[key]
ref = weakref.ref(obj, on_destroy)
@@ -61,18 +91,38 @@ class _WeakKeyDictionary:
def __len__(self):
return len(self._data)
+ def clear(self):
+ self._data.clear()
+
+
+def _funcname(x):
+ try:
+ if isinstance(x, list):
+ x = x[0][0]
+ except Exception:
+ pass
+ return funcname(x)
+
def _make_tasks_summary(tasks):
"""Summarize of list of (func, args, kwargs) function calls"""
- pass
+ unique_funcs = {func for func, args, kwargs in tasks}
+
+ if len(unique_funcs) == 1:
+ mixed = False
+ else:
+ mixed = True
+ return len(tasks), mixed, _funcname(tasks)
class Batch:
"""dask-compatible wrapper that executes a batch of tasks"""
-
def __init__(self, tasks):
+ # collect some metadata from the tasks to ease Batch calls
+ # introspection when debugging
self._num_tasks, self._mixed, self._funcname = _make_tasks_summary(
- tasks)
+ tasks
+ )
def __call__(self, tasks=None):
results = []
@@ -82,46 +132,58 @@ class Batch:
return results
def __repr__(self):
- descr = f'batch_of_{self._funcname}_{self._num_tasks}_calls'
+ descr = f"batch_of_{self._funcname}_{self._num_tasks}_calls"
if self._mixed:
- descr = 'mixed_' + descr
+ descr = "mixed_" + descr
return descr
+def _joblib_probe_task():
+ # Noop used by the joblib connector to probe when workers are ready.
+ pass
+
+
class DaskDistributedBackend(AutoBatchingMixin, ParallelBackendBase):
MIN_IDEAL_BATCH_DURATION = 0.2
MAX_IDEAL_BATCH_DURATION = 1.0
supports_retrieve_callback = True
default_n_jobs = -1
- def __init__(self, scheduler_host=None, scatter=None, client=None, loop
- =None, wait_for_workers_timeout=10, **submit_kwargs):
+ def __init__(self, scheduler_host=None, scatter=None,
+ client=None, loop=None, wait_for_workers_timeout=10,
+ **submit_kwargs):
super().__init__()
+
if distributed is None:
- msg = (
- "You are trying to use 'dask' as a joblib parallel backend but dask is not installed. Please install dask to fix this error."
- )
+ msg = ("You are trying to use 'dask' as a joblib parallel backend "
+ "but dask is not installed. Please install dask "
+ "to fix this error.")
raise ValueError(msg)
+
if client is None:
if scheduler_host:
- client = Client(scheduler_host, loop=loop, set_as_default=False
- )
+ client = Client(scheduler_host, loop=loop,
+ set_as_default=False)
else:
try:
client = get_client()
except ValueError as e:
- msg = """To use Joblib with Dask first create a Dask Client
-
- from dask.distributed import Client
- client = Client()
-or
- client = Client('scheduler-address:8786')"""
+ msg = ("To use Joblib with Dask first create a Dask Client"
+ "\n\n"
+ " from dask.distributed import Client\n"
+ " client = Client()\n"
+ "or\n"
+ " client = Client('scheduler-address:8786')")
raise ValueError(msg) from e
+
self.client = client
+
if scatter is not None and not isinstance(scatter, (list, tuple)):
- raise TypeError('scatter must be a list/tuple, got `%s`' % type
- (scatter).__name__)
+ raise TypeError("scatter must be a list/tuple, got "
+ "`%s`" % type(scatter).__name__)
+
if scatter is not None and len(scatter) > 0:
+ # Keep a reference to the scattered data to keep the ids the same
self._scatter = list(scatter)
scattered = self.client.scatter(scatter, broadcast=True)
self.data_futures = {id(x): f for x, f in zip(scatter, scattered)}
@@ -130,20 +192,173 @@ or
self.data_futures = {}
self.wait_for_workers_timeout = wait_for_workers_timeout
self.submit_kwargs = submit_kwargs
- self.waiting_futures = as_completed([], loop=client.loop,
- with_results=True, raise_errors=False)
+ self.waiting_futures = as_completed(
+ [],
+ loop=client.loop,
+ with_results=True,
+ raise_errors=False
+ )
self._results = {}
self._callbacks = {}
+ async def _collect(self):
+ while self._continue:
+ async for future, result in self.waiting_futures:
+ cf_future = self._results.pop(future)
+ callback = self._callbacks.pop(future)
+ if future.status == "error":
+ typ, exc, tb = result
+ cf_future.set_exception(exc)
+ else:
+ cf_future.set_result(result)
+ callback(result)
+ await asyncio.sleep(0.01)
+
def __reduce__(self):
- return DaskDistributedBackend, ()
+ return (DaskDistributedBackend, ())
+
+ def get_nested_backend(self):
+ return DaskDistributedBackend(client=self.client), -1
+
+ def configure(self, n_jobs=1, parallel=None, **backend_args):
+ self.parallel = parallel
+ return self.effective_n_jobs(n_jobs)
+
+ def start_call(self):
+ self._continue = True
+ self.client.loop.add_callback(self._collect)
+ self.call_data_futures = _WeakKeyDictionary()
+
+ def stop_call(self):
+ # The explicit call to clear is required to break a cycling reference
+ # to the futures.
+ self._continue = False
+ # wait for the future collection routine (self._backend._collect) to
+ # finish in order to limit asyncio warnings due to aborting _collect
+ # during a following backend termination call
+ time.sleep(0.01)
+ self.call_data_futures.clear()
+
+ def effective_n_jobs(self, n_jobs):
+ effective_n_jobs = sum(self.client.ncores().values())
+ if effective_n_jobs != 0 or not self.wait_for_workers_timeout:
+ return effective_n_jobs
+
+ # If there is no worker, schedule a probe task to wait for the workers
+ # to come up and be available. If the dask cluster is in adaptive mode
+ # task might cause the cluster to provision some workers.
+ try:
+ self.client.submit(_joblib_probe_task).result(
+ timeout=self.wait_for_workers_timeout
+ )
+ except _TimeoutError as e:
+ error_msg = (
+ "DaskDistributedBackend has no worker after {} seconds. "
+ "Make sure that workers are started and can properly connect "
+ "to the scheduler and increase the joblib/dask connection "
+ "timeout with:\n\n"
+ "parallel_config(backend='dask', wait_for_workers_timeout={})"
+ ).format(self.wait_for_workers_timeout,
+ max(10, 2 * self.wait_for_workers_timeout))
+ raise TimeoutError(error_msg) from e
+ return sum(self.client.ncores().values())
+
+ async def _to_func_args(self, func):
+ itemgetters = dict()
+
+ # Futures that are dynamically generated during a single call to
+ # Parallel.__call__.
+ call_data_futures = getattr(self, 'call_data_futures', None)
+
+ async def maybe_to_futures(args):
+ out = []
+ for arg in args:
+ arg_id = id(arg)
+ if arg_id in itemgetters:
+ out.append(itemgetters[arg_id])
+ continue
+
+ f = self.data_futures.get(arg_id, None)
+ if f is None and call_data_futures is not None:
+ try:
+ f = await call_data_futures[arg]
+ except KeyError:
+ pass
+ if f is None:
+ if is_weakrefable(arg) and sizeof(arg) > 1e3:
+ # Automatically scatter large objects to some of
+ # the workers to avoid duplicated data transfers.
+ # Rely on automated inter-worker data stealing if
+ # more workers need to reuse this data
+ # concurrently.
+ # set hash=False - nested scatter calls (i.e
+ # calling client.scatter inside a dask worker)
+ # using hash=True often raise CancelledError,
+ # see dask/distributed#3703
+ _coro = self.client.scatter(
+ arg,
+ asynchronous=True,
+ hash=False
+ )
+ # Centralize the scattering of identical arguments
+ # between concurrent apply_async callbacks by
+ # exposing the running coroutine in
+ # call_data_futures before it completes.
+ t = asyncio.Task(_coro)
+ call_data_futures[arg] = t
+
+ f = await t
+
+ if f is not None:
+ out.append(f)
+ else:
+ out.append(arg)
+ return out
+
+ tasks = []
+ for f, args, kwargs in func.items:
+ args = list(await maybe_to_futures(args))
+ kwargs = dict(zip(kwargs.keys(),
+ await maybe_to_futures(kwargs.values())))
+ tasks.append((f, args, kwargs))
+
+ return (Batch(tasks), tasks)
+
+ def apply_async(self, func, callback=None):
+
+ cf_future = concurrent.futures.Future()
+ cf_future.get = cf_future.result # achieve AsyncResult API
+
+ async def f(func, callback):
+ batch, tasks = await self._to_func_args(func)
+ key = f'{repr(batch)}-{uuid4().hex}'
+
+ dask_future = self.client.submit(
+ _TracebackCapturingWrapper(batch),
+ tasks=tasks,
+ key=key,
+ **self.submit_kwargs
+ )
+ self.waiting_futures.add(dask_future)
+ self._callbacks[dask_future] = callback
+ self._results[dask_future] = cf_future
+
+ self.client.loop.add_callback(f, func, callback)
+
+ return cf_future
+
+ def retrieve_result_callback(self, out):
+ return _retrieve_traceback_capturing_wrapped_call(out)
def abort_everything(self, ensure_ready=True):
""" Tell the client to cancel any task submitted via this instance
joblib.Parallel will never access those results
"""
- pass
+ with self.waiting_futures.lock:
+ self.waiting_futures.futures.clear()
+ while not self.waiting_futures.queue.empty():
+ self.waiting_futures.queue.get()
@contextlib.contextmanager
def retrieval_context(self):
@@ -152,4 +367,13 @@ or
This removes thread from the worker's thread pool (using 'secede').
Seceding avoids deadlock in nested parallelism settings.
"""
- pass
+ # See 'joblib.Parallel.__call__' and 'joblib.Parallel.retrieve' for how
+ # this is used.
+ if hasattr(thread_state, 'execution_state'):
+ # we are in a worker. Secede to avoid deadlock.
+ secede()
+
+ yield
+
+ if hasattr(thread_state, 'execution_state'):
+ rejoin()
diff --git a/joblib/_memmapping_reducer.py b/joblib/_memmapping_reducer.py
index 5012683..13f5c4a 100644
--- a/joblib/_memmapping_reducer.py
+++ b/joblib/_memmapping_reducer.py
@@ -1,6 +1,10 @@
"""
Reducer using memory mapping for numpy arrays
"""
+# Author: Thomas Moreau <thomas.moreau.2010@gmail.com>
+# Copyright: 2017, Thomas Moreau
+# License: BSD 3 clause
+
from mmap import mmap
import errno
import os
@@ -13,27 +17,63 @@ import warnings
import weakref
from uuid import uuid4
from multiprocessing import util
+
from pickle import whichmodule, loads, dumps, HIGHEST_PROTOCOL, PicklingError
+
try:
WindowsError
except NameError:
WindowsError = type(None)
+
try:
import numpy as np
from numpy.lib.stride_tricks import as_strided
except ImportError:
np = None
+
from .numpy_pickle import dump, load, load_temporary_memmap
from .backports import make_memmap
from .disk import delete_folder
from .externals.loky.backend import resource_tracker
+
+# Some system have a ramdisk mounted by default, we can use it instead of /tmp
+# as the default folder to dump big arrays to share with subprocesses.
SYSTEM_SHARED_MEM_FS = '/dev/shm'
-SYSTEM_SHARED_MEM_FS_MIN_SIZE = int(2000000000.0)
+
+# Minimal number of bytes available on SYSTEM_SHARED_MEM_FS to consider using
+# it as the default folder to dump big arrays to share with subprocesses.
+SYSTEM_SHARED_MEM_FS_MIN_SIZE = int(2e9)
+
+# Folder and file permissions to chmod temporary files generated by the
+# memmapping pool. Only the owner of the Python process can access the
+# temporary files and folder.
FOLDER_PERMISSIONS = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR
FILE_PERMISSIONS = stat.S_IRUSR | stat.S_IWUSR
+
+# Set used in joblib workers, referencing the filenames of temporary memmaps
+# created by joblib to speed up data communication. In child processes, we add
+# a finalizer to these memmaps that sends a maybe_unlink call to the
+# resource_tracker, in order to free main memory as fast as possible.
JOBLIB_MMAPS = set()
+def _log_and_unlink(filename):
+ from .externals.loky.backend.resource_tracker import _resource_tracker
+ util.debug(
+ "[FINALIZER CALL] object mapping to {} about to be deleted,"
+ " decrementing the refcount of the file (pid: {})".format(
+ os.path.basename(filename), os.getpid()))
+ _resource_tracker.maybe_unlink(filename, "file")
+
+
+def add_maybe_unlink_finalizer(memmap):
+ util.debug(
+ "[FINALIZER ADD] adding finalizer to {} (id {}, filename {}, pid {})"
+ "".format(type(memmap), id(memmap), os.path.basename(memmap.filename),
+ os.getpid()))
+ weakref.finalize(memmap, _log_and_unlink, memmap.filename)
+
+
def unlink_file(filename):
"""Wrapper around os.unlink with a retry mechanism.
@@ -45,7 +85,24 @@ def unlink_file(filename):
it takes for the last reference of the memmap to be closed, yielding (on
Windows) a PermissionError in the resource_tracker loop.
"""
- pass
+ NUM_RETRIES = 10
+ for retry_no in range(1, NUM_RETRIES + 1):
+ try:
+ os.unlink(filename)
+ break
+ except PermissionError:
+ util.debug(
+ '[ResourceTracker] tried to unlink {}, got '
+ 'PermissionError'.format(filename)
+ )
+ if retry_no == NUM_RETRIES:
+ raise
+ else:
+ time.sleep(.2)
+ except FileNotFoundError:
+ # In case of a race condition when deleting the temporary folder,
+ # avoid noisy FileNotFoundError exception in the resource tracker.
+ pass
resource_tracker._CLEANUP_FUNCS['file'] = unlink_file
@@ -62,13 +119,54 @@ class _WeakArrayKeyMap:
def __init__(self):
self._data = {}
+ def get(self, obj):
+ ref, val = self._data[id(obj)]
+ if ref() is not obj:
+ # In case of race condition with on_destroy: could never be
+ # triggered by the joblib tests with CPython.
+ raise KeyError(obj)
+ return val
+
+ def set(self, obj, value):
+ key = id(obj)
+ try:
+ ref, _ = self._data[key]
+ if ref() is not obj:
+ # In case of race condition with on_destroy: could never be
+ # triggered by the joblib tests with CPython.
+ raise KeyError(obj)
+ except KeyError:
+ # Insert the new entry in the mapping along with a weakref
+ # callback to automatically delete the entry from the mapping
+ # as soon as the object used as key is garbage collected.
+ def on_destroy(_):
+ del self._data[key]
+ ref = weakref.ref(obj, on_destroy)
+ self._data[key] = ref, value
+
def __getstate__(self):
- raise PicklingError('_WeakArrayKeyMap is not pickleable')
+ raise PicklingError("_WeakArrayKeyMap is not pickleable")
+
+
+###############################################################################
+# Support for efficient transient pickling of numpy data structures
def _get_backing_memmap(a):
"""Recursively look up the original np.memmap instance base if any."""
- pass
+ b = getattr(a, 'base', None)
+ if b is None:
+ # TODO: check scipy sparse datastructure if scipy is installed
+ # a nor its descendants do not have a memmap base
+ return None
+
+ elif isinstance(b, mmap):
+ # a is already a real memmap instance.
+ return a
+
+ else:
+ # Recursive exploration of the base ancestry
+ return _get_backing_memmap(b)
def _get_temp_dir(pool_folder_name, temp_folder=None):
@@ -101,18 +199,61 @@ def _get_temp_dir(pool_folder_name, temp_folder=None):
whether the temporary folder is written to the system shared memory
folder or some other temporary folder.
"""
- pass
+ use_shared_mem = False
+ if temp_folder is None:
+ temp_folder = os.environ.get('JOBLIB_TEMP_FOLDER', None)
+ if temp_folder is None:
+ if os.path.exists(SYSTEM_SHARED_MEM_FS) and hasattr(os, 'statvfs'):
+ try:
+ shm_stats = os.statvfs(SYSTEM_SHARED_MEM_FS)
+ available_nbytes = shm_stats.f_bsize * shm_stats.f_bavail
+ if available_nbytes > SYSTEM_SHARED_MEM_FS_MIN_SIZE:
+ # Try to see if we have write access to the shared mem
+ # folder only if it is reasonably large (that is 2GB or
+ # more).
+ temp_folder = SYSTEM_SHARED_MEM_FS
+ pool_folder = os.path.join(temp_folder, pool_folder_name)
+ if not os.path.exists(pool_folder):
+ os.makedirs(pool_folder)
+ use_shared_mem = True
+ except (IOError, OSError):
+ # Missing rights in the /dev/shm partition, fallback to regular
+ # temp folder.
+ temp_folder = None
+ if temp_folder is None:
+ # Fallback to the default tmp folder, typically /tmp
+ temp_folder = tempfile.gettempdir()
+ temp_folder = os.path.abspath(os.path.expanduser(temp_folder))
+ pool_folder = os.path.join(temp_folder, pool_folder_name)
+ return pool_folder, use_shared_mem
def has_shareable_memory(a):
"""Return True if a is backed by some mmap buffer directly or not."""
- pass
+ return _get_backing_memmap(a) is not None
-def _strided_from_memmap(filename, dtype, mode, offset, order, shape,
- strides, total_buffer_len, unlink_on_gc_collect):
+def _strided_from_memmap(filename, dtype, mode, offset, order, shape, strides,
+ total_buffer_len, unlink_on_gc_collect):
"""Reconstruct an array view on a memory mapped file."""
- pass
+ if mode == 'w+':
+ # Do not zero the original data when unpickling
+ mode = 'r+'
+
+ if strides is None:
+ # Simple, contiguous memmap
+ return make_memmap(
+ filename, dtype=dtype, shape=shape, mode=mode, offset=offset,
+ order=order, unlink_on_gc_collect=unlink_on_gc_collect
+ )
+ else:
+ # For non-contiguous data, memmap the total enclosing buffer and then
+ # extract the non-contiguous view with the stride-tricks API
+ base = make_memmap(
+ filename, dtype=dtype, shape=total_buffer_len, offset=offset,
+ mode=mode, order=order, unlink_on_gc_collect=unlink_on_gc_collect
+ )
+ return as_strided(base, shape=shape, strides=strides)
def _reduce_memmap_backed(a, m):
@@ -122,12 +263,60 @@ def _reduce_memmap_backed(a, m):
m is expected to be an instance of np.memmap on the top of the ``base``
attribute ancestry of a. ``m.base`` should be the real python mmap object.
"""
- pass
+ # offset that comes from the striding differences between a and m
+ util.debug('[MEMMAP REDUCE] reducing a memmap-backed array '
+ '(shape, {}, pid: {})'.format(a.shape, os.getpid()))
+ try:
+ from numpy.lib.array_utils import byte_bounds
+ except (ModuleNotFoundError, ImportError):
+ # Backward-compat for numpy < 2.0
+ from numpy import byte_bounds
+ a_start, a_end = byte_bounds(a)
+ m_start = byte_bounds(m)[0]
+ offset = a_start - m_start
+
+ # offset from the backing memmap
+ offset += m.offset
+
+ if m.flags['F_CONTIGUOUS']:
+ order = 'F'
+ else:
+ # The backing memmap buffer is necessarily contiguous hence C if not
+ # Fortran
+ order = 'C'
+
+ if a.flags['F_CONTIGUOUS'] or a.flags['C_CONTIGUOUS']:
+ # If the array is a contiguous view, no need to pass the strides
+ strides = None
+ total_buffer_len = None
+ else:
+ # Compute the total number of items to map from which the strided
+ # view will be extracted.
+ strides = a.strides
+ total_buffer_len = (a_end - a_start) // a.itemsize
+
+ return (_strided_from_memmap,
+ (m.filename, a.dtype, m.mode, offset, order, a.shape, strides,
+ total_buffer_len, False))
def reduce_array_memmap_backward(a):
"""reduce a np.array or a np.memmap from a child process"""
- pass
+ m = _get_backing_memmap(a)
+ if isinstance(m, np.memmap) and m.filename not in JOBLIB_MMAPS:
+ # if a is backed by a memmaped file, reconstruct a using the
+ # memmaped file.
+ return _reduce_memmap_backed(a, m)
+ else:
+ # a is either a regular (not memmap-backed) numpy array, or an array
+ # backed by a shared temporary file created by joblib. In the latter
+ # case, in order to limit the lifespan of these temporary files, we
+ # serialize the memmap as a regular numpy array, and decref the
+ # file backing the memmap (done implicitly in a previously registered
+ # finalizer, see ``unlink_on_gc_collect`` for more details)
+ return (
+ loads, (dumps(np.asarray(a), protocol=HIGHEST_PROTOCOL), )
+ )
class ArrayMemmapForwardReducer(object):
@@ -156,14 +345,15 @@ class ArrayMemmapForwardReducer(object):
"""
def __init__(self, max_nbytes, temp_folder_resolver, mmap_mode,
- unlink_on_gc_collect, verbose=0, prewarm=True):
+ unlink_on_gc_collect, verbose=0, prewarm=True):
self._max_nbytes = max_nbytes
self._temp_folder_resolver = temp_folder_resolver
self._mmap_mode = mmap_mode
self.verbose = int(verbose)
- if prewarm == 'auto':
+ if prewarm == "auto":
self._prewarm = not self._temp_folder.startswith(
- SYSTEM_SHARED_MEM_FS)
+ SYSTEM_SHARED_MEM_FS
+ )
else:
self._prewarm = prewarm
self._prewarm = prewarm
@@ -171,68 +361,157 @@ class ArrayMemmapForwardReducer(object):
self._temporary_memmaped_filenames = set()
self._unlink_on_gc_collect = unlink_on_gc_collect
+ @property
+ def _temp_folder(self):
+ return self._temp_folder_resolver()
+
def __reduce__(self):
- args = (self._max_nbytes, None, self._mmap_mode, self.
- _unlink_on_gc_collect)
- kwargs = {'verbose': self.verbose, 'prewarm': self._prewarm}
+ # The ArrayMemmapForwardReducer is passed to the children processes: it
+ # needs to be pickled but the _WeakArrayKeyMap need to be skipped as
+ # it's only guaranteed to be consistent with the parent process memory
+ # garbage collection.
+ # Although this reducer is pickled, it is not needed in its destination
+ # process (child processes), as we only use this reducer to send
+ # memmaps from the parent process to the children processes. For this
+ # reason, we can afford skipping the resolver, (which would otherwise
+ # be unpicklable), and pass it as None instead.
+ args = (self._max_nbytes, None, self._mmap_mode,
+ self._unlink_on_gc_collect)
+ kwargs = {
+ 'verbose': self.verbose,
+ 'prewarm': self._prewarm,
+ }
return ArrayMemmapForwardReducer, args, kwargs
def __call__(self, a):
m = _get_backing_memmap(a)
if m is not None and isinstance(m, np.memmap):
+ # a is already backed by a memmap file, let's reuse it directly
return _reduce_memmap_backed(a, m)
- if (not a.dtype.hasobject and self._max_nbytes is not None and a.
- nbytes > self._max_nbytes):
+
+ if (not a.dtype.hasobject and self._max_nbytes is not None and
+ a.nbytes > self._max_nbytes):
+ # check that the folder exists (lazily create the pool temp folder
+ # if required)
try:
os.makedirs(self._temp_folder)
os.chmod(self._temp_folder, FOLDER_PERMISSIONS)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
+
try:
basename = self._memmaped_arrays.get(a)
except KeyError:
- basename = '{}-{}-{}.pkl'.format(os.getpid(), id(threading.
- current_thread()), uuid4().hex)
+ # Generate a new unique random filename. The process and thread
+ # ids are only useful for debugging purpose and to make it
+ # easier to cleanup orphaned files in case of hard process
+ # kill (e.g. by "kill -9" or segfault).
+ basename = "{}-{}-{}.pkl".format(
+ os.getpid(), id(threading.current_thread()), uuid4().hex)
self._memmaped_arrays.set(a, basename)
filename = os.path.join(self._temp_folder, basename)
+
+ # In case the same array with the same content is passed several
+ # times to the pool subprocess children, serialize it only once
+
is_new_memmap = filename not in self._temporary_memmaped_filenames
+
+ # add the memmap to the list of temporary memmaps created by joblib
self._temporary_memmaped_filenames.add(filename)
+
if self._unlink_on_gc_collect:
- resource_tracker.register(filename, 'file')
+ # Bump reference count of the memmap by 1 to account for
+ # shared usage of the memmap by a child process. The
+ # corresponding decref call will be executed upon calling
+ # resource_tracker.maybe_unlink, registered as a finalizer in
+ # the child.
+ # the incref/decref calls here are only possible when the child
+ # and the parent share the same resource_tracker. It is not the
+ # case for the multiprocessing backend, but it does not matter
+ # because unlinking a memmap from a child process is only
+ # useful to control the memory usage of long-lasting child
+ # processes, while the multiprocessing-based pools terminate
+ # their workers at the end of a map() call.
+ resource_tracker.register(filename, "file")
+
if is_new_memmap:
- resource_tracker.register(filename, 'file')
+ # Incref each temporary memmap created by joblib one extra
+ # time. This means that these memmaps will only be deleted
+ # once an extra maybe_unlink() is called, which is done once
+ # all the jobs have completed (or been canceled) in the
+ # Parallel._terminate_backend() method.
+ resource_tracker.register(filename, "file")
+
if not os.path.exists(filename):
util.debug(
- '[ARRAY DUMP] Pickling new array (shape={}, dtype={}) creating a new memmap at {}'
- .format(a.shape, a.dtype, filename))
+ "[ARRAY DUMP] Pickling new array (shape={}, dtype={}) "
+ "creating a new memmap at {}".format(
+ a.shape, a.dtype, filename))
for dumped_filename in dump(a, filename):
os.chmod(dumped_filename, FILE_PERMISSIONS)
+
if self._prewarm:
+ # Warm up the data by accessing it. This operation ensures
+ # that the disk access required to create the memmapping
+ # file are performed in the reducing process and avoids
+ # concurrent memmap creation in multiple children
+ # processes.
load(filename, mmap_mode=self._mmap_mode).max()
+
else:
util.debug(
- '[ARRAY DUMP] Pickling known array (shape={}, dtype={}) reusing memmap file: {}'
- .format(a.shape, a.dtype, os.path.basename(filename)))
- return load_temporary_memmap, (filename, self._mmap_mode, self.
- _unlink_on_gc_collect)
+ "[ARRAY DUMP] Pickling known array (shape={}, dtype={}) "
+ "reusing memmap file: {}".format(
+ a.shape, a.dtype, os.path.basename(filename)))
+
+ # The worker process will use joblib.load to memmap the data
+ return (
+ (load_temporary_memmap, (filename, self._mmap_mode,
+ self._unlink_on_gc_collect))
+ )
else:
+ # do not convert a into memmap, let pickler do its usual copy with
+ # the default system pickler
util.debug(
- '[ARRAY DUMP] Pickling array (NO MEMMAPPING) (shape={}, dtype={}).'
- .format(a.shape, a.dtype))
- return loads, (dumps(a, protocol=HIGHEST_PROTOCOL),)
+ '[ARRAY DUMP] Pickling array (NO MEMMAPPING) (shape={}, '
+ ' dtype={}).'.format(a.shape, a.dtype))
+ return (loads, (dumps(a, protocol=HIGHEST_PROTOCOL),))
-def get_memmapping_reducers(forward_reducers=None, backward_reducers=None,
- temp_folder_resolver=None, max_nbytes=1000000.0, mmap_mode='r', verbose
- =0, prewarm=False, unlink_on_gc_collect=True, **kwargs):
+def get_memmapping_reducers(
+ forward_reducers=None, backward_reducers=None,
+ temp_folder_resolver=None, max_nbytes=1e6, mmap_mode='r', verbose=0,
+ prewarm=False, unlink_on_gc_collect=True, **kwargs):
"""Construct a pair of memmapping reducer linked to a tmpdir.
This function manage the creation and the clean up of the temporary folders
underlying the memory maps and should be use to get the reducers necessary
to construct joblib pool or executor.
"""
- pass
+ if forward_reducers is None:
+ forward_reducers = dict()
+ if backward_reducers is None:
+ backward_reducers = dict()
+
+ if np is not None:
+ # Register smart numpy.ndarray reducers that detects memmap backed
+ # arrays and that is also able to dump to memmap large in-memory
+ # arrays over the max_nbytes threshold
+ forward_reduce_ndarray = ArrayMemmapForwardReducer(
+ max_nbytes, temp_folder_resolver, mmap_mode, unlink_on_gc_collect,
+ verbose, prewarm=prewarm)
+ forward_reducers[np.ndarray] = forward_reduce_ndarray
+ forward_reducers[np.memmap] = forward_reduce_ndarray
+
+ # Communication from child process to the parent process always
+ # pickles in-memory numpy.ndarray without dumping them as memmap
+ # to avoid confusing the caller and make it tricky to collect the
+ # temporary folder
+ backward_reducers[np.ndarray] = reduce_array_memmap_backward
+ backward_reducers[np.memmap] = reduce_array_memmap_backward
+
+ return forward_reducers, backward_reducers
class TemporaryResourcesManager(object):
@@ -253,14 +532,126 @@ class TemporaryResourcesManager(object):
self._id = uuid4().hex
self._finalizers = {}
if context_id is None:
+ # It would be safer to not assign a default context id (less silent
+ # bugs), but doing this while maintaining backward compatibility
+ # with the previous, context-unaware version get_memmaping_executor
+ # exposes too many low-level details.
context_id = uuid4().hex
self.set_current_context(context_id)
+ def set_current_context(self, context_id):
+ self._current_context_id = context_id
+ self.register_new_context(context_id)
+
+ def register_new_context(self, context_id):
+ # Prepare a sub-folder name specific to a context (usually a unique id
+ # generated by each instance of the Parallel class). Do not create in
+ # advance to spare FS write access if no array is to be dumped).
+ if context_id in self._cached_temp_folders:
+ return
+ else:
+ # During its lifecycle, one Parallel object can have several
+ # executors associated to it (for instance, if a loky worker raises
+ # an exception, joblib shutdowns the executor and instantly
+ # recreates a new one before raising the error - see
+ # ``ensure_ready``. Because we don't want two executors tied to
+ # the same Parallel object (and thus the same context id) to
+ # register/use/delete the same folder, we also add an id specific
+ # to the current Manager (and thus specific to its associated
+ # executor) to the folder name.
+ new_folder_name = (
+ "joblib_memmapping_folder_{}_{}_{}".format(
+ os.getpid(), self._id, context_id)
+ )
+ new_folder_path, _ = _get_temp_dir(
+ new_folder_name, self._temp_folder_root
+ )
+ self.register_folder_finalizer(new_folder_path, context_id)
+ self._cached_temp_folders[context_id] = new_folder_path
+
def resolve_temp_folder_name(self):
"""Return a folder name specific to the currently activated context"""
- pass
+ return self._cached_temp_folders[self._current_context_id]
+
+ # resource management API
+
+ def register_folder_finalizer(self, pool_subfolder, context_id):
+ # Register the garbage collector at program exit in case caller forgets
+ # to call terminate explicitly: note we do not pass any reference to
+ # ensure that this callback won't prevent garbage collection of
+ # parallel instance and related file handler resources such as POSIX
+ # semaphores and pipes
+ pool_module_name = whichmodule(delete_folder, 'delete_folder')
+ resource_tracker.register(pool_subfolder, "folder")
+
+ def _cleanup():
+ # In some cases the Python runtime seems to set delete_folder to
+ # None just before exiting when accessing the delete_folder
+ # function from the closure namespace. So instead we reimport
+ # the delete_folder function explicitly.
+ # https://github.com/joblib/joblib/issues/328
+ # We cannot just use from 'joblib.pool import delete_folder'
+ # because joblib should only use relative imports to allow
+ # easy vendoring.
+ delete_folder = __import__(
+ pool_module_name, fromlist=['delete_folder']
+ ).delete_folder
+ try:
+ delete_folder(pool_subfolder, allow_non_empty=True)
+ resource_tracker.unregister(pool_subfolder, "folder")
+ except OSError:
+ warnings.warn("Failed to delete temporary folder: {}"
+ .format(pool_subfolder))
+
+ self._finalizers[context_id] = atexit.register(_cleanup)
def _clean_temporary_resources(self, context_id=None, force=False,
- allow_non_empty=False):
+ allow_non_empty=False):
"""Clean temporary resources created by a process-based pool"""
- pass
+ if context_id is None:
+ # Iterates over a copy of the cache keys to avoid Error due to
+ # iterating over a changing size dictionary.
+ for context_id in list(self._cached_temp_folders):
+ self._clean_temporary_resources(
+ context_id, force=force, allow_non_empty=allow_non_empty
+ )
+ else:
+ temp_folder = self._cached_temp_folders.get(context_id)
+ if temp_folder and os.path.exists(temp_folder):
+ for filename in os.listdir(temp_folder):
+ if force:
+ # Some workers have failed and the ref counted might
+ # be off. The workers should have shut down by this
+ # time so forcefully clean up the files.
+ resource_tracker.unregister(
+ os.path.join(temp_folder, filename), "file"
+ )
+ else:
+ resource_tracker.maybe_unlink(
+ os.path.join(temp_folder, filename), "file"
+ )
+
+ # When forcing clean-up, try to delete the folder even if some
+ # files are still in it. Otherwise, try to delete the folder
+ allow_non_empty |= force
+
+ # Clean up the folder if possible, either if it is empty or
+ # if none of the files in it are in used and allow_non_empty.
+ try:
+ delete_folder(
+ temp_folder, allow_non_empty=allow_non_empty
+ )
+ # Forget the folder once it has been deleted
+ self._cached_temp_folders.pop(context_id, None)
+ resource_tracker.unregister(temp_folder, "folder")
+
+ # Also cancel the finalizers that gets triggered at gc.
+ finalizer = self._finalizers.pop(context_id, None)
+ if finalizer is not None:
+ atexit.unregister(finalizer)
+
+ except OSError:
+ # Temporary folder cannot be deleted right now.
+ # This folder will be cleaned up by an atexit
+ # finalizer registered by the memmapping_reducer.
+ pass
diff --git a/joblib/_multiprocessing_helpers.py b/joblib/_multiprocessing_helpers.py
index 6441d34..bde4bc1 100644
--- a/joblib/_multiprocessing_helpers.py
+++ b/joblib/_multiprocessing_helpers.py
@@ -5,31 +5,48 @@ circular dependencies (for instance for the assert_spawning name).
"""
import os
import warnings
+
+
+# Obtain possible configuration from the environment, assuming 1 (on)
+# by default, upon 0 set to None. Should instructively fail if some non
+# 0/1 value is set.
mp = int(os.environ.get('JOBLIB_MULTIPROCESSING', 1)) or None
if mp:
try:
import multiprocessing as mp
- import _multiprocessing
+ import _multiprocessing # noqa
except ImportError:
mp = None
+
+# 2nd stage: validate that locking is available on the system and
+# issue a warning if not
if mp is not None:
try:
+ # try to create a named semaphore using SemLock to make sure they are
+ # available on this platform. We use the low level object
+ # _multiprocessing.SemLock to avoid spawning a resource tracker on
+ # Unix system or changing the default backend.
import tempfile
from _multiprocessing import SemLock
+
_rand = tempfile._RandomNameSequence()
for i in range(100):
try:
- name = '/joblib-{}-{}'.format(os.getpid(), next(_rand))
+ name = '/joblib-{}-{}' .format(
+ os.getpid(), next(_rand))
_sem = SemLock(0, 0, 1, name=name, unlink=True)
- del _sem
+ del _sem # cleanup
break
- except FileExistsError as e:
+ except FileExistsError as e: # pragma: no cover
if i >= 99:
- raise FileExistsError('cannot find name for semaphore'
- ) from e
+ raise FileExistsError(
+ 'cannot find name for semaphore') from e
except (FileExistsError, AttributeError, ImportError, OSError) as e:
mp = None
warnings.warn('%s. joblib will operate in serial mode' % (e,))
+
+
+# 3rd stage: backward compat for the assert_spawning helper
if mp is not None:
from multiprocessing.context import assert_spawning
else:
diff --git a/joblib/_parallel_backends.py b/joblib/_parallel_backends.py
index 87fe642..8201c96 100644
--- a/joblib/_parallel_backends.py
+++ b/joblib/_parallel_backends.py
@@ -1,38 +1,61 @@
"""
Backends for embarrassingly parallel code.
"""
+
import gc
import os
import warnings
import threading
import contextlib
from abc import ABCMeta, abstractmethod
-from ._utils import _TracebackCapturingWrapper, _retrieve_traceback_capturing_wrapped_call
+
+from ._utils import (
+ _TracebackCapturingWrapper,
+ _retrieve_traceback_capturing_wrapped_call
+)
+
from ._multiprocessing_helpers import mp
+
if mp is not None:
from .pool import MemmappingPool
from multiprocessing.pool import ThreadPool
from .executor import get_memmapping_executor
+
+ # Import loky only if multiprocessing is present
from .externals.loky import process_executor, cpu_count
from .externals.loky.process_executor import ShutdownExecutorError
class ParallelBackendBase(metaclass=ABCMeta):
"""Helper abc which defines all methods a ParallelBackend must implement"""
+
supports_inner_max_num_threads = False
supports_retrieve_callback = False
default_n_jobs = 1
+
+ @property
+ def supports_return_generator(self):
+ return self.supports_retrieve_callback
+
+ @property
+ def supports_timeout(self):
+ return self.supports_retrieve_callback
+
nesting_level = None
- def __init__(self, nesting_level=None, inner_max_num_threads=None, **kwargs
- ):
+ def __init__(self, nesting_level=None, inner_max_num_threads=None,
+ **kwargs):
super().__init__(**kwargs)
self.nesting_level = nesting_level
self.inner_max_num_threads = inner_max_num_threads
- MAX_NUM_THREADS_VARS = ['OMP_NUM_THREADS', 'OPENBLAS_NUM_THREADS',
- 'MKL_NUM_THREADS', 'BLIS_NUM_THREADS', 'VECLIB_MAXIMUM_THREADS',
- 'NUMBA_NUM_THREADS', 'NUMEXPR_NUM_THREADS']
- TBB_ENABLE_IPC_VAR = 'ENABLE_IPC'
+
+ MAX_NUM_THREADS_VARS = [
+ 'OMP_NUM_THREADS', 'OPENBLAS_NUM_THREADS', 'MKL_NUM_THREADS',
+ 'BLIS_NUM_THREADS', 'VECLIB_MAXIMUM_THREADS', 'NUMBA_NUM_THREADS',
+ 'NUMEXPR_NUM_THREADS',
+ ]
+
+ TBB_ENABLE_IPC_VAR = "ENABLE_IPC"
@abstractmethod
def effective_n_jobs(self, n_jobs):
@@ -51,12 +74,10 @@ class ParallelBackendBase(metaclass=ABCMeta):
scheduling overhead and better use of CPU cache prefetching heuristics)
as long as all the workers have enough work to do.
"""
- pass
@abstractmethod
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
def retrieve_result_callback(self, out):
"""Called within the callback function passed in apply_async.
@@ -65,40 +86,36 @@ class ParallelBackendBase(metaclass=ABCMeta):
the considered backend. It is supposed to return the outcome of a task
if it succeeded or raise the exception if it failed.
"""
- pass
def configure(self, n_jobs=1, parallel=None, prefer=None, require=None,
- **backend_args):
+ **backend_args):
"""Reconfigure the backend and return the number of workers.
This makes it possible to reuse an existing backend instance for
successive independent calls to Parallel with different parameters.
"""
- pass
+ self.parallel = parallel
+ return self.effective_n_jobs(n_jobs)
def start_call(self):
"""Call-back method called at the beginning of a Parallel call"""
- pass
def stop_call(self):
"""Call-back method called at the end of a Parallel call"""
- pass
def terminate(self):
"""Shutdown the workers and free the shared memory."""
- pass
def compute_batch_size(self):
"""Determine the optimal batch size"""
- pass
+ return 1
def batch_completed(self, batch_size, duration):
"""Callback indicate how long it took to run a batch"""
- pass
def get_exceptions(self):
"""List of exception types to be captured."""
- pass
+ return []
def abort_everything(self, ensure_ready=True):
"""Abort any running tasks
@@ -120,6 +137,8 @@ class ParallelBackendBase(metaclass=ABCMeta):
managed by the backend it-self: if we expect no new tasks, there is no
point in re-creating new workers.
"""
+ # Does nothing by default: to be overridden in subclasses when
+ # canceling tasks is possible.
pass
def get_nested_backend(self):
@@ -129,7 +148,11 @@ class ParallelBackendBase(metaclass=ABCMeta):
nesting. Beyond, switch to sequential backend to avoid spawning too
many threads on the host.
"""
- pass
+ nesting_level = getattr(self, 'nesting_level', 0) + 1
+ if nesting_level > 1:
+ return SequentialBackend(nesting_level=nesting_level), None
+ else:
+ return ThreadingBackend(nesting_level=nesting_level), None
@contextlib.contextmanager
def retrieval_context(self):
@@ -146,7 +169,7 @@ class ParallelBackendBase(metaclass=ABCMeta):
calls to finish, but the backend has no free workers to execute those
tasks.
"""
- pass
+ yield
def _prepare_worker_env(self, n_jobs):
"""Return environment variables limiting threadpools in external libs.
@@ -156,7 +179,30 @@ class ParallelBackendBase(metaclass=ABCMeta):
number of threads to `n_threads` for OpenMP, MKL, Accelerated and
OpenBLAS libraries in the child processes.
"""
- pass
+ explicit_n_threads = self.inner_max_num_threads
+ default_n_threads = max(cpu_count() // n_jobs, 1)
+
+ # Set the inner environment variables to self.inner_max_num_threads if
+ # it is given. Else, default to cpu_count // n_jobs unless the variable
+ # is already present in the parent process environment.
+ env = {}
+ for var in self.MAX_NUM_THREADS_VARS:
+ if explicit_n_threads is None:
+ var_value = os.environ.get(var, default_n_threads)
+ else:
+ var_value = explicit_n_threads
+
+ env[var] = str(var_value)
+
+ if self.TBB_ENABLE_IPC_VAR not in os.environ:
+ # To avoid over-subscription when using TBB, let the TBB schedulers
+ # use Inter Process Communication to coordinate:
+ env[self.TBB_ENABLE_IPC_VAR] = "1"
+ return env
+
+ @staticmethod
+ def in_main_thread():
+ return isinstance(threading.current_thread(), threading._MainThread)
class SequentialBackend(ParallelBackendBase):
@@ -165,6 +211,7 @@ class SequentialBackend(ParallelBackendBase):
Does not use/create any threading objects, and hence has minimal
overhead. Used when n_jobs == 1.
"""
+
uses_threads = True
supports_timeout = False
supports_retrieve_callback = False
@@ -172,46 +219,90 @@ class SequentialBackend(ParallelBackendBase):
def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
- pass
+ if n_jobs == 0:
+ raise ValueError('n_jobs == 0 in Parallel has no meaning')
+ return 1
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
+ raise RuntimeError("Should never be called for SequentialBackend.")
+
+ def retrieve_result_callback(self, out):
+ raise RuntimeError("Should never be called for SequentialBackend.")
+
+ def get_nested_backend(self):
+ # import is not top level to avoid cyclic import errors.
+ from .parallel import get_active_backend
+
+ # SequentialBackend should neither change the nesting level, the
+ # default backend or the number of jobs. Just return the current one.
+ return get_active_backend()
class PoolManagerMixin(object):
"""A helper class for managing pool of workers."""
+
_pool = None
def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
- pass
+ if n_jobs == 0:
+ raise ValueError('n_jobs == 0 in Parallel has no meaning')
+ elif mp is None or n_jobs is None:
+ # multiprocessing is not available or disabled, fallback
+ # to sequential mode
+ return 1
+ elif n_jobs < 0:
+ n_jobs = max(cpu_count() + 1 + n_jobs, 1)
+ return n_jobs
def terminate(self):
"""Shutdown the process or thread pool"""
- pass
+ if self._pool is not None:
+ self._pool.close()
+ self._pool.terminate() # terminate does a join()
+ self._pool = None
def _get_pool(self):
"""Used by apply_async to make it possible to implement lazy init"""
- pass
+ return self._pool
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
+ # Here, we need a wrapper to avoid crashes on KeyboardInterruptErrors.
+ # We also call the callback on error, to make sure the pool does not
+ # wait on crashed jobs.
+ return self._get_pool().apply_async(
+ _TracebackCapturingWrapper(func), (),
+ callback=callback, error_callback=callback
+ )
def retrieve_result_callback(self, out):
"""Mimic concurrent.futures results, raising an error if needed."""
- pass
+ return _retrieve_traceback_capturing_wrapped_call(out)
def abort_everything(self, ensure_ready=True):
"""Shutdown the pool and restart a new one with the same parameters"""
- pass
+ self.terminate()
+ if ensure_ready:
+ self.configure(n_jobs=self.parallel.n_jobs, parallel=self.parallel,
+ **self.parallel._backend_args)
class AutoBatchingMixin(object):
"""A helper class for automagically batching jobs."""
- MIN_IDEAL_BATCH_DURATION = 0.2
+
+ # In seconds, should be big enough to hide multiprocessing dispatching
+ # overhead.
+ # This settings was found by running benchmarks/bench_auto_batching.py
+ # with various parameters on various platforms.
+ MIN_IDEAL_BATCH_DURATION = .2
+
+ # Should not be too high to avoid stragglers: long jobs running alone
+ # on a single worker while other workers have no work to process any more.
MAX_IDEAL_BATCH_DURATION = 2
+
+ # Batching counters default values
_DEFAULT_EFFECTIVE_BATCH_SIZE = 1
_DEFAULT_SMOOTHED_BATCH_DURATION = 0.0
@@ -222,18 +313,89 @@ class AutoBatchingMixin(object):
def compute_batch_size(self):
"""Determine the optimal batch size"""
- pass
+ old_batch_size = self._effective_batch_size
+ batch_duration = self._smoothed_batch_duration
+ if (batch_duration > 0 and
+ batch_duration < self.MIN_IDEAL_BATCH_DURATION):
+ # The current batch size is too small: the duration of the
+ # processing of a batch of task is not large enough to hide
+ # the scheduling overhead.
+ ideal_batch_size = int(old_batch_size *
+ self.MIN_IDEAL_BATCH_DURATION /
+ batch_duration)
+ # Multiply by two to limit oscilations between min and max.
+ ideal_batch_size *= 2
+
+ # dont increase the batch size too fast to limit huge batch sizes
+ # potentially leading to starving worker
+ batch_size = min(2 * old_batch_size, ideal_batch_size)
+
+ batch_size = max(batch_size, 1)
+
+ self._effective_batch_size = batch_size
+ if self.parallel.verbose >= 10:
+ self.parallel._print(
+ f"Batch computation too fast ({batch_duration}s.) "
+ f"Setting batch_size={batch_size}."
+ )
+ elif (batch_duration > self.MAX_IDEAL_BATCH_DURATION and
+ old_batch_size >= 2):
+ # The current batch size is too big. If we schedule overly long
+ # running batches some CPUs might wait with nothing left to do
+ # while a couple of CPUs a left processing a few long running
+ # batches. Better reduce the batch size a bit to limit the
+ # likelihood of scheduling such stragglers.
+
+ # decrease the batch size quickly to limit potential starving
+ ideal_batch_size = int(
+ old_batch_size * self.MIN_IDEAL_BATCH_DURATION / batch_duration
+ )
+ # Multiply by two to limit oscilations between min and max.
+ batch_size = max(2 * ideal_batch_size, 1)
+ self._effective_batch_size = batch_size
+ if self.parallel.verbose >= 10:
+ self.parallel._print(
+ f"Batch computation too slow ({batch_duration}s.) "
+ f"Setting batch_size={batch_size}."
+ )
+ else:
+ # No batch size adjustment
+ batch_size = old_batch_size
+
+ if batch_size != old_batch_size:
+ # Reset estimation of the smoothed mean batch duration: this
+ # estimate is updated in the multiprocessing apply_async
+ # CallBack as long as the batch_size is constant. Therefore
+ # we need to reset the estimate whenever we re-tune the batch
+ # size.
+ self._smoothed_batch_duration = \
+ self._DEFAULT_SMOOTHED_BATCH_DURATION
+
+ return batch_size
def batch_completed(self, batch_size, duration):
"""Callback indicate how long it took to run a batch"""
- pass
+ if batch_size == self._effective_batch_size:
+ # Update the smoothed streaming estimate of the duration of a batch
+ # from dispatch to completion
+ old_duration = self._smoothed_batch_duration
+ if old_duration == self._DEFAULT_SMOOTHED_BATCH_DURATION:
+ # First record of duration for this batch size after the last
+ # reset.
+ new_duration = duration
+ else:
+ # Update the exponentially weighted average of the duration of
+ # batch for the current effective size.
+ new_duration = 0.8 * old_duration + 0.2 * duration
+ self._smoothed_batch_duration = new_duration
def reset_batch_stats(self):
"""Reset batch statistics to default values.
This avoids interferences with future jobs.
"""
- pass
+ self._effective_batch_size = self._DEFAULT_EFFECTIVE_BATCH_SIZE
+ self._smoothed_batch_duration = self._DEFAULT_SMOOTHED_BATCH_DURATION
class ThreadingBackend(PoolManagerMixin, ParallelBackendBase):
@@ -250,13 +412,21 @@ class ThreadingBackend(PoolManagerMixin, ParallelBackendBase):
ThreadingBackend is used as the default backend for nested calls.
"""
+
supports_retrieve_callback = True
uses_threads = True
supports_sharedmem = True
def configure(self, n_jobs=1, parallel=None, **backend_args):
"""Build a process or thread pool and return the number of workers"""
- pass
+ n_jobs = self.effective_n_jobs(n_jobs)
+ if n_jobs == 1:
+ # Avoid unnecessary overhead and use sequential backend instead.
+ raise FallbackToBackend(
+ SequentialBackend(nesting_level=self.nesting_level))
+ self.parallel = parallel
+ self._n_jobs = n_jobs
+ return n_jobs
def _get_pool(self):
"""Lazily initialize the thread pool
@@ -264,17 +434,20 @@ class ThreadingBackend(PoolManagerMixin, ParallelBackendBase):
The actual pool of worker threads is only initialized at the first
call to apply_async.
"""
- pass
+ if self._pool is None:
+ self._pool = ThreadPool(self._n_jobs)
+ return self._pool
class MultiprocessingBackend(PoolManagerMixin, AutoBatchingMixin,
- ParallelBackendBase):
+ ParallelBackendBase):
"""A ParallelBackend which will use a multiprocessing.Pool.
Will introduce some communication and memory overhead when exchanging
input and output data with the with the worker Python processes.
However, does not suffer from the Python Global Interpreter Lock.
"""
+
supports_retrieve_callback = True
supports_return_generator = False
@@ -284,40 +457,171 @@ class MultiprocessingBackend(PoolManagerMixin, AutoBatchingMixin,
This also checks if we are attempting to create a nested parallel
loop.
"""
- pass
+ if mp is None:
+ return 1
+
+ if mp.current_process().daemon:
+ # Daemonic processes cannot have children
+ if n_jobs != 1:
+ if inside_dask_worker():
+ msg = (
+ "Inside a Dask worker with daemon=True, "
+ "setting n_jobs=1.\nPossible work-arounds:\n"
+ "- dask.config.set("
+ "{'distributed.worker.daemon': False})"
+ "- set the environment variable "
+ "DASK_DISTRIBUTED__WORKER__DAEMON=False\n"
+ "before creating your Dask cluster."
+ )
+ else:
+ msg = (
+ 'Multiprocessing-backed parallel loops '
+ 'cannot be nested, setting n_jobs=1'
+ )
+ warnings.warn(msg, stacklevel=3)
+ return 1
+
+ if process_executor._CURRENT_DEPTH > 0:
+ # Mixing loky and multiprocessing in nested loop is not supported
+ if n_jobs != 1:
+ warnings.warn(
+ 'Multiprocessing-backed parallel loops cannot be nested,'
+ ' below loky, setting n_jobs=1',
+ stacklevel=3)
+ return 1
+
+ elif not (self.in_main_thread() or self.nesting_level == 0):
+ # Prevent posix fork inside in non-main posix threads
+ if n_jobs != 1:
+ warnings.warn(
+ 'Multiprocessing-backed parallel loops cannot be nested'
+ ' below threads, setting n_jobs=1',
+ stacklevel=3)
+ return 1
+
+ return super(MultiprocessingBackend, self).effective_n_jobs(n_jobs)
def configure(self, n_jobs=1, parallel=None, prefer=None, require=None,
- **memmappingpool_args):
+ **memmappingpool_args):
"""Build a process or thread pool and return the number of workers"""
- pass
+ n_jobs = self.effective_n_jobs(n_jobs)
+ if n_jobs == 1:
+ raise FallbackToBackend(
+ SequentialBackend(nesting_level=self.nesting_level))
+
+ # Make sure to free as much memory as possible before forking
+ gc.collect()
+ self._pool = MemmappingPool(n_jobs, **memmappingpool_args)
+ self.parallel = parallel
+ return n_jobs
def terminate(self):
"""Shutdown the process or thread pool"""
- pass
+ super(MultiprocessingBackend, self).terminate()
+ self.reset_batch_stats()
class LokyBackend(AutoBatchingMixin, ParallelBackendBase):
"""Managing pool of workers with loky instead of multiprocessing."""
+
supports_retrieve_callback = True
supports_inner_max_num_threads = True
def configure(self, n_jobs=1, parallel=None, prefer=None, require=None,
- idle_worker_timeout=300, **memmappingexecutor_args):
+ idle_worker_timeout=300, **memmappingexecutor_args):
"""Build a process executor and return the number of workers"""
- pass
+ n_jobs = self.effective_n_jobs(n_jobs)
+ if n_jobs == 1:
+ raise FallbackToBackend(
+ SequentialBackend(nesting_level=self.nesting_level))
+
+ self._workers = get_memmapping_executor(
+ n_jobs, timeout=idle_worker_timeout,
+ env=self._prepare_worker_env(n_jobs=n_jobs),
+ context_id=parallel._id, **memmappingexecutor_args)
+ self.parallel = parallel
+ return n_jobs
def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
- pass
+ if n_jobs == 0:
+ raise ValueError('n_jobs == 0 in Parallel has no meaning')
+ elif mp is None or n_jobs is None:
+ # multiprocessing is not available or disabled, fallback
+ # to sequential mode
+ return 1
+ elif mp.current_process().daemon:
+ # Daemonic processes cannot have children
+ if n_jobs != 1:
+ if inside_dask_worker():
+ msg = (
+ "Inside a Dask worker with daemon=True, "
+ "setting n_jobs=1.\nPossible work-arounds:\n"
+ "- dask.config.set("
+ "{'distributed.worker.daemon': False})\n"
+ "- set the environment variable "
+ "DASK_DISTRIBUTED__WORKER__DAEMON=False\n"
+ "before creating your Dask cluster."
+ )
+ else:
+ msg = (
+ 'Loky-backed parallel loops cannot be called in a'
+ ' multiprocessing, setting n_jobs=1'
+ )
+ warnings.warn(msg, stacklevel=3)
+
+ return 1
+ elif not (self.in_main_thread() or self.nesting_level == 0):
+ # Prevent posix fork inside in non-main posix threads
+ if n_jobs != 1:
+ warnings.warn(
+ 'Loky-backed parallel loops cannot be nested below '
+ 'threads, setting n_jobs=1',
+ stacklevel=3)
+ return 1
+ elif n_jobs < 0:
+ n_jobs = max(cpu_count() + 1 + n_jobs, 1)
+ return n_jobs
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
+ future = self._workers.submit(func)
+ if callback is not None:
+ future.add_done_callback(callback)
+ return future
+
+ def retrieve_result_callback(self, out):
+ try:
+ return out.result()
+ except ShutdownExecutorError:
+ raise RuntimeError(
+ "The executor underlying Parallel has been shutdown. "
+ "This is likely due to the garbage collection of a previous "
+ "generator from a call to Parallel with return_as='generator'."
+ " Make sure the generator is not garbage collected when "
+ "submitting a new job or that it is first properly exhausted."
+ )
+
+ def terminate(self):
+ if self._workers is not None:
+ # Don't terminate the workers as we want to reuse them in later
+ # calls, but cleanup the temporary resources that the Parallel call
+ # created. This 'hack' requires a private, low-level operation.
+ self._workers._temp_folder_manager._clean_temporary_resources(
+ context_id=self.parallel._id, force=False
+ )
+ self._workers = None
+
+ self.reset_batch_stats()
def abort_everything(self, ensure_ready=True):
"""Shutdown the workers and restart a new one with the same parameters
"""
- pass
+ self._workers.terminate(kill_workers=True)
+ self._workers = None
+
+ if ensure_ready:
+ self.configure(n_jobs=self.parallel.n_jobs, parallel=self.parallel)
class FallbackToBackend(Exception):
@@ -330,4 +634,16 @@ class FallbackToBackend(Exception):
def inside_dask_worker():
"""Check whether the current function is executed inside a Dask worker.
"""
- pass
+ # This function can not be in joblib._dask because there would be a
+ # circular import:
+ # _dask imports _parallel_backend that imports _dask ...
+ try:
+ from distributed import get_worker
+ except ImportError:
+ return False
+
+ try:
+ get_worker()
+ return True
+ except ValueError:
+ return False
diff --git a/joblib/_store_backends.py b/joblib/_store_backends.py
index 0ce3682..68e207c 100644
--- a/joblib/_store_backends.py
+++ b/joblib/_store_backends.py
@@ -1,4 +1,5 @@
"""Storage providers backends for Memory caching."""
+
from pickle import PicklingError
import re
import os
@@ -12,12 +13,14 @@ import collections
import operator
import threading
from abc import ABCMeta, abstractmethod
+
from .backports import concurrency_safe_rename
from .disk import mkdirp, memstr_to_bytes, rm_subdirs
from .logger import format_time
from . import numpy_pickle
-CacheItemInfo = collections.namedtuple('CacheItemInfo', 'path size last_access'
- )
+
+CacheItemInfo = collections.namedtuple('CacheItemInfo',
+ 'path size last_access')
class CacheWarning(Warning):
@@ -27,12 +30,18 @@ class CacheWarning(Warning):
def concurrency_safe_write(object_to_write, filename, write_func):
"""Writes an object into a unique file in a concurrency-safe way."""
- pass
+ thread_id = id(threading.current_thread())
+ temporary_filename = '{}.thread-{}-pid-{}'.format(
+ filename, thread_id, os.getpid())
+ write_func(object_to_write, temporary_filename)
+
+ return temporary_filename
class StoreBackendBase(metaclass=ABCMeta):
"""Helper Abstract Base Class which defines all methods that
a StorageBackend must implement."""
+
location = None
@abstractmethod
@@ -53,7 +62,6 @@ class StoreBackendBase(metaclass=ABCMeta):
-------
a file-like object
"""
- pass
@abstractmethod
def _item_exists(self, location):
@@ -71,7 +79,6 @@ class StoreBackendBase(metaclass=ABCMeta):
-------
True if the item exists, False otherwise
"""
- pass
@abstractmethod
def _move_item(self, src, dst):
@@ -86,7 +93,6 @@ class StoreBackendBase(metaclass=ABCMeta):
dst: string
The destination location of an item
"""
- pass
@abstractmethod
def create_location(self, location):
@@ -98,7 +104,6 @@ class StoreBackendBase(metaclass=ABCMeta):
The location in the store. On a filesystem, this corresponds to a
directory.
"""
- pass
@abstractmethod
def clear_location(self, location):
@@ -110,7 +115,6 @@ class StoreBackendBase(metaclass=ABCMeta):
The location in the store. On a filesystem, this corresponds to a
directory or a filename absolute path
"""
- pass
@abstractmethod
def get_items(self):
@@ -121,7 +125,6 @@ class StoreBackendBase(metaclass=ABCMeta):
The list of items identified by their ids (e.g filename in a
filesystem).
"""
- pass
@abstractmethod
def configure(self, location, verbose=0, backend_options=dict()):
@@ -138,7 +141,6 @@ class StoreBackendBase(metaclass=ABCMeta):
Contains a dictionary of named parameters used to configure the
store backend.
"""
- pass
class StoreBackendMixin(object):
@@ -153,101 +155,320 @@ class StoreBackendMixin(object):
def load_item(self, call_id, verbose=1, timestamp=None, metadata=None):
"""Load an item from the store given its id as a list of str."""
- pass
+ full_path = os.path.join(self.location, *call_id)
+
+ if verbose > 1:
+ ts_string = ('{: <16}'.format(format_time(time.time() - timestamp))
+ if timestamp is not None else '')
+ signature = os.path.basename(call_id[0])
+ if metadata is not None and 'input_args' in metadata:
+ kwargs = ', '.join('{}={}'.format(*item)
+ for item in metadata['input_args'].items())
+ signature += '({})'.format(kwargs)
+ msg = '[Memory]{}: Loading {}'.format(ts_string, signature)
+ if verbose < 10:
+ print('{0}...'.format(msg))
+ else:
+ print('{0} from {1}'.format(msg, full_path))
+
+ mmap_mode = (None if not hasattr(self, 'mmap_mode')
+ else self.mmap_mode)
+
+ filename = os.path.join(full_path, 'output.pkl')
+ if not self._item_exists(filename):
+ raise KeyError("Non-existing item (may have been "
+ "cleared).\nFile %s does not exist" % filename)
+
+ # file-like object cannot be used when mmap_mode is set
+ if mmap_mode is None:
+ with self._open_item(filename, "rb") as f:
+ item = numpy_pickle.load(f)
+ else:
+ item = numpy_pickle.load(filename, mmap_mode=mmap_mode)
+ return item
def dump_item(self, call_id, item, verbose=1):
"""Dump an item in the store at the id given as a list of str."""
- pass
+ try:
+ item_path = os.path.join(self.location, *call_id)
+ if not self._item_exists(item_path):
+ self.create_location(item_path)
+ filename = os.path.join(item_path, 'output.pkl')
+ if verbose > 10:
+ print('Persisting in %s' % item_path)
+
+ def write_func(to_write, dest_filename):
+ with self._open_item(dest_filename, "wb") as f:
+ try:
+ numpy_pickle.dump(to_write, f, compress=self.compress)
+ except PicklingError as e:
+ # TODO(1.5) turn into error
+ warnings.warn(
+ "Unable to cache to disk: failed to pickle "
+ "output. In version 1.5 this will raise an "
+ f"exception. Exception: {e}.",
+ FutureWarning
+ )
+
+ self._concurrency_safe_write(item, filename, write_func)
+ except Exception as e: # noqa: E722
+ warnings.warn(
+ "Unable to cache to disk. Possibly a race condition in the "
+ f"creation of the directory. Exception: {e}.",
+ CacheWarning
+ )
def clear_item(self, call_id):
"""Clear the item at the id, given as a list of str."""
- pass
+ item_path = os.path.join(self.location, *call_id)
+ if self._item_exists(item_path):
+ self.clear_location(item_path)
def contains_item(self, call_id):
"""Check if there is an item at the id, given as a list of str."""
- pass
+ item_path = os.path.join(self.location, *call_id)
+ filename = os.path.join(item_path, 'output.pkl')
+
+ return self._item_exists(filename)
def get_item_info(self, call_id):
"""Return information about item."""
- pass
+ return {'location': os.path.join(self.location, *call_id)}
def get_metadata(self, call_id):
"""Return actual metadata of an item."""
- pass
+ try:
+ item_path = os.path.join(self.location, *call_id)
+ filename = os.path.join(item_path, 'metadata.json')
+ with self._open_item(filename, 'rb') as f:
+ return json.loads(f.read().decode('utf-8'))
+ except: # noqa: E722
+ return {}
def store_metadata(self, call_id, metadata):
"""Store metadata of a computation."""
- pass
+ try:
+ item_path = os.path.join(self.location, *call_id)
+ self.create_location(item_path)
+ filename = os.path.join(item_path, 'metadata.json')
+
+ def write_func(to_write, dest_filename):
+ with self._open_item(dest_filename, "wb") as f:
+ f.write(json.dumps(to_write).encode('utf-8'))
+
+ self._concurrency_safe_write(metadata, filename, write_func)
+ except: # noqa: E722
+ pass
def contains_path(self, call_id):
"""Check cached function is available in store."""
- pass
+ func_path = os.path.join(self.location, *call_id)
+ return self.object_exists(func_path)
def clear_path(self, call_id):
"""Clear all items with a common path in the store."""
- pass
+ func_path = os.path.join(self.location, *call_id)
+ if self._item_exists(func_path):
+ self.clear_location(func_path)
def store_cached_func_code(self, call_id, func_code=None):
"""Store the code of the cached function."""
- pass
+ func_path = os.path.join(self.location, *call_id)
+ if not self._item_exists(func_path):
+ self.create_location(func_path)
+
+ if func_code is not None:
+ filename = os.path.join(func_path, "func_code.py")
+ with self._open_item(filename, 'wb') as f:
+ f.write(func_code.encode('utf-8'))
def get_cached_func_code(self, call_id):
"""Store the code of the cached function."""
- pass
+ filename = os.path.join(self.location, *call_id, 'func_code.py')
+ try:
+ with self._open_item(filename, 'rb') as f:
+ return f.read().decode('utf-8')
+ except: # noqa: E722
+ raise
def get_cached_func_info(self, call_id):
"""Return information related to the cached function if it exists."""
- pass
+ return {'location': os.path.join(self.location, *call_id)}
def clear(self):
"""Clear the whole store content."""
- pass
+ self.clear_location(self.location)
- def enforce_store_limits(self, bytes_limit, items_limit=None, age_limit
- =None):
+ def enforce_store_limits(
+ self, bytes_limit, items_limit=None, age_limit=None
+ ):
"""
Remove the store's oldest files to enforce item, byte, and age limits.
"""
- pass
-
- def _get_items_to_delete(self, bytes_limit, items_limit=None, age_limit
- =None):
+ items_to_delete = self._get_items_to_delete(
+ bytes_limit, items_limit, age_limit
+ )
+
+ for item in items_to_delete:
+ if self.verbose > 10:
+ print('Deleting item {0}'.format(item))
+ try:
+ self.clear_location(item.path)
+ except OSError:
+ # Even with ignore_errors=True shutil.rmtree can raise OSError
+ # with:
+ # [Errno 116] Stale file handle if another process has deleted
+ # the folder already.
+ pass
+
+ def _get_items_to_delete(
+ self, bytes_limit, items_limit=None, age_limit=None
+ ):
"""
Get items to delete to keep the store under size, file, & age limits.
"""
- pass
+ if isinstance(bytes_limit, str):
+ bytes_limit = memstr_to_bytes(bytes_limit)
+
+ items = self.get_items()
+ if not items:
+ return []
+
+ size = sum(item.size for item in items)
+
+ if bytes_limit is not None:
+ to_delete_size = size - bytes_limit
+ else:
+ to_delete_size = 0
+
+ if items_limit is not None:
+ to_delete_items = len(items) - items_limit
+ else:
+ to_delete_items = 0
+
+ if age_limit is not None:
+ older_item = min(item.last_access for item in items)
+ deadline = datetime.datetime.now() - age_limit
+ else:
+ deadline = None
+
+ if (
+ to_delete_size <= 0 and to_delete_items <= 0
+ and (deadline is None or older_item > deadline)
+ ):
+ return []
+
+ # We want to delete first the cache items that were accessed a
+ # long time ago
+ items.sort(key=operator.attrgetter('last_access'))
+
+ items_to_delete = []
+ size_so_far = 0
+ items_so_far = 0
+
+ for item in items:
+ if (
+ (size_so_far >= to_delete_size)
+ and items_so_far >= to_delete_items
+ and (deadline is None or deadline < item.last_access)
+ ):
+ break
+
+ items_to_delete.append(item)
+ size_so_far += item.size
+ items_so_far += 1
+
+ return items_to_delete
def _concurrency_safe_write(self, to_write, filename, write_func):
"""Writes an object into a file in a concurrency-safe way."""
- pass
+ temporary_filename = concurrency_safe_write(to_write,
+ filename, write_func)
+ self._move_item(temporary_filename, filename)
def __repr__(self):
"""Printable representation of the store location."""
- return '{class_name}(location="{location}")'.format(class_name=self
- .__class__.__name__, location=self.location)
+ return '{class_name}(location="{location}")'.format(
+ class_name=self.__class__.__name__, location=self.location)
class FileSystemStoreBackend(StoreBackendBase, StoreBackendMixin):
"""A StoreBackend used with local or network file systems."""
+
_open_item = staticmethod(open)
_item_exists = staticmethod(os.path.exists)
_move_item = staticmethod(concurrency_safe_rename)
def clear_location(self, location):
"""Delete location on store."""
- pass
+ if (location == self.location):
+ rm_subdirs(location)
+ else:
+ shutil.rmtree(location, ignore_errors=True)
def create_location(self, location):
"""Create object location on store"""
- pass
+ mkdirp(location)
def get_items(self):
"""Returns the whole list of items available in the store."""
- pass
+ items = []
+
+ for dirpath, _, filenames in os.walk(self.location):
+ is_cache_hash_dir = re.match('[a-f0-9]{32}',
+ os.path.basename(dirpath))
+
+ if is_cache_hash_dir:
+ output_filename = os.path.join(dirpath, 'output.pkl')
+ try:
+ last_access = os.path.getatime(output_filename)
+ except OSError:
+ try:
+ last_access = os.path.getatime(dirpath)
+ except OSError:
+ # The directory has already been deleted
+ continue
+
+ last_access = datetime.datetime.fromtimestamp(last_access)
+ try:
+ full_filenames = [os.path.join(dirpath, fn)
+ for fn in filenames]
+ dirsize = sum(os.path.getsize(fn)
+ for fn in full_filenames)
+ except OSError:
+ # Either output_filename or one of the files in
+ # dirpath does not exist any more. We assume this
+ # directory is being cleaned by another process already
+ continue
+
+ items.append(CacheItemInfo(dirpath, dirsize,
+ last_access))
+
+ return items
def configure(self, location, verbose=1, backend_options=None):
"""Configure the store backend.
For this backend, valid store options are 'compress' and 'mmap_mode'
"""
- pass
+ if backend_options is None:
+ backend_options = {}
+
+ # setup location directory
+ self.location = location
+ if not os.path.exists(self.location):
+ mkdirp(self.location)
+
+ # item can be stored compressed for faster I/O
+ self.compress = backend_options.get('compress', False)
+
+ # FileSystemStoreBackend can be used with mmap_mode options under
+ # certain conditions.
+ mmap_mode = backend_options.get('mmap_mode')
+ if self.compress and mmap_mode is not None:
+ warnings.warn('Compressed items cannot be memmapped in a '
+ 'filesystem store. Option will be ignored.',
+ stacklevel=2)
+
+ self.mmap_mode = mmap_mode
+ self.verbose = verbose
diff --git a/joblib/_utils.py b/joblib/_utils.py
index d2feff7..0b7cc64 100644
--- a/joblib/_utils.py
+++ b/joblib/_utils.py
@@ -1,12 +1,27 @@
+# Adapted from https://stackoverflow.com/a/9558001/2536294
+
import ast
from dataclasses import dataclass
import operator as op
+
+
from ._multiprocessing_helpers import mp
+
if mp is not None:
from .externals.loky.process_executor import _ExceptionWithTraceback
-operators = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, ast.Div:
- op.truediv, ast.FloorDiv: op.floordiv, ast.Mod: op.mod, ast.Pow: op.pow,
- ast.USub: op.neg}
+
+
+# supported operators
+operators = {
+ ast.Add: op.add,
+ ast.Sub: op.sub,
+ ast.Mult: op.mul,
+ ast.Div: op.truediv,
+ ast.FloorDiv: op.floordiv,
+ ast.Mod: op.mod,
+ ast.Pow: op.pow,
+ ast.USub: op.neg,
+}
def eval_expr(expr):
@@ -18,7 +33,23 @@ def eval_expr(expr):
>>> eval_expr('1 + 2*3**(4) / (6 + -7)')
-161.0
"""
- pass
+ try:
+ return eval_(ast.parse(expr, mode="eval").body)
+ except (TypeError, SyntaxError, KeyError) as e:
+ raise ValueError(
+ f"{expr!r} is not a valid or supported arithmetic expression."
+ ) from e
+
+
+def eval_(node):
+ if isinstance(node, ast.Constant): # <constant>
+ return node.value
+ elif isinstance(node, ast.BinOp): # <left> <operator> <right>
+ return operators[type(node.op)](eval_(node.left), eval_(node.right))
+ elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
+ return operators[type(node.op)](eval_(node.operand))
+ else:
+ raise TypeError(node)
@dataclass(frozen=True)
@@ -27,7 +58,7 @@ class _Sentinel:
default_value: object
def __repr__(self):
- return f'default({self.default_value!r})'
+ return f"default({self.default_value!r})"
class _TracebackCapturingWrapper:
@@ -41,3 +72,12 @@ class _TracebackCapturingWrapper:
return self.func(**kwargs)
except BaseException as e:
return _ExceptionWithTraceback(e)
+
+
+def _retrieve_traceback_capturing_wrapped_call(out):
+ if isinstance(out, _ExceptionWithTraceback):
+ rebuild, args = out.__reduce__()
+ out = rebuild(*args)
+ if isinstance(out, BaseException):
+ raise out
+ return out
diff --git a/joblib/backports.py b/joblib/backports.py
index b7178c8..3a14f10 100644
--- a/joblib/backports.py
+++ b/joblib/backports.py
@@ -4,6 +4,7 @@ Backports of fixes for joblib dependencies
import os
import re
import time
+
from os.path import basename
from multiprocessing import util
@@ -65,24 +66,53 @@ class LooseVersion(Version):
We might rexplore this choice in the future if all major Python projects
introduce a dependency on packaging anyway.
"""
- component_re = re.compile('(\\d+ | [a-z]+ | \\.)', re.VERBOSE)
+
+ component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE)
def __init__(self, vstring=None):
if vstring:
self.parse(vstring)
+ def parse(self, vstring):
+ # I've given up on thinking I can reconstruct the version string
+ # from the parsed tuple -- so I just store the string here for
+ # use by __str__
+ self.vstring = vstring
+ components = [x for x in self.component_re.split(vstring)
+ if x and x != '.']
+ for i, obj in enumerate(components):
+ try:
+ components[i] = int(obj)
+ except ValueError:
+ pass
+
+ self.version = components
+
def __str__(self):
return self.vstring
def __repr__(self):
return "LooseVersion ('%s')" % str(self)
+ def _cmp(self, other):
+ if isinstance(other, str):
+ other = LooseVersion(other)
+ elif not isinstance(other, LooseVersion):
+ return NotImplemented
+
+ if self.version == other.version:
+ return 0
+ if self.version < other.version:
+ return -1
+ if self.version > other.version:
+ return 1
+
try:
import numpy as np
- def make_memmap(filename, dtype='uint8', mode='r+', offset=0, shape=
- None, order='C', unlink_on_gc_collect=False):
+ def make_memmap(filename, dtype='uint8', mode='r+', offset=0,
+ shape=None, order='C', unlink_on_gc_collect=False):
"""Custom memmap constructor compatible with numpy.memmap.
This function:
@@ -95,10 +125,30 @@ try:
newly-created memmap that sends a maybe_unlink request for the
memmaped file to resource_tracker.
"""
- pass
+ util.debug(
+ "[MEMMAP READ] creating a memmap (shape {}, filename {}, "
+ "pid {})".format(shape, basename(filename), os.getpid())
+ )
+
+ mm = np.memmap(filename, dtype=dtype, mode=mode, offset=offset,
+ shape=shape, order=order)
+ if LooseVersion(np.__version__) < '1.13':
+ mm.offset = offset
+ if unlink_on_gc_collect:
+ from ._memmapping_reducer import add_maybe_unlink_finalizer
+ add_maybe_unlink_finalizer(mm)
+ return mm
except ImportError:
+ def make_memmap(filename, dtype='uint8', mode='r+', offset=0,
+ shape=None, order='C', unlink_on_gc_collect=False):
+ raise NotImplementedError(
+ "'joblib.backports.make_memmap' should not be used "
+ 'if numpy is not installed.')
+
+
if os.name == 'nt':
- access_denied_errors = 5, 13
+ # https://github.com/joblib/joblib/issues/540
+ access_denied_errors = (5, 13)
from os import replace
def concurrency_safe_rename(src, dst):
@@ -107,6 +157,21 @@ if os.name == 'nt':
On Windows os.replace can yield permission errors if executed by two
different processes.
"""
- pass
+ max_sleep_time = 1
+ total_sleep_time = 0
+ sleep_time = 0.001
+ while total_sleep_time < max_sleep_time:
+ try:
+ replace(src, dst)
+ break
+ except Exception as exc:
+ if getattr(exc, 'winerror', None) in access_denied_errors:
+ time.sleep(sleep_time)
+ total_sleep_time += sleep_time
+ sleep_time *= 2
+ else:
+ raise
+ else:
+ raise
else:
- from os import replace as concurrency_safe_rename
+ from os import replace as concurrency_safe_rename # noqa
diff --git a/joblib/compressor.py b/joblib/compressor.py
index 7a72b91..0d9e261 100644
--- a/joblib/compressor.py
+++ b/joblib/compressor.py
@@ -1,38 +1,49 @@
"""Classes and functions for managing compressors."""
+
import io
import zlib
from joblib.backports import LooseVersion
+
try:
from threading import RLock
except ImportError:
from dummy_threading import RLock
+
try:
import bz2
except ImportError:
bz2 = None
+
try:
import lz4
from lz4.frame import LZ4FrameFile
except ImportError:
lz4 = None
+
try:
import lzma
except ImportError:
lzma = None
-LZ4_NOT_INSTALLED_ERROR = (
- 'LZ4 is not installed. Install it with pip: https://python-lz4.readthedocs.io/'
- )
+
+
+LZ4_NOT_INSTALLED_ERROR = ('LZ4 is not installed. Install it with pip: '
+ 'https://python-lz4.readthedocs.io/')
+
+# Registered compressors
_COMPRESSORS = {}
-_ZFILE_PREFIX = b'ZF'
-_ZLIB_PREFIX = b'x'
+
+# Magic numbers of supported compression file formats.
+_ZFILE_PREFIX = b'ZF' # used with pickle files created before 0.9.3.
+_ZLIB_PREFIX = b'\x78'
_GZIP_PREFIX = b'\x1f\x8b'
_BZ2_PREFIX = b'BZ'
-_XZ_PREFIX = b'\xfd7zXZ'
-_LZMA_PREFIX = b']\x00'
-_LZ4_PREFIX = b'\x04"M\x18'
+_XZ_PREFIX = b'\xfd\x37\x7a\x58\x5a'
+_LZMA_PREFIX = b'\x5d\x00'
+_LZ4_PREFIX = b'\x04\x22\x4D\x18'
-def register_compressor(compressor_name, compressor, force=False):
+def register_compressor(compressor_name, compressor,
+ force=False):
"""Register a new compressor.
Parameters
@@ -42,10 +53,32 @@ def register_compressor(compressor_name, compressor, force=False):
compressor: CompressorWrapper
An instance of a 'CompressorWrapper'.
"""
- pass
+ global _COMPRESSORS
+ if not isinstance(compressor_name, str):
+ raise ValueError("Compressor name should be a string, "
+ "'{}' given.".format(compressor_name))
+
+ if not isinstance(compressor, CompressorWrapper):
+ raise ValueError("Compressor should implement the CompressorWrapper "
+ "interface, '{}' given.".format(compressor))
+
+ if (compressor.fileobj_factory is not None and
+ (not hasattr(compressor.fileobj_factory, 'read') or
+ not hasattr(compressor.fileobj_factory, 'write') or
+ not hasattr(compressor.fileobj_factory, 'seek') or
+ not hasattr(compressor.fileobj_factory, 'tell'))):
+ raise ValueError("Compressor 'fileobj_factory' attribute should "
+ "implement the file object interface, '{}' given."
+ .format(compressor.fileobj_factory))
+
+ if compressor_name in _COMPRESSORS and not force:
+ raise ValueError("Compressor '{}' already registered."
+ .format(compressor_name))
+ _COMPRESSORS[compressor_name] = compressor
-class CompressorWrapper:
+
+class CompressorWrapper():
"""A wrapper around a compressor file object.
Attributes
@@ -68,14 +101,19 @@ class CompressorWrapper:
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ if compresslevel is None:
+ return self.fileobj_factory(fileobj, 'wb')
+ else:
+ return self.fileobj_factory(fileobj, 'wb',
+ compresslevel=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ return self.fileobj_factory(fileobj, 'rb')
class BZ2CompressorWrapper(CompressorWrapper):
+
prefix = _BZ2_PREFIX
extension = '.bz2'
@@ -85,16 +123,29 @@ class BZ2CompressorWrapper(CompressorWrapper):
else:
self.fileobj_factory = None
+ def _check_versions(self):
+ if bz2 is None:
+ raise ValueError('bz2 module is not compiled on your python '
+ 'standard library.')
+
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ self._check_versions()
+ if compresslevel is None:
+ return self.fileobj_factory(fileobj, 'wb')
+ else:
+ return self.fileobj_factory(fileobj, 'wb',
+ compresslevel=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ self._check_versions()
+ fileobj = self.fileobj_factory(fileobj, 'rb')
+ return fileobj
class LZMACompressorWrapper(CompressorWrapper):
+
prefix = _LZMA_PREFIX
extension = '.lzma'
_lzma_format_name = 'FORMAT_ALONE'
@@ -106,22 +157,35 @@ class LZMACompressorWrapper(CompressorWrapper):
else:
self.fileobj_factory = None
+ def _check_versions(self):
+ if lzma is None:
+ raise ValueError('lzma module is not compiled on your python '
+ 'standard library.')
+
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ if compresslevel is None:
+ return self.fileobj_factory(fileobj, 'wb',
+ format=self._lzma_format)
+ else:
+ return self.fileobj_factory(fileobj, 'wb',
+ format=self._lzma_format,
+ preset=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ return lzma.LZMAFile(fileobj, 'rb')
class XZCompressorWrapper(LZMACompressorWrapper):
+
prefix = _XZ_PREFIX
extension = '.xz'
_lzma_format_name = 'FORMAT_XZ'
class LZ4CompressorWrapper(CompressorWrapper):
+
prefix = _LZ4_PREFIX
extension = '.lz4'
@@ -131,15 +195,32 @@ class LZ4CompressorWrapper(CompressorWrapper):
else:
self.fileobj_factory = None
+ def _check_versions(self):
+ if lz4 is None:
+ raise ValueError(LZ4_NOT_INSTALLED_ERROR)
+ lz4_version = lz4.__version__
+ if lz4_version.startswith("v"):
+ lz4_version = lz4_version[1:]
+ if LooseVersion(lz4_version) < LooseVersion('0.19'):
+ raise ValueError(LZ4_NOT_INSTALLED_ERROR)
+
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ self._check_versions()
+ if compresslevel is None:
+ return self.fileobj_factory(fileobj, 'wb')
+ else:
+ return self.fileobj_factory(fileobj, 'wb',
+ compression_level=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ self._check_versions()
+ return self.fileobj_factory(fileobj, 'rb')
+###############################################################################
+# base file compression/decompression object definition
_MODE_CLOSED = 0
_MODE_READ = 1
_MODE_READ_EOF = 2
@@ -170,9 +251,12 @@ class BinaryZlibFile(io.BufferedIOBase):
and 9 specifying the level of compression: 1 produces the least
compression, and 9 produces the most compression. 3 is the default.
"""
+
wbits = zlib.MAX_WBITS
- def __init__(self, filename, mode='rb', compresslevel=3):
+ def __init__(self, filename, mode="rb", compresslevel=3):
+ # This lock must be recursive, so that BufferedIOBase's
+ # readline(), readlines() and writelines() don't deadlock.
self._lock = RLock()
self._fp = None
self._closefp = False
@@ -180,29 +264,33 @@ class BinaryZlibFile(io.BufferedIOBase):
self._pos = 0
self._size = -1
self.compresslevel = compresslevel
- if not isinstance(compresslevel, int) or not 1 <= compresslevel <= 9:
- raise ValueError(
- "'compresslevel' must be an integer between 1 and 9. You provided 'compresslevel={}'"
- .format(compresslevel))
- if mode == 'rb':
+
+ if not isinstance(compresslevel, int) or not (1 <= compresslevel <= 9):
+ raise ValueError("'compresslevel' must be an integer "
+ "between 1 and 9. You provided 'compresslevel={}'"
+ .format(compresslevel))
+
+ if mode == "rb":
self._mode = _MODE_READ
self._decompressor = zlib.decompressobj(self.wbits)
- self._buffer = b''
+ self._buffer = b""
self._buffer_offset = 0
- elif mode == 'wb':
+ elif mode == "wb":
self._mode = _MODE_WRITE
- self._compressor = zlib.compressobj(self.compresslevel, zlib.
- DEFLATED, self.wbits, zlib.DEF_MEM_LEVEL, 0)
+ self._compressor = zlib.compressobj(self.compresslevel,
+ zlib.DEFLATED, self.wbits,
+ zlib.DEF_MEM_LEVEL, 0)
else:
- raise ValueError('Invalid mode: %r' % (mode,))
+ raise ValueError("Invalid mode: %r" % (mode,))
+
if isinstance(filename, str):
self._fp = io.open(filename, mode)
self._closefp = True
- elif hasattr(filename, 'read') or hasattr(filename, 'write'):
+ elif hasattr(filename, "read") or hasattr(filename, "write"):
self._fp = filename
else:
- raise TypeError('filename must be a str or bytes object, or a file'
- )
+ raise TypeError("filename must be a str or bytes object, "
+ "or a file")
def close(self):
"""Flush and close the file.
@@ -210,28 +298,147 @@ class BinaryZlibFile(io.BufferedIOBase):
May be called more than once without error. Once the file is
closed, any other operation on it will raise a ValueError.
"""
- pass
+ with self._lock:
+ if self._mode == _MODE_CLOSED:
+ return
+ try:
+ if self._mode in (_MODE_READ, _MODE_READ_EOF):
+ self._decompressor = None
+ elif self._mode == _MODE_WRITE:
+ self._fp.write(self._compressor.flush())
+ self._compressor = None
+ finally:
+ try:
+ if self._closefp:
+ self._fp.close()
+ finally:
+ self._fp = None
+ self._closefp = False
+ self._mode = _MODE_CLOSED
+ self._buffer = b""
+ self._buffer_offset = 0
@property
def closed(self):
"""True if this file is closed."""
- pass
+ return self._mode == _MODE_CLOSED
def fileno(self):
"""Return the file descriptor for the underlying file."""
- pass
+ self._check_not_closed()
+ return self._fp.fileno()
def seekable(self):
"""Return whether the file supports seeking."""
- pass
+ return self.readable() and self._fp.seekable()
def readable(self):
"""Return whether the file was opened for reading."""
- pass
+ self._check_not_closed()
+ return self._mode in (_MODE_READ, _MODE_READ_EOF)
def writable(self):
"""Return whether the file was opened for writing."""
- pass
+ self._check_not_closed()
+ return self._mode == _MODE_WRITE
+
+ # Mode-checking helper functions.
+
+ def _check_not_closed(self):
+ if self.closed:
+ fname = getattr(self._fp, 'name', None)
+ msg = "I/O operation on closed file"
+ if fname is not None:
+ msg += " {}".format(fname)
+ msg += "."
+ raise ValueError(msg)
+
+ def _check_can_read(self):
+ if self._mode not in (_MODE_READ, _MODE_READ_EOF):
+ self._check_not_closed()
+ raise io.UnsupportedOperation("File not open for reading")
+
+ def _check_can_write(self):
+ if self._mode != _MODE_WRITE:
+ self._check_not_closed()
+ raise io.UnsupportedOperation("File not open for writing")
+
+ def _check_can_seek(self):
+ if self._mode not in (_MODE_READ, _MODE_READ_EOF):
+ self._check_not_closed()
+ raise io.UnsupportedOperation("Seeking is only supported "
+ "on files open for reading")
+ if not self._fp.seekable():
+ raise io.UnsupportedOperation("The underlying file object "
+ "does not support seeking")
+
+ # Fill the readahead buffer if it is empty. Returns False on EOF.
+ def _fill_buffer(self):
+ if self._mode == _MODE_READ_EOF:
+ return False
+ # Depending on the input data, our call to the decompressor may not
+ # return any data. In this case, try again after reading another block.
+ while self._buffer_offset == len(self._buffer):
+ try:
+ rawblock = (self._decompressor.unused_data or
+ self._fp.read(_BUFFER_SIZE))
+ if not rawblock:
+ raise EOFError
+ except EOFError:
+ # End-of-stream marker and end of file. We're good.
+ self._mode = _MODE_READ_EOF
+ self._size = self._pos
+ return False
+ else:
+ self._buffer = self._decompressor.decompress(rawblock)
+ self._buffer_offset = 0
+ return True
+
+ # Read data until EOF.
+ # If return_data is false, consume the data without returning it.
+ def _read_all(self, return_data=True):
+ # The loop assumes that _buffer_offset is 0. Ensure that this is true.
+ self._buffer = self._buffer[self._buffer_offset:]
+ self._buffer_offset = 0
+
+ blocks = []
+ while self._fill_buffer():
+ if return_data:
+ blocks.append(self._buffer)
+ self._pos += len(self._buffer)
+ self._buffer = b""
+ if return_data:
+ return b"".join(blocks)
+
+ # Read a block of up to n bytes.
+ # If return_data is false, consume the data without returning it.
+ def _read_block(self, n_bytes, return_data=True):
+ # If we have enough data buffered, return immediately.
+ end = self._buffer_offset + n_bytes
+ if end <= len(self._buffer):
+ data = self._buffer[self._buffer_offset: end]
+ self._buffer_offset = end
+ self._pos += len(data)
+ return data if return_data else None
+
+ # The loop assumes that _buffer_offset is 0. Ensure that this is true.
+ self._buffer = self._buffer[self._buffer_offset:]
+ self._buffer_offset = 0
+
+ blocks = []
+ while n_bytes > 0 and self._fill_buffer():
+ if n_bytes < len(self._buffer):
+ data = self._buffer[:n_bytes]
+ self._buffer_offset = n_bytes
+ else:
+ data = self._buffer
+ self._buffer = b""
+ if return_data:
+ blocks.append(data)
+ self._pos += len(data)
+ n_bytes -= len(data)
+ if return_data:
+ return b"".join(blocks)
def read(self, size=-1):
"""Read up to size uncompressed bytes from the file.
@@ -239,14 +446,22 @@ class BinaryZlibFile(io.BufferedIOBase):
If size is negative or omitted, read until EOF is reached.
Returns b'' if the file is already at EOF.
"""
- pass
+ with self._lock:
+ self._check_can_read()
+ if size == 0:
+ return b""
+ elif size < 0:
+ return self._read_all()
+ else:
+ return self._read_block(size)
def readinto(self, b):
"""Read up to len(b) bytes into b.
Returns the number of bytes read (0 for EOF).
"""
- pass
+ with self._lock:
+ return io.BufferedIOBase.readinto(self, b)
def write(self, data):
"""Write a byte string to the file.
@@ -255,7 +470,25 @@ class BinaryZlibFile(io.BufferedIOBase):
always len(data). Note that due to buffering, the file on disk
may not reflect the data written until close() is called.
"""
- pass
+ with self._lock:
+ self._check_can_write()
+ # Convert data type if called by io.BufferedWriter.
+ if isinstance(data, memoryview):
+ data = data.tobytes()
+
+ compressed = self._compressor.compress(data)
+ self._fp.write(compressed)
+ self._pos += len(data)
+ return len(data)
+
+ # Rewind the file to the beginning of the data stream.
+ def _rewind(self):
+ self._fp.seek(0, 0)
+ self._mode = _MODE_READ
+ self._pos = 0
+ self._decompressor = zlib.decompressobj(self.wbits)
+ self._buffer = b""
+ self._buffer_offset = 0
def seek(self, offset, whence=0):
"""Change the file position.
@@ -272,18 +505,45 @@ class BinaryZlibFile(io.BufferedIOBase):
Note that seeking is emulated, so depending on the parameters,
this operation may be extremely slow.
"""
- pass
+ with self._lock:
+ self._check_can_seek()
+
+ # Recalculate offset as an absolute file position.
+ if whence == 0:
+ pass
+ elif whence == 1:
+ offset = self._pos + offset
+ elif whence == 2:
+ # Seeking relative to EOF - we need to know the file's size.
+ if self._size < 0:
+ self._read_all(return_data=False)
+ offset = self._size + offset
+ else:
+ raise ValueError("Invalid value for whence: %s" % (whence,))
+
+ # Make it so that offset is the number of bytes to skip forward.
+ if offset < self._pos:
+ self._rewind()
+ else:
+ offset -= self._pos
+
+ # Read and discard data until we reach the desired position.
+ self._read_block(offset, return_data=False)
+
+ return self._pos
def tell(self):
"""Return the current file position."""
- pass
+ with self._lock:
+ self._check_not_closed()
+ return self._pos
class ZlibCompressorWrapper(CompressorWrapper):
def __init__(self):
- CompressorWrapper.__init__(self, obj=BinaryZlibFile, prefix=
- _ZLIB_PREFIX, extension='.z')
+ CompressorWrapper.__init__(self, obj=BinaryZlibFile,
+ prefix=_ZLIB_PREFIX, extension='.z')
class BinaryGzipFile(BinaryZlibFile):
@@ -299,11 +559,12 @@ class BinaryGzipFile(BinaryZlibFile):
and 9 specifying the level of compression: 1 produces the least
compression, and 9 produces the most compression. 3 is the default.
"""
- wbits = 31
+
+ wbits = 31 # zlib compressor/decompressor wbits value for gzip format.
class GzipCompressorWrapper(CompressorWrapper):
def __init__(self):
- CompressorWrapper.__init__(self, obj=BinaryGzipFile, prefix=
- _GZIP_PREFIX, extension='.gz')
+ CompressorWrapper.__init__(self, obj=BinaryGzipFile,
+ prefix=_GZIP_PREFIX, extension='.gz')
diff --git a/joblib/disk.py b/joblib/disk.py
index b35e507..32fbb89 100644
--- a/joblib/disk.py
+++ b/joblib/disk.py
@@ -1,12 +1,22 @@
"""
Disk management utilities.
"""
+
+# Authors: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Lars Buitinck
+# Copyright (c) 2010 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
+
import os
import sys
import time
import errno
import shutil
+
from multiprocessing import util
+
+
try:
WindowsError
except NameError:
@@ -15,22 +25,49 @@ except NameError:
def disk_used(path):
""" Return the disk usage in a directory."""
- pass
+ size = 0
+ for file in os.listdir(path) + ['.']:
+ stat = os.stat(os.path.join(path, file))
+ if hasattr(stat, 'st_blocks'):
+ size += stat.st_blocks * 512
+ else:
+ # on some platform st_blocks is not available (e.g., Windows)
+ # approximate by rounding to next multiple of 512
+ size += (stat.st_size // 512 + 1) * 512
+ # We need to convert to int to avoid having longs on some systems (we
+ # don't want longs to avoid problems we SQLite)
+ return int(size / 1024.)
def memstr_to_bytes(text):
""" Convert a memory text to its value in bytes.
"""
- pass
+ kilo = 1024
+ units = dict(K=kilo, M=kilo ** 2, G=kilo ** 3)
+ try:
+ size = int(units[text[-1]] * float(text[:-1]))
+ except (KeyError, ValueError) as e:
+ raise ValueError(
+ "Invalid literal for size give: %s (type %s) should be "
+ "alike '10G', '500M', '50K'." % (text, type(text))) from e
+ return size
def mkdirp(d):
"""Ensure directory d exists (like mkdir -p on Unix)
No guarantee that the directory is writable.
"""
- pass
+ try:
+ os.makedirs(d)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+# if a rmtree operation fails in rm_subdirs, wait for this much time (in secs),
+# then retry up to RM_SUBDIRS_N_RETRY times. If it still fails, raise the
+# exception. this mechanism ensures that the sub-process gc have the time to
+# collect and close the memmaps before we fail.
RM_SUBDIRS_RETRY_TIME = 0.1
RM_SUBDIRS_N_RETRY = 10
@@ -47,9 +84,53 @@ def rm_subdirs(path, onerror=None):
exc_info is a tuple returned by sys.exc_info(). If onerror is None,
an exception is raised.
"""
- pass
+
+ # NOTE this code is adapted from the one in shutil.rmtree, and is
+ # just as fast
+
+ names = []
+ try:
+ names = os.listdir(path)
+ except os.error:
+ if onerror is not None:
+ onerror(os.listdir, path, sys.exc_info())
+ else:
+ raise
+
+ for name in names:
+ fullname = os.path.join(path, name)
+ delete_folder(fullname, onerror=onerror)
def delete_folder(folder_path, onerror=None, allow_non_empty=True):
"""Utility function to cleanup a temporary folder if it still exists."""
- pass
+ if os.path.isdir(folder_path):
+ if onerror is not None:
+ shutil.rmtree(folder_path, False, onerror)
+ else:
+ # allow the rmtree to fail once, wait and re-try.
+ # if the error is raised again, fail
+ err_count = 0
+ while True:
+ files = os.listdir(folder_path)
+ try:
+ if len(files) == 0 or allow_non_empty:
+ shutil.rmtree(
+ folder_path, ignore_errors=False, onerror=None
+ )
+ util.debug(
+ "Successfully deleted {}".format(folder_path))
+ break
+ else:
+ raise OSError(
+ "Expected empty folder {} but got {} "
+ "files.".format(folder_path, len(files))
+ )
+ except (OSError, WindowsError):
+ err_count += 1
+ if err_count > RM_SUBDIRS_N_RETRY:
+ # the folder cannot be deleted right now. It maybe
+ # because some temporary files have not been deleted
+ # yet.
+ raise
+ time.sleep(RM_SUBDIRS_RETRY_TIME)
diff --git a/joblib/executor.py b/joblib/executor.py
index 6eea29c..6837a7d 100644
--- a/joblib/executor.py
+++ b/joblib/executor.py
@@ -4,22 +4,102 @@ This module provides efficient ways of working with data stored in
shared memory with numpy.memmap arrays without inducing any memory
copy between the parent and child processes.
"""
+# Author: Thomas Moreau <thomas.moreau.2010@gmail.com>
+# Copyright: 2017, Thomas Moreau
+# License: BSD 3 clause
+
from ._memmapping_reducer import get_memmapping_reducers
from ._memmapping_reducer import TemporaryResourcesManager
from .externals.loky.reusable_executor import _ReusablePoolExecutor
+
+
_executor_args = None
+def get_memmapping_executor(n_jobs, **kwargs):
+ return MemmappingExecutor.get_memmapping_executor(n_jobs, **kwargs)
+
+
class MemmappingExecutor(_ReusablePoolExecutor):
@classmethod
def get_memmapping_executor(cls, n_jobs, timeout=300, initializer=None,
- initargs=(), env=None, temp_folder=None, context_id=None, **
- backend_args):
+ initargs=(), env=None, temp_folder=None,
+ context_id=None, **backend_args):
"""Factory for ReusableExecutor with automatic memmapping for large
numpy arrays.
"""
- pass
+ global _executor_args
+ # Check if we can reuse the executor here instead of deferring the test
+ # to loky as the reducers are objects that changes at each call.
+ executor_args = backend_args.copy()
+ executor_args.update(env if env else {})
+ executor_args.update(dict(
+ timeout=timeout, initializer=initializer, initargs=initargs))
+ reuse = _executor_args is None or _executor_args == executor_args
+ _executor_args = executor_args
+
+ manager = TemporaryResourcesManager(temp_folder)
+
+ # reducers access the temporary folder in which to store temporary
+ # pickles through a call to manager.resolve_temp_folder_name. resolving
+ # the folder name dynamically is useful to use different folders across
+ # calls of a same reusable executor
+ job_reducers, result_reducers = get_memmapping_reducers(
+ unlink_on_gc_collect=True,
+ temp_folder_resolver=manager.resolve_temp_folder_name,
+ **backend_args)
+ _executor, executor_is_reused = super().get_reusable_executor(
+ n_jobs, job_reducers=job_reducers, result_reducers=result_reducers,
+ reuse=reuse, timeout=timeout, initializer=initializer,
+ initargs=initargs, env=env
+ )
+
+ if not executor_is_reused:
+ # Only set a _temp_folder_manager for new executors. Reused
+ # executors already have a _temporary_folder_manager that must not
+ # be re-assigned like that because it is referenced in various
+ # places in the reducing machinery of the executor.
+ _executor._temp_folder_manager = manager
+
+ if context_id is not None:
+ # Only register the specified context once we know which manager
+ # the current executor is using, in order to not register an atexit
+ # finalizer twice for the same folder.
+ _executor._temp_folder_manager.register_new_context(context_id)
+
+ return _executor
+
+ def terminate(self, kill_workers=False):
+
+ self.shutdown(kill_workers=kill_workers)
+
+ # When workers are killed in a brutal manner, they cannot execute the
+ # finalizer of their shared memmaps. The refcount of those memmaps may
+ # be off by an unknown number, so instead of decref'ing them, we force
+ # delete the whole temporary folder, and unregister them. There is no
+ # risk of PermissionError at folder deletion because at this
+ # point, all child processes are dead, so all references to temporary
+ # memmaps are closed. Otherwise, just try to delete as much as possible
+ # with allow_non_empty=True but if we can't, it will be clean up later
+ # on by the resource_tracker.
+ with self._submit_resize_lock:
+ self._temp_folder_manager._clean_temporary_resources(
+ force=kill_workers, allow_non_empty=True
+ )
+
+ @property
+ def _temp_folder(self):
+ # Legacy property in tests. could be removed if we refactored the
+ # memmapping tests. SHOULD ONLY BE USED IN TESTS!
+ # We cache this property because it is called late in the tests - at
+ # this point, all context have been unregistered, and
+ # resolve_temp_folder_name raises an error.
+ if getattr(self, '_cached_temp_folder', None) is not None:
+ return self._cached_temp_folder
+ else:
+ self._cached_temp_folder = self._temp_folder_manager.resolve_temp_folder_name() # noqa
+ return self._cached_temp_folder
class _TestingMemmappingExecutor(MemmappingExecutor):
@@ -27,7 +107,11 @@ class _TestingMemmappingExecutor(MemmappingExecutor):
and Executor. This is only for testing purposes.
"""
-
def apply_async(self, func, args):
"""Schedule a func to be run"""
- pass
+ future = self.submit(func, *args)
+ future.get = future.result
+ return future
+
+ def map(self, f, *args):
+ return list(super().map(f, *args))
diff --git a/joblib/externals/cloudpickle/cloudpickle.py b/joblib/externals/cloudpickle/cloudpickle.py
index 92fb769..eb43a96 100644
--- a/joblib/externals/cloudpickle/cloudpickle.py
+++ b/joblib/externals/cloudpickle/cloudpickle.py
@@ -49,6 +49,7 @@ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
+
import _collections_abc
from collections import ChainMap, OrderedDict
import abc
@@ -72,19 +73,58 @@ import typing
import uuid
import warnings
import weakref
-from types import CellType
+
+# The following import is required to be imported in the cloudpickle
+# namespace to be able to load pickle files generated with older versions of
+# cloudpickle. See: tests/test_backward_compat.py
+from types import CellType # noqa: F401
+
+
+# cloudpickle is meant for inter process communication: we expect all
+# communicating processes to run the same Python version hence we favor
+# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
+
+# Names of modules whose resources should be treated as dynamic.
_PICKLE_BY_VALUE_MODULES = set()
+
+# Track the provenance of reconstructed dynamic classes to make it possible to
+# reconstruct instances from the matching singleton class definition when
+# appropriate and preserve the usual "isinstance" semantics of Python objects.
_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
-PYPY = platform.python_implementation() == 'PyPy'
+
+PYPY = platform.python_implementation() == "PyPy"
+
builtin_code_type = None
if PYPY:
+ # builtin-code objects only exist in pypy
builtin_code_type = type(float.__new__.__code__)
+
_extract_code_globals_cache = weakref.WeakKeyDictionary()
+def _get_or_create_tracker_id(class_def):
+ with _DYNAMIC_CLASS_TRACKER_LOCK:
+ class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
+ if class_tracker_id is None:
+ class_tracker_id = uuid.uuid4().hex
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
+ _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
+ return class_tracker_id
+
+
+def _lookup_class_or_track(class_tracker_id, class_def):
+ if class_tracker_id is not None:
+ with _DYNAMIC_CLASS_TRACKER_LOCK:
+ class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
+ class_tracker_id, class_def
+ )
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
+ return class_def
+
+
def register_pickle_by_value(module):
"""Register a module to make it functions and classes picklable by value.
@@ -104,12 +144,52 @@ def register_pickle_by_value(module):
Note: this feature is considered experimental. See the cloudpickle
README.md file for more details and limitations.
"""
- pass
+ if not isinstance(module, types.ModuleType):
+ raise ValueError(f"Input should be a module object, got {str(module)} instead")
+ # In the future, cloudpickle may need a way to access any module registered
+ # for pickling by value in order to introspect relative imports inside
+ # functions pickled by value. (see
+ # https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633).
+ # This access can be ensured by checking that module is present in
+ # sys.modules at registering time and assuming that it will still be in
+ # there when accessed during pickling. Another alternative would be to
+ # store a weakref to the module. Even though cloudpickle does not implement
+ # this introspection yet, in order to avoid a possible breaking change
+ # later, we still enforce the presence of module inside sys.modules.
+ if module.__name__ not in sys.modules:
+ raise ValueError(
+ f"{module} was not imported correctly, have you used an "
+ "`import` statement to access it?"
+ )
+ _PICKLE_BY_VALUE_MODULES.add(module.__name__)
def unregister_pickle_by_value(module):
"""Unregister that the input module should be pickled by value."""
- pass
+ if not isinstance(module, types.ModuleType):
+ raise ValueError(f"Input should be a module object, got {str(module)} instead")
+ if module.__name__ not in _PICKLE_BY_VALUE_MODULES:
+ raise ValueError(f"{module} is not registered for pickle by value")
+ else:
+ _PICKLE_BY_VALUE_MODULES.remove(module.__name__)
+
+
+def list_registry_pickle_by_value():
+ return _PICKLE_BY_VALUE_MODULES.copy()
+
+
+def _is_registered_pickle_by_value(module):
+ module_name = module.__name__
+ if module_name in _PICKLE_BY_VALUE_MODULES:
+ return True
+ while True:
+ parent_name = module_name.rsplit(".", 1)[0]
+ if parent_name == module_name:
+ break
+ if parent_name in _PICKLE_BY_VALUE_MODULES:
+ return True
+ module_name = parent_name
+ return False
def _whichmodule(obj, name):
@@ -121,7 +201,28 @@ def _whichmodule(obj, name):
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
- pass
+ module_name = getattr(obj, "__module__", None)
+
+ if module_name is not None:
+ return module_name
+ # Protect the iteration by using a copy of sys.modules against dynamic
+ # modules that trigger imports of other modules upon calls to getattr or
+ # other threads importing at the same time.
+ for module_name, module in sys.modules.copy().items():
+ # Some modules such as coverage can inject non-module objects inside
+ # sys.modules
+ if (
+ module_name == "__main__"
+ or module is None
+ or not isinstance(module, types.ModuleType)
+ ):
+ continue
+ try:
+ if _getattribute(module, name)[0] is obj:
+ return module_name
+ except Exception:
+ pass
+ return None
def _should_pickle_by_reference(obj, name=None):
@@ -138,12 +239,91 @@ def _should_pickle_by_reference(obj, name=None):
functions and classes or for attributes of modules that have been
explicitly registered to be pickled by value.
"""
- pass
+ if isinstance(obj, types.FunctionType) or issubclass(type(obj), type):
+ module_and_name = _lookup_module_and_qualname(obj, name=name)
+ if module_and_name is None:
+ return False
+ module, name = module_and_name
+ return not _is_registered_pickle_by_value(module)
+
+ elif isinstance(obj, types.ModuleType):
+ # We assume that sys.modules is primarily used as a cache mechanism for
+ # the Python import machinery. Checking if a module has been added in
+ # is sys.modules therefore a cheap and simple heuristic to tell us
+ # whether we can assume that a given module could be imported by name
+ # in another Python process.
+ if _is_registered_pickle_by_value(obj):
+ return False
+ return obj.__name__ in sys.modules
+ else:
+ raise TypeError(
+ "cannot check importability of {} instances".format(type(obj).__name__)
+ )
+
+
+def _lookup_module_and_qualname(obj, name=None):
+ if name is None:
+ name = getattr(obj, "__qualname__", None)
+ if name is None: # pragma: no cover
+ # This used to be needed for Python 2.7 support but is probably not
+ # needed anymore. However we keep the __name__ introspection in case
+ # users of cloudpickle rely on this old behavior for unknown reasons.
+ name = getattr(obj, "__name__", None)
+
+ module_name = _whichmodule(obj, name)
+
+ if module_name is None:
+ # In this case, obj.__module__ is None AND obj was not found in any
+ # imported module. obj is thus treated as dynamic.
+ return None
+
+ if module_name == "__main__":
+ return None
+
+ # Note: if module_name is in sys.modules, the corresponding module is
+ # assumed importable at unpickling time. See #357
+ module = sys.modules.get(module_name, None)
+ if module is None:
+ # The main reason why obj's module would not be imported is that this
+ # module has been dynamically created, using for example
+ # types.ModuleType. The other possibility is that module was removed
+ # from sys.modules after obj was created/imported. But this case is not
+ # supported, as the standard pickle does not support it either.
+ return None
+
+ try:
+ obj2, parent = _getattribute(module, name)
+ except AttributeError:
+ # obj was not found inside the module it points to
+ return None
+ if obj2 is not obj:
+ return None
+ return module, name
def _extract_code_globals(co):
"""Find all globals names read or written to by codeblock co."""
- pass
+ out_names = _extract_code_globals_cache.get(co)
+ if out_names is None:
+ # We use a dict with None values instead of a set to get a
+ # deterministic order and avoid introducing non-deterministic pickle
+ # bytes as a results.
+ out_names = {name: None for name in _walk_global_ops(co)}
+
+ # Declaring a function inside another one using the "def ..." syntax
+ # generates a constant code object corresponding to the one of the
+ # nested function's As the nested function may itself need global
+ # variables, we need to introspect its code, extract its globals, (look
+ # for code object in it's co_consts attribute..) and add the result to
+ # code_globals
+ if co.co_consts:
+ for const in co.co_consts:
+ if isinstance(const, types.CodeType):
+ out_names.update(_extract_code_globals(const))
+
+ _extract_code_globals_cache[co] = out_names
+
+ return out_names
def _find_imported_submodules(code, top_level_dependencies):
@@ -171,29 +351,82 @@ def _find_imported_submodules(code, top_level_dependencies):
that calling func once depickled does not fail due to concurrent.futures
not being imported
"""
- pass
-
-STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
-DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
-LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
-GLOBAL_OPS = STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL
+ subimports = []
+ # check if any known dependency is an imported package
+ for x in top_level_dependencies:
+ if (
+ isinstance(x, types.ModuleType)
+ and hasattr(x, "__package__")
+ and x.__package__
+ ):
+ # check if the package has any currently loaded sub-imports
+ prefix = x.__name__ + "."
+ # A concurrent thread could mutate sys.modules,
+ # make sure we iterate over a copy to avoid exceptions
+ for name in list(sys.modules):
+ # Older versions of pytest will add a "None" module to
+ # sys.modules.
+ if name is not None and name.startswith(prefix):
+ # check whether the function can address the sub-module
+ tokens = set(name[len(prefix) :].split("."))
+ if not tokens - set(code.co_names):
+ subimports.append(sys.modules[name])
+ return subimports
+
+
+# relevant opcodes
+STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"]
+DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"]
+LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"]
+GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
+
+
_BUILTIN_TYPE_NAMES = {}
for k, v in types.__dict__.items():
if type(v) is type:
_BUILTIN_TYPE_NAMES[v] = k
+def _builtin_type(name):
+ if name == "ClassType": # pragma: no cover
+ # Backward compat to load pickle files generated with cloudpickle
+ # < 1.3 even if loading pickle files from older versions is not
+ # officially supported.
+ return type
+ return getattr(types, name)
+
+
def _walk_global_ops(code):
"""Yield referenced name for global-referencing instructions in code."""
- pass
+ for instr in dis.get_instructions(code):
+ op = instr.opcode
+ if op in GLOBAL_OPS:
+ yield instr.argval
def _extract_class_dict(cls):
"""Retrieve a copy of the dict of a class without the inherited method."""
- pass
+ clsdict = dict(cls.__dict__) # copy dict proxy to a dict
+ if len(cls.__bases__) == 1:
+ inherited_dict = cls.__bases__[0].__dict__
+ else:
+ inherited_dict = {}
+ for base in reversed(cls.__bases__):
+ inherited_dict.update(base.__dict__)
+ to_remove = []
+ for name, value in clsdict.items():
+ try:
+ base_value = inherited_dict[name]
+ if value is base_value:
+ to_remove.append(name)
+ except KeyError:
+ pass
+ for name in to_remove:
+ clsdict.pop(name)
+ return clsdict
def is_tornado_coroutine(func):
@@ -201,7 +434,43 @@ def is_tornado_coroutine(func):
Running coroutines are not supported.
"""
- pass
+ warnings.warn(
+ "is_tornado_coroutine is deprecated in cloudpickle 3.0 and will be "
+ "removed in cloudpickle 4.0. Use tornado.gen.is_coroutine_function "
+ "directly instead.",
+ category=DeprecationWarning,
+ )
+ if "tornado.gen" not in sys.modules:
+ return False
+ gen = sys.modules["tornado.gen"]
+ if not hasattr(gen, "is_coroutine_function"):
+ # Tornado version is too old
+ return False
+ return gen.is_coroutine_function(func)
+
+
+def subimport(name):
+ # We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is
+ # the name of a submodule, __import__ will return the top-level root module
+ # of this submodule. For instance, __import__('os.path') returns the `os`
+ # module.
+ __import__(name)
+ return sys.modules[name]
+
+
+def dynamic_subimport(name, vars):
+ mod = types.ModuleType(name)
+ mod.__dict__.update(vars)
+ mod.__dict__["__builtins__"] = builtins.__dict__
+ return mod
+
+
+def _get_cell_contents(cell):
+ try:
+ return cell.cell_contents
+ except ValueError:
+ # Handle empty cells explicitly with a sentinel value.
+ return _empty_cell_value
def instance(cls):
@@ -217,7 +486,7 @@ def instance(cls):
instance : cls
A new instance of ``cls``.
"""
- pass
+ return cls()
@instance
@@ -229,8 +498,31 @@ class _empty_cell_value:
return cls.__name__
-def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
- class_tracker_id, extra):
+def _make_function(code, globals, name, argdefs, closure):
+ # Setting __builtins__ in globals is needed for nogil CPython.
+ globals["__builtins__"] = __builtins__
+ return types.FunctionType(code, globals, name, argdefs, closure)
+
+
+def _make_empty_cell():
+ if False:
+ # trick the compiler into creating an empty cell in our lambda
+ cell = None
+ raise AssertionError("this route should not be executed")
+
+ return (lambda: cell).__closure__[0]
+
+
+def _make_cell(value=_empty_cell_value):
+ cell = _make_empty_cell()
+ if value is not _empty_cell_value:
+ cell.cell_contents = value
+ return cell
+
+
+def _make_skeleton_class(
+ type_constructor, name, bases, type_kwargs, class_tracker_id, extra
+):
"""Build dynamic class with an empty __dict__ to be filled once memoized
If class_tracker_id is not None, try to lookup an existing class definition
@@ -241,11 +533,15 @@ def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
- pass
+ skeleton_class = types.new_class(
+ name, bases, {"metaclass": type_constructor}, lambda ns: ns.update(type_kwargs)
+ )
+ return _lookup_class_or_track(class_tracker_id, skeleton_class)
-def _make_skeleton_enum(bases, name, qualname, members, module,
- class_tracker_id, extra):
+def _make_skeleton_enum(
+ bases, name, qualname, members, module, class_tracker_id, extra
+):
"""Build dynamic enum with an empty __dict__ to be filled once memoized
The creation of the enum class is inspired by the code of
@@ -259,22 +555,447 @@ def _make_skeleton_enum(bases, name, qualname, members, module,
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
- pass
+ # enums always inherit from their base Enum class at the last position in
+ # the list of base classes:
+ enum_base = bases[-1]
+ metacls = enum_base.__class__
+ classdict = metacls.__prepare__(name, bases)
+
+ for member_name, member_value in members.items():
+ classdict[member_name] = member_value
+ enum_class = metacls.__new__(metacls, name, bases, classdict)
+ enum_class.__module__ = module
+ enum_class.__qualname__ = qualname
+
+ return _lookup_class_or_track(class_tracker_id, enum_class)
+
+
+def _make_typevar(name, bound, constraints, covariant, contravariant, class_tracker_id):
+ tv = typing.TypeVar(
+ name,
+ *constraints,
+ bound=bound,
+ covariant=covariant,
+ contravariant=contravariant,
+ )
+ return _lookup_class_or_track(class_tracker_id, tv)
+
+
+def _decompose_typevar(obj):
+ return (
+ obj.__name__,
+ obj.__bound__,
+ obj.__constraints__,
+ obj.__covariant__,
+ obj.__contravariant__,
+ _get_or_create_tracker_id(obj),
+ )
+
+
+def _typevar_reduce(obj):
+ # TypeVar instances require the module information hence why we
+ # are not using the _should_pickle_by_reference directly
+ module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
+
+ if module_and_name is None:
+ return (_make_typevar, _decompose_typevar(obj))
+ elif _is_registered_pickle_by_value(module_and_name[0]):
+ return (_make_typevar, _decompose_typevar(obj))
+
+ return (getattr, module_and_name)
+
+
+def _get_bases(typ):
+ if "__orig_bases__" in getattr(typ, "__dict__", {}):
+ # For generic types (see PEP 560)
+ # Note that simply checking `hasattr(typ, '__orig_bases__')` is not
+ # correct. Subclasses of a fully-parameterized generic class does not
+ # have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')`
+ # will return True because it's defined in the base class.
+ bases_attr = "__orig_bases__"
+ else:
+ # For regular class objects
+ bases_attr = "__bases__"
+ return getattr(typ, bases_attr)
+
+
+def _make_dict_keys(obj, is_ordered=False):
+ if is_ordered:
+ return OrderedDict.fromkeys(obj).keys()
+ else:
+ return dict.fromkeys(obj).keys()
+
+
+def _make_dict_values(obj, is_ordered=False):
+ if is_ordered:
+ return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
+ else:
+ return {i: _ for i, _ in enumerate(obj)}.values()
+
+
+def _make_dict_items(obj, is_ordered=False):
+ if is_ordered:
+ return OrderedDict(obj).items()
+ else:
+ return obj.items()
+
+
+# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
+# -------------------------------------------------
+
+
+def _class_getnewargs(obj):
+ type_kwargs = {}
+ if "__module__" in obj.__dict__:
+ type_kwargs["__module__"] = obj.__module__
+
+ __dict__ = obj.__dict__.get("__dict__", None)
+ if isinstance(__dict__, property):
+ type_kwargs["__dict__"] = __dict__
+
+ return (
+ type(obj),
+ obj.__name__,
+ _get_bases(obj),
+ type_kwargs,
+ _get_or_create_tracker_id(obj),
+ None,
+ )
+
+
+def _enum_getnewargs(obj):
+ members = {e.name: e.value for e in obj}
+ return (
+ obj.__bases__,
+ obj.__name__,
+ obj.__qualname__,
+ members,
+ obj.__module__,
+ _get_or_create_tracker_id(obj),
+ None,
+ )
+
+
+# COLLECTION OF OBJECTS RECONSTRUCTORS
+# ------------------------------------
+def _file_reconstructor(retval):
+ return retval
+
+
+# COLLECTION OF OBJECTS STATE GETTERS
+# -----------------------------------
+
+
+def _function_getstate(func):
+ # - Put func's dynamic attributes (stored in func.__dict__) in state. These
+ # attributes will be restored at unpickling time using
+ # f.__dict__.update(state)
+ # - Put func's members into slotstate. Such attributes will be restored at
+ # unpickling time by iterating over slotstate and calling setattr(func,
+ # slotname, slotvalue)
+ slotstate = {
+ "__name__": func.__name__,
+ "__qualname__": func.__qualname__,
+ "__annotations__": func.__annotations__,
+ "__kwdefaults__": func.__kwdefaults__,
+ "__defaults__": func.__defaults__,
+ "__module__": func.__module__,
+ "__doc__": func.__doc__,
+ "__closure__": func.__closure__,
+ }
+
+ f_globals_ref = _extract_code_globals(func.__code__)
+ f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in func.__globals__}
+
+ if func.__closure__ is not None:
+ closure_values = list(map(_get_cell_contents, func.__closure__))
+ else:
+ closure_values = ()
+
+ # Extract currently-imported submodules used by func. Storing these modules
+ # in a smoke _cloudpickle_subimports attribute of the object's state will
+ # trigger the side effect of importing these modules at unpickling time
+ # (which is necessary for func to work correctly once depickled)
+ slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
+ func.__code__, itertools.chain(f_globals.values(), closure_values)
+ )
+ slotstate["__globals__"] = f_globals
+
+ state = func.__dict__
+ return state, slotstate
+
+
+def _class_getstate(obj):
+ clsdict = _extract_class_dict(obj)
+ clsdict.pop("__weakref__", None)
+
+ if issubclass(type(obj), abc.ABCMeta):
+ # If obj is an instance of an ABCMeta subclass, don't pickle the
+ # cache/negative caches populated during isinstance/issubclass
+ # checks, but pickle the list of registered subclasses of obj.
+ clsdict.pop("_abc_cache", None)
+ clsdict.pop("_abc_negative_cache", None)
+ clsdict.pop("_abc_negative_cache_version", None)
+ registry = clsdict.pop("_abc_registry", None)
+ if registry is None:
+ # The abc caches and registered subclasses of a
+ # class are bundled into the single _abc_impl attribute
+ clsdict.pop("_abc_impl", None)
+ (registry, _, _, _) = abc._get_dump(obj)
+
+ clsdict["_abc_impl"] = [subclass_weakref() for subclass_weakref in registry]
+ else:
+ # In the above if clause, registry is a set of weakrefs -- in
+ # this case, registry is a WeakSet
+ clsdict["_abc_impl"] = [type_ for type_ in registry]
+
+ if "__slots__" in clsdict:
+ # pickle string length optimization: member descriptors of obj are
+ # created automatically from obj's __slots__ attribute, no need to
+ # save them in obj's state
+ if isinstance(obj.__slots__, str):
+ clsdict.pop(obj.__slots__)
+ else:
+ for k in obj.__slots__:
+ clsdict.pop(k, None)
+
+ clsdict.pop("__dict__", None) # unpicklable property object
+
+ return (clsdict, {})
+
+
+def _enum_getstate(obj):
+ clsdict, slotstate = _class_getstate(obj)
+
+ members = {e.name: e.value for e in obj}
+ # Cleanup the clsdict that will be passed to _make_skeleton_enum:
+ # Those attributes are already handled by the metaclass.
+ for attrname in [
+ "_generate_next_value_",
+ "_member_names_",
+ "_member_map_",
+ "_member_type_",
+ "_value2member_map_",
+ ]:
+ clsdict.pop(attrname, None)
+ for member in members:
+ clsdict.pop(member)
+ # Special handling of Enum subclasses
+ return clsdict, slotstate
+
+
+# COLLECTIONS OF OBJECTS REDUCERS
+# -------------------------------
+# A reducer is a function taking a single argument (obj), and that returns a
+# tuple with all the necessary data to re-construct obj. Apart from a few
+# exceptions (list, dict, bytes, int, etc.), a reducer is necessary to
+# correctly pickle an object.
+# While many built-in objects (Exceptions objects, instances of the "object"
+# class, etc), are shipped with their own built-in reducer (invoked using
+# obj.__reduce__), some do not. The following methods were created to "fill
+# these holes".
def _code_reduce(obj):
"""code object reducer."""
- pass
+ # If you are not sure about the order of arguments, take a look at help
+ # of the specific type from types, for example:
+ # >>> from types import CodeType
+ # >>> help(CodeType)
+ if hasattr(obj, "co_exceptiontable"):
+ # Python 3.11 and later: there are some new attributes
+ # related to the enhanced exceptions.
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ obj.co_filename,
+ obj.co_name,
+ obj.co_qualname,
+ obj.co_firstlineno,
+ obj.co_linetable,
+ obj.co_exceptiontable,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ elif hasattr(obj, "co_linetable"):
+ # Python 3.10 and later: obj.co_lnotab is deprecated and constructor
+ # expects obj.co_linetable instead.
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ obj.co_filename,
+ obj.co_name,
+ obj.co_firstlineno,
+ obj.co_linetable,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ elif hasattr(obj, "co_nmeta"): # pragma: no cover
+ # "nogil" Python: modified attributes from 3.9
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_framesize,
+ obj.co_ndefaultargs,
+ obj.co_nmeta,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_varnames,
+ obj.co_filename,
+ obj.co_name,
+ obj.co_firstlineno,
+ obj.co_lnotab,
+ obj.co_exc_handlers,
+ obj.co_jump_table,
+ obj.co_freevars,
+ obj.co_cellvars,
+ obj.co_free2reg,
+ obj.co_cell2reg,
+ )
+ else:
+ # Backward compat for 3.8 and 3.9
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ obj.co_filename,
+ obj.co_name,
+ obj.co_firstlineno,
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ return types.CodeType, args
def _cell_reduce(obj):
"""Cell (containing values of a function's free variables) reducer."""
- pass
+ try:
+ obj.cell_contents
+ except ValueError: # cell is empty
+ return _make_empty_cell, ()
+ else:
+ return _make_cell, (obj.cell_contents,)
+
+
+def _classmethod_reduce(obj):
+ orig_func = obj.__func__
+ return type(obj), (orig_func,)
def _file_reduce(obj):
"""Save a file."""
- pass
+ import io
+
+ if not hasattr(obj, "name") or not hasattr(obj, "mode"):
+ raise pickle.PicklingError(
+ "Cannot pickle files that do not map to an actual file"
+ )
+ if obj is sys.stdout:
+ return getattr, (sys, "stdout")
+ if obj is sys.stderr:
+ return getattr, (sys, "stderr")
+ if obj is sys.stdin:
+ raise pickle.PicklingError("Cannot pickle standard input")
+ if obj.closed:
+ raise pickle.PicklingError("Cannot pickle closed files")
+ if hasattr(obj, "isatty") and obj.isatty():
+ raise pickle.PicklingError("Cannot pickle files that map to tty objects")
+ if "r" not in obj.mode and "+" not in obj.mode:
+ raise pickle.PicklingError(
+ "Cannot pickle files that are not opened for reading: %s" % obj.mode
+ )
+
+ name = obj.name
+
+ retval = io.StringIO()
+
+ try:
+ # Read the whole file
+ curloc = obj.tell()
+ obj.seek(0)
+ contents = obj.read()
+ obj.seek(curloc)
+ except OSError as e:
+ raise pickle.PicklingError(
+ "Cannot pickle file %s as it cannot be read" % name
+ ) from e
+ retval.write(contents)
+ retval.seek(curloc)
+
+ retval.name = name
+ return _file_reconstructor, (retval,)
+
+
+def _getset_descriptor_reduce(obj):
+ return getattr, (obj.__objclass__, obj.__name__)
+
+
+def _mappingproxy_reduce(obj):
+ return types.MappingProxyType, (dict(obj),)
+
+
+def _memoryview_reduce(obj):
+ return bytes, (obj.tobytes(),)
+
+
+def _module_reduce(obj):
+ if _should_pickle_by_reference(obj):
+ return subimport, (obj.__name__,)
+ else:
+ # Some external libraries can populate the "__builtins__" entry of a
+ # module's `__dict__` with unpicklable objects (see #316). For that
+ # reason, we do not attempt to pickle the "__builtins__" entry, and
+ # restore a default value for it at unpickling time.
+ state = obj.__dict__.copy()
+ state.pop("__builtins__", None)
+ return dynamic_subimport, (obj.__name__, state)
+
+
+def _method_reduce(obj):
+ return (types.MethodType, (obj.__func__, obj.__self__))
+
+
+def _logger_reduce(obj):
+ return logging.getLogger, (obj.name,)
+
+
+def _root_logger_reduce(obj):
+ return logging.getLogger, ()
+
+
+def _property_reduce(obj):
+ return property, (obj.fget, obj.fset, obj.fdel, obj.__doc__)
+
+
+def _weakset_reduce(obj):
+ return weakref.WeakSet, (list(obj),)
def _dynamic_class_reduce(obj):
@@ -284,12 +1005,85 @@ def _dynamic_class_reduce(obj):
functions, or that otherwise can't be serialized as attribute lookups
from importable modules.
"""
- pass
+ if Enum is not None and issubclass(obj, Enum):
+ return (
+ _make_skeleton_enum,
+ _enum_getnewargs(obj),
+ _enum_getstate(obj),
+ None,
+ None,
+ _class_setstate,
+ )
+ else:
+ return (
+ _make_skeleton_class,
+ _class_getnewargs(obj),
+ _class_getstate(obj),
+ None,
+ None,
+ _class_setstate,
+ )
def _class_reduce(obj):
"""Select the reducer depending on the dynamic nature of the class obj."""
- pass
+ if obj is type(None): # noqa
+ return type, (None,)
+ elif obj is type(Ellipsis):
+ return type, (Ellipsis,)
+ elif obj is type(NotImplemented):
+ return type, (NotImplemented,)
+ elif obj in _BUILTIN_TYPE_NAMES:
+ return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],)
+ elif not _should_pickle_by_reference(obj):
+ return _dynamic_class_reduce(obj)
+ return NotImplemented
+
+
+def _dict_keys_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_keys, (list(obj),)
+
+
+def _dict_values_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_values, (list(obj),)
+
+
+def _dict_items_reduce(obj):
+ return _make_dict_items, (dict(obj),)
+
+
+def _odict_keys_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_keys, (list(obj), True)
+
+
+def _odict_values_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_values, (list(obj), True)
+
+
+def _odict_items_reduce(obj):
+ return _make_dict_items, (dict(obj), True)
+
+
+def _dataclass_field_base_reduce(obj):
+ return _get_dataclass_field_type_sentinel, (obj.name,)
+
+
+# COLLECTIONS OF OBJECTS STATE SETTERS
+# ------------------------------------
+# state setters are called at unpickling time, once the object is created and
+# it has to be updated to how it was at unpickling time.
def _function_setstate(obj, state):
@@ -299,15 +1093,68 @@ def _function_setstate(obj, state):
cannot rely on the native setstate routine of pickle.load_build, that calls
setattr on items of the slotstate. Instead, we have to modify them inplace.
"""
- pass
+ state, slotstate = state
+ obj.__dict__.update(state)
+
+ obj_globals = slotstate.pop("__globals__")
+ obj_closure = slotstate.pop("__closure__")
+ # _cloudpickle_subimports is a set of submodules that must be loaded for
+ # the pickled function to work correctly at unpickling time. Now that these
+ # submodules are depickled (hence imported), they can be removed from the
+ # object's state (the object state only served as a reference holder to
+ # these submodules)
+ slotstate.pop("_cloudpickle_submodules")
+
+ obj.__globals__.update(obj_globals)
+ obj.__globals__["__builtins__"] = __builtins__
+
+ if obj_closure is not None:
+ for i, cell in enumerate(obj_closure):
+ try:
+ value = cell.cell_contents
+ except ValueError: # cell is empty
+ continue
+ obj.__closure__[i].cell_contents = value
+
+ for k, v in slotstate.items():
+ setattr(obj, k, v)
+
+
+def _class_setstate(obj, state):
+ state, slotstate = state
+ registry = None
+ for attrname, attr in state.items():
+ if attrname == "_abc_impl":
+ registry = attr
+ else:
+ setattr(obj, attrname, attr)
+ if registry is not None:
+ for subclass in registry:
+ obj.register(subclass)
+ return obj
-_DATACLASSE_FIELD_TYPE_SENTINELS = {dataclasses._FIELD.name: dataclasses.
- _FIELD, dataclasses._FIELD_CLASSVAR.name: dataclasses._FIELD_CLASSVAR,
- dataclasses._FIELD_INITVAR.name: dataclasses._FIELD_INITVAR}
+
+# COLLECTION OF DATACLASS UTILITIES
+# ---------------------------------
+# There are some internal sentinel values whose identity must be preserved when
+# unpickling dataclass fields. Each sentinel value has a unique name that we can
+# use to retrieve its identity at unpickling time.
+
+
+_DATACLASSE_FIELD_TYPE_SENTINELS = {
+ dataclasses._FIELD.name: dataclasses._FIELD,
+ dataclasses._FIELD_CLASSVAR.name: dataclasses._FIELD_CLASSVAR,
+ dataclasses._FIELD_INITVAR.name: dataclasses._FIELD_INITVAR,
+}
+
+
+def _get_dataclass_field_type_sentinel(name):
+ return _DATACLASSE_FIELD_TYPE_SENTINELS[name]
class Pickler(pickle.Pickler):
+ # set of reducers defined and used by cloudpickle (private)
_dispatch_table = {}
_dispatch_table[classmethod] = _classmethod_reduce
_dispatch_table[io.TextIOWrapper] = _file_reduce
@@ -335,11 +1182,16 @@ class Pickler(pickle.Pickler):
_dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce
_dispatch_table[abc.abstractproperty] = _property_reduce
_dispatch_table[dataclasses._FIELD_BASE] = _dataclass_field_base_reduce
+
dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
+ # function reducers are defined as instance methods of cloudpickle.Pickler
+ # objects, as they rely on a cloudpickle.Pickler attribute (globals_ref)
def _dynamic_function_reduce(self, func):
"""Reduce a function that is not pickleable via attribute lookup."""
- pass
+ newargs = self._function_getnewargs(func)
+ state = _function_getstate(func)
+ return (_make_function, newargs, state, None, None, _function_setstate)
def _function_reduce(self, obj):
"""Reducer for function objects.
@@ -350,18 +1202,91 @@ class Pickler(pickle.Pickler):
obj using a custom cloudpickle reducer designed specifically to handle
dynamic functions.
"""
- pass
+ if _should_pickle_by_reference(obj):
+ return NotImplemented
+ else:
+ return self._dynamic_function_reduce(obj)
+
+ def _function_getnewargs(self, func):
+ code = func.__code__
+
+ # base_globals represents the future global namespace of func at
+ # unpickling time. Looking it up and storing it in
+ # cloudpickle.Pickler.globals_ref allow functions sharing the same
+ # globals at pickling time to also share them once unpickled, at one
+ # condition: since globals_ref is an attribute of a cloudpickle.Pickler
+ # instance, and that a new cloudpickle.Pickler is created each time
+ # cloudpickle.dump or cloudpickle.dumps is called, functions also need
+ # to be saved within the same invocation of
+ # cloudpickle.dump/cloudpickle.dumps (for example:
+ # cloudpickle.dumps([f1, f2])). There is no such limitation when using
+ # cloudpickle.Pickler.dump, as long as the multiple invocations are
+ # bound to the same cloudpickle.Pickler instance.
+ base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
+
+ if base_globals == {}:
+ # Add module attributes used to resolve relative imports
+ # instructions inside func.
+ for k in ["__package__", "__name__", "__path__", "__file__"]:
+ if k in func.__globals__:
+ base_globals[k] = func.__globals__[k]
+
+ # Do not bind the free variables before the function is created to
+ # avoid infinite recursion.
+ if func.__closure__ is None:
+ closure = None
+ else:
+ closure = tuple(_make_empty_cell() for _ in range(len(code.co_freevars)))
+
+ return code, base_globals, None, None, closure
+
+ def dump(self, obj):
+ try:
+ return super().dump(obj)
+ except RuntimeError as e:
+ if len(e.args) > 0 and "recursion" in e.args[0]:
+ msg = "Could not pickle object as excessively deep recursion required."
+ raise pickle.PicklingError(msg) from e
+ else:
+ raise
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
- super().__init__(file, protocol=protocol, buffer_callback=
- buffer_callback)
+ super().__init__(file, protocol=protocol, buffer_callback=buffer_callback)
+ # map functions __globals__ attribute ids, to ensure that functions
+ # sharing the same global namespace at pickling time also share
+ # their global namespace at unpickling time.
self.globals_ref = {}
self.proto = int(protocol)
+
if not PYPY:
+ # pickle.Pickler is the C implementation of the CPython pickler and
+ # therefore we rely on reduce_override method to customize the pickler
+ # behavior.
+
+ # `cloudpickle.Pickler.dispatch` is only left for backward
+ # compatibility - note that when using protocol 5,
+ # `cloudpickle.Pickler.dispatch` is not an extension of
+ # `pickle._Pickler.dispatch` dictionary, because `cloudpickle.Pickler`
+ # subclasses the C-implemented `pickle.Pickler`, which does not expose
+ # a `dispatch` attribute. Earlier versions of `cloudpickle.Pickler`
+ # used `cloudpickle.Pickler.dispatch` as a class-level attribute
+ # storing all reducers implemented by cloudpickle, but the attribute
+ # name was not a great choice given because it would collide with a
+ # similarly named attribute in the pure-Python `pickle._Pickler`
+ # implementation in the standard library.
dispatch = dispatch_table
+ # Implementation of the reducer_override callback, in order to
+ # efficiently serialize dynamic functions and classes by subclassing
+ # the C-implemented `pickle.Pickler`.
+ # TODO: decorrelate reducer_override (which is tied to CPython's
+ # implementation - would it make sense to backport it to pypy? - and
+ # pickle's protocol 5 which is implementation agnostic. Currently, the
+ # availability of both notions coincide on CPython's pickle, but it may
+ # not be the case anymore when pypy implements protocol 5.
+
def reducer_override(self, obj):
"""Type-agnostic reducing callback for function and classes.
@@ -393,17 +1318,85 @@ class Pickler(pickle.Pickler):
reducers, such as Exceptions. See
https://github.com/cloudpipe/cloudpickle/issues/248
"""
- pass
+ t = type(obj)
+ try:
+ is_anyclass = issubclass(t, type)
+ except TypeError: # t is not a class (old Boost; see SF #502085)
+ is_anyclass = False
+
+ if is_anyclass:
+ return _class_reduce(obj)
+ elif isinstance(obj, types.FunctionType):
+ return self._function_reduce(obj)
+ else:
+ # fallback to save_global, including the Pickler's
+ # dispatch_table
+ return NotImplemented
+
else:
+ # When reducer_override is not available, hack the pure-Python
+ # Pickler's types.FunctionType and type savers. Note: the type saver
+ # must override Pickler.save_global, because pickle.py contains a
+ # hard-coded call to save_global when pickling meta-classes.
dispatch = pickle.Pickler.dispatch.copy()
+ def _save_reduce_pickle5(
+ self,
+ func,
+ args,
+ state=None,
+ listitems=None,
+ dictitems=None,
+ state_setter=None,
+ obj=None,
+ ):
+ save = self.save
+ write = self.write
+ self.save_reduce(
+ func,
+ args,
+ state=None,
+ listitems=listitems,
+ dictitems=dictitems,
+ obj=obj,
+ )
+ # backport of the Python 3.8 state_setter pickle operations
+ save(state_setter)
+ save(obj) # simple BINGET opcode as obj is already memoized.
+ save(state)
+ write(pickle.TUPLE2)
+ # Trigger a state_setter(obj, state) function call.
+ write(pickle.REDUCE)
+ # The purpose of state_setter is to carry-out an
+ # inplace modification of obj. We do not care about what the
+ # method might return, so its output is eventually removed from
+ # the stack.
+ write(pickle.POP)
+
def save_global(self, obj, name=None, pack=struct.pack):
"""Main dispatch method.
The name of this method is somewhat misleading: all types get
dispatched here.
"""
- pass
+ if obj is type(None): # noqa
+ return self.save_reduce(type, (None,), obj=obj)
+ elif obj is type(Ellipsis):
+ return self.save_reduce(type, (Ellipsis,), obj=obj)
+ elif obj is type(NotImplemented):
+ return self.save_reduce(type, (NotImplemented,), obj=obj)
+ elif obj in _BUILTIN_TYPE_NAMES:
+ return self.save_reduce(
+ _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj
+ )
+
+ if name is not None:
+ super().save_global(obj, name=name)
+ elif not _should_pickle_by_reference(obj, name=name):
+ self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
+ else:
+ super().save_global(obj, name=name)
+
dispatch[type] = save_global
def save_function(self, obj, name=None):
@@ -412,7 +1405,14 @@ class Pickler(pickle.Pickler):
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
- pass
+ if _should_pickle_by_reference(obj, name=name):
+ return super().save_global(obj, name=name)
+ elif PYPY and isinstance(obj.__code__, builtin_code_type):
+ return self.save_pypy_builtin_func(obj)
+ else:
+ return self._save_reduce_pickle5(
+ *self._dynamic_function_reduce(obj), obj=obj
+ )
def save_pypy_builtin_func(self, obj):
"""Save pypy equivalent of builtin functions.
@@ -432,10 +1432,19 @@ class Pickler(pickle.Pickler):
this routing should be removed when cloudpickle supports only PyPy
3.6 and later.
"""
- pass
+ rv = (
+ types.FunctionType,
+ (obj.__code__, {}, obj.__name__, obj.__defaults__, obj.__closure__),
+ obj.__dict__,
+ )
+ self.save_reduce(*rv, obj=obj)
+
dispatch[types.FunctionType] = save_function
+# Shorthands similar to pickle.dump/pickle.dumps
+
+
def dump(obj, file, protocol=None, buffer_callback=None):
"""Serialize obj as bytes streamed into file
@@ -449,7 +1458,7 @@ def dump(obj, file, protocol=None, buffer_callback=None):
implementation details that can change from one Python version to the
next).
"""
- pass
+ Pickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
def dumps(obj, protocol=None, buffer_callback=None):
@@ -465,8 +1474,14 @@ def dumps(obj, protocol=None, buffer_callback=None):
implementation details that can change from one Python version to the
next).
"""
- pass
+ with io.BytesIO() as file:
+ cp = Pickler(file, protocol=protocol, buffer_callback=buffer_callback)
+ cp.dump(obj)
+ return file.getvalue()
+# Include pickles unloading functions in this namespace for convenience.
load, loads = pickle.load, pickle.loads
+
+# Backward compat alias.
CloudPickler = Pickler
diff --git a/joblib/externals/loky/_base.py b/joblib/externals/loky/_base.py
index 6d789c8..da0abc1 100644
--- a/joblib/externals/loky/_base.py
+++ b/joblib/externals/loky/_base.py
@@ -1,6 +1,28 @@
+###############################################################################
+# Modification of concurrent.futures.Future
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from concurrent/futures/_base.py (17/02/2017)
+# * Do not use yield from
+# * Use old super syntax
+#
+# Copyright 2009 Brian Quinlan. All Rights Reserved.
+# Licensed to PSF under a Contributor Agreement.
+
from concurrent.futures import Future as _BaseFuture
from concurrent.futures._base import LOGGER
+# To make loky._base.Future instances awaitable by concurrent.futures.wait,
+# derive our custom Future class from _BaseFuture. _invoke_callback is the only
+# modification made to this class in loky.
+# TODO investigate why using `concurrent.futures.Future` directly does not
+# always work in our test suite.
class Future(_BaseFuture):
- pass
+ def _invoke_callbacks(self):
+ for callback in self._done_callbacks:
+ try:
+ callback(self)
+ except BaseException:
+ LOGGER.exception(f"exception calling callback for {self!r}")
diff --git a/joblib/externals/loky/backend/_posix_reduction.py b/joblib/externals/loky/backend/_posix_reduction.py
index c819d41..4b800ec 100644
--- a/joblib/externals/loky/backend/_posix_reduction.py
+++ b/joblib/externals/loky/backend/_posix_reduction.py
@@ -1,16 +1,65 @@
+###############################################################################
+# Extra reducers for Unix based system and connections objects
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from multiprocessing/reduction.py (17/02/2017)
+# * Add adapted reduction for LokyProcesses and socket/Connection
+#
import os
import socket
import _socket
from multiprocessing.connection import Connection
from multiprocessing.context import get_spawning_popen
+
from .reduction import register
-HAVE_SEND_HANDLE = hasattr(socket, 'CMSG_LEN') and hasattr(socket, 'SCM_RIGHTS'
- ) and hasattr(socket.socket, 'sendmsg')
+
+HAVE_SEND_HANDLE = (
+ hasattr(socket, "CMSG_LEN")
+ and hasattr(socket, "SCM_RIGHTS")
+ and hasattr(socket.socket, "sendmsg")
+)
+
+
+def _mk_inheritable(fd):
+ os.set_inheritable(fd, True)
+ return fd
def DupFd(fd):
"""Return a wrapper for an fd."""
- pass
+ popen_obj = get_spawning_popen()
+ if popen_obj is not None:
+ return popen_obj.DupFd(popen_obj.duplicate_for_child(fd))
+ elif HAVE_SEND_HANDLE:
+ from multiprocessing import resource_sharer
+
+ return resource_sharer.DupFd(fd)
+ else:
+ raise TypeError(
+ "Cannot pickle connection object. This object can only be "
+ "passed when spawning a new process"
+ )
+
+
+def _reduce_socket(s):
+ df = DupFd(s.fileno())
+ return _rebuild_socket, (df, s.family, s.type, s.proto)
+
+
+def _rebuild_socket(df, family, type, proto):
+ fd = df.detach()
+ return socket.fromfd(fd, family, type, proto)
+
+
+def rebuild_connection(df, readable, writable):
+ fd = df.detach()
+ return Connection(fd, readable, writable)
+
+
+def reduce_connection(conn):
+ df = DupFd(conn.fileno())
+ return rebuild_connection, (df, conn.readable, conn.writable)
register(socket.socket, _reduce_socket)
diff --git a/joblib/externals/loky/backend/_win_reduction.py b/joblib/externals/loky/backend/_win_reduction.py
index 0a4276c..506d0ec 100644
--- a/joblib/externals/loky/backend/_win_reduction.py
+++ b/joblib/externals/loky/backend/_win_reduction.py
@@ -1,7 +1,18 @@
+###############################################################################
+# Extra reducers for Windows system and connections objects
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from multiprocessing/reduction.py (17/02/2017)
+# * Add adapted reduction for LokyProcesses and socket/PipeConnection
+#
import socket
from multiprocessing import connection
from multiprocessing.reduction import _reduce_socket
+
from .reduction import register
+
+# register reduction for win32 communication objects
register(socket.socket, _reduce_socket)
register(connection.Connection, connection.reduce_connection)
register(connection.PipeConnection, connection.reduce_pipe_connection)
diff --git a/joblib/externals/loky/backend/context.py b/joblib/externals/loky/backend/context.py
index 1e0d413..d0f5903 100644
--- a/joblib/externals/loky/backend/context.py
+++ b/joblib/externals/loky/backend/context.py
@@ -1,3 +1,14 @@
+###############################################################################
+# Basic context management with LokyContext
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from multiprocessing/context.py
+# * Create a context ensuring loky uses only objects that are compatible
+# * Add LokyContext to the list of context of multiprocessing so loky can be
+# used with multiprocessing.set_start_method
+# * Implement a CFS-aware amd physical-core aware cpu_count function.
+#
import os
import sys
import math
@@ -7,20 +18,68 @@ import warnings
import multiprocessing as mp
from multiprocessing import get_context as mp_get_context
from multiprocessing.context import BaseContext
+
+
from .process import LokyProcess, LokyInitMainProcess
+
+# Apparently, on older Python versions, loky cannot work 61 workers on Windows
+# but instead 60: ¯\_(ツ)_/¯
if sys.version_info >= (3, 8):
from concurrent.futures.process import _MAX_WINDOWS_WORKERS
+
if sys.version_info < (3, 10):
_MAX_WINDOWS_WORKERS = _MAX_WINDOWS_WORKERS - 1
else:
+ # compat for versions before 3.8 which do not define this.
_MAX_WINDOWS_WORKERS = 60
-START_METHODS = ['loky', 'loky_init_main', 'spawn']
-if sys.platform != 'win32':
- START_METHODS += ['fork', 'forkserver']
+
+START_METHODS = ["loky", "loky_init_main", "spawn"]
+if sys.platform != "win32":
+ START_METHODS += ["fork", "forkserver"]
+
_DEFAULT_START_METHOD = None
+
+# Cache for the number of physical cores to avoid repeating subprocess calls.
+# It should not change during the lifetime of the program.
physical_cores_cache = None
+def get_context(method=None):
+ # Try to overload the default context
+ method = method or _DEFAULT_START_METHOD or "loky"
+ if method == "fork":
+ # If 'fork' is explicitly requested, warn user about potential issues.
+ warnings.warn(
+ "`fork` start method should not be used with "
+ "`loky` as it does not respect POSIX. Try using "
+ "`spawn` or `loky` instead.",
+ UserWarning,
+ )
+ try:
+ return mp_get_context(method)
+ except ValueError:
+ raise ValueError(
+ f"Unknown context '{method}'. Value should be in "
+ f"{START_METHODS}."
+ )
+
+
+def set_start_method(method, force=False):
+ global _DEFAULT_START_METHOD
+ if _DEFAULT_START_METHOD is not None and not force:
+ raise RuntimeError("context has already been set")
+ assert method is None or method in START_METHODS, (
+ f"'{method}' is not a valid start_method. It should be in "
+ f"{START_METHODS}"
+ )
+
+ _DEFAULT_START_METHOD = method
+
+
+def get_start_method():
+ return _DEFAULT_START_METHOD
+
+
def cpu_count(only_physical_cores=False):
"""Return the number of CPUs the current process can use.
@@ -47,12 +106,127 @@ def cpu_count(only_physical_cores=False):
It is also always larger or equal to 1.
"""
- pass
+ # Note: os.cpu_count() is allowed to return None in its docstring
+ os_cpu_count = os.cpu_count() or 1
+ if sys.platform == "win32":
+ # On Windows, attempting to use more than 61 CPUs would result in a
+ # OS-level error. See https://bugs.python.org/issue26903. According to
+ # https://learn.microsoft.com/en-us/windows/win32/procthread/processor-groups
+ # it might be possible to go beyond with a lot of extra work but this
+ # does not look easy.
+ os_cpu_count = min(os_cpu_count, _MAX_WINDOWS_WORKERS)
+
+ cpu_count_user = _cpu_count_user(os_cpu_count)
+ aggregate_cpu_count = max(min(os_cpu_count, cpu_count_user), 1)
+
+ if not only_physical_cores:
+ return aggregate_cpu_count
+
+ if cpu_count_user < os_cpu_count:
+ # Respect user setting
+ return max(cpu_count_user, 1)
+
+ cpu_count_physical, exception = _count_physical_cores()
+ if cpu_count_physical != "not found":
+ return cpu_count_physical
+
+ # Fallback to default behavior
+ if exception is not None:
+ # warns only the first time
+ warnings.warn(
+ "Could not find the number of physical cores for the "
+ f"following reason:\n{exception}\n"
+ "Returning the number of logical cores instead. You can "
+ "silence this warning by setting LOKY_MAX_CPU_COUNT to "
+ "the number of cores you want to use."
+ )
+ traceback.print_tb(exception.__traceback__)
+
+ return aggregate_cpu_count
+
+
+def _cpu_count_cgroup(os_cpu_count):
+ # Cgroup CPU bandwidth limit available in Linux since 2.6 kernel
+ cpu_max_fname = "/sys/fs/cgroup/cpu.max"
+ cfs_quota_fname = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"
+ cfs_period_fname = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"
+ if os.path.exists(cpu_max_fname):
+ # cgroup v2
+ # https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html
+ with open(cpu_max_fname) as fh:
+ cpu_quota_us, cpu_period_us = fh.read().strip().split()
+ elif os.path.exists(cfs_quota_fname) and os.path.exists(cfs_period_fname):
+ # cgroup v1
+ # https://www.kernel.org/doc/html/latest/scheduler/sched-bwc.html#management
+ with open(cfs_quota_fname) as fh:
+ cpu_quota_us = fh.read().strip()
+ with open(cfs_period_fname) as fh:
+ cpu_period_us = fh.read().strip()
+ else:
+ # No Cgroup CPU bandwidth limit (e.g. non-Linux platform)
+ cpu_quota_us = "max"
+ cpu_period_us = 100_000 # unused, for consistency with default values
+
+ if cpu_quota_us == "max":
+ # No active Cgroup quota on a Cgroup-capable platform
+ return os_cpu_count
+ else:
+ cpu_quota_us = int(cpu_quota_us)
+ cpu_period_us = int(cpu_period_us)
+ if cpu_quota_us > 0 and cpu_period_us > 0:
+ return math.ceil(cpu_quota_us / cpu_period_us)
+ else: # pragma: no cover
+ # Setting a negative cpu_quota_us value is a valid way to disable
+ # cgroup CPU bandwith limits
+ return os_cpu_count
+
+
+def _cpu_count_affinity(os_cpu_count):
+ # Number of available CPUs given affinity settings
+ if hasattr(os, "sched_getaffinity"):
+ try:
+ return len(os.sched_getaffinity(0))
+ except NotImplementedError:
+ pass
+
+ # On PyPy and possibly other platforms, os.sched_getaffinity does not exist
+ # or raises NotImplementedError, let's try with the psutil if installed.
+ try:
+ import psutil
+
+ p = psutil.Process()
+ if hasattr(p, "cpu_affinity"):
+ return len(p.cpu_affinity())
+
+ except ImportError: # pragma: no cover
+ if (
+ sys.platform == "linux"
+ and os.environ.get("LOKY_MAX_CPU_COUNT") is None
+ ):
+ # PyPy does not implement os.sched_getaffinity on Linux which
+ # can cause severe oversubscription problems. Better warn the
+ # user in this particularly pathological case which can wreck
+ # havoc, typically on CI workers.
+ warnings.warn(
+ "Failed to inspect CPU affinity constraints on this system. "
+ "Please install psutil or explictly set LOKY_MAX_CPU_COUNT."
+ )
+
+ # This can happen for platforms that do not implement any kind of CPU
+ # infinity such as macOS-based platforms.
+ return os_cpu_count
def _cpu_count_user(os_cpu_count):
"""Number of user defined available CPUs"""
- pass
+ cpu_count_affinity = _cpu_count_affinity(os_cpu_count)
+
+ cpu_count_cgroup = _cpu_count_cgroup(os_cpu_count)
+
+ # User defined soft-limit passed as a loky specific environment variable.
+ cpu_count_loky = int(os.environ.get("LOKY_MAX_CPU_COUNT", os_cpu_count))
+
+ return min(cpu_count_affinity, cpu_count_cgroup, cpu_count_loky)
def _count_physical_cores():
@@ -63,23 +237,80 @@ def _count_physical_cores():
The number of physical cores is cached to avoid repeating subprocess calls.
"""
- pass
+ exception = None
+
+ # First check if the value is cached
+ global physical_cores_cache
+ if physical_cores_cache is not None:
+ return physical_cores_cache, exception
+
+ # Not cached yet, find it
+ try:
+ if sys.platform == "linux":
+ cpu_info = subprocess.run(
+ "lscpu --parse=core".split(), capture_output=True, text=True
+ )
+ cpu_info = cpu_info.stdout.splitlines()
+ cpu_info = {line for line in cpu_info if not line.startswith("#")}
+ cpu_count_physical = len(cpu_info)
+ elif sys.platform == "win32":
+ cpu_info = subprocess.run(
+ "wmic CPU Get NumberOfCores /Format:csv".split(),
+ capture_output=True,
+ text=True,
+ )
+ cpu_info = cpu_info.stdout.splitlines()
+ cpu_info = [
+ l.split(",")[1]
+ for l in cpu_info
+ if (l and l != "Node,NumberOfCores")
+ ]
+ cpu_count_physical = sum(map(int, cpu_info))
+ elif sys.platform == "darwin":
+ cpu_info = subprocess.run(
+ "sysctl -n hw.physicalcpu".split(),
+ capture_output=True,
+ text=True,
+ )
+ cpu_info = cpu_info.stdout
+ cpu_count_physical = int(cpu_info)
+ else:
+ raise NotImplementedError(f"unsupported platform: {sys.platform}")
+
+ # if cpu_count_physical < 1, we did not find a valid value
+ if cpu_count_physical < 1:
+ raise ValueError(f"found {cpu_count_physical} physical cores < 1")
+
+ except Exception as e:
+ exception = e
+ cpu_count_physical = "not found"
+
+ # Put the result in cache
+ physical_cores_cache = cpu_count_physical
+
+ return cpu_count_physical, exception
class LokyContext(BaseContext):
"""Context relying on the LokyProcess."""
- _name = 'loky'
+
+ _name = "loky"
Process = LokyProcess
cpu_count = staticmethod(cpu_count)
def Queue(self, maxsize=0, reducers=None):
"""Returns a queue object"""
- pass
+ from .queues import Queue
+
+ return Queue(maxsize, reducers=reducers, ctx=self.get_context())
def SimpleQueue(self, reducers=None):
"""Returns a queue object"""
- pass
- if sys.platform != 'win32':
+ from .queues import SimpleQueue
+
+ return SimpleQueue(reducers=reducers, ctx=self.get_context())
+
+ if sys.platform != "win32":
"""For Unix platform, use our custom implementation of synchronize
ensuring that we use the loky.backend.resource_tracker to clean-up
the semaphores in case of a worker crash.
@@ -87,27 +318,39 @@ class LokyContext(BaseContext):
def Semaphore(self, value=1):
"""Returns a semaphore object"""
- pass
+ from .synchronize import Semaphore
+
+ return Semaphore(value=value)
def BoundedSemaphore(self, value):
"""Returns a bounded semaphore object"""
- pass
+ from .synchronize import BoundedSemaphore
+
+ return BoundedSemaphore(value)
def Lock(self):
"""Returns a lock object"""
- pass
+ from .synchronize import Lock
+
+ return Lock()
def RLock(self):
"""Returns a recurrent lock object"""
- pass
+ from .synchronize import RLock
+
+ return RLock()
def Condition(self, lock=None):
"""Returns a condition object"""
- pass
+ from .synchronize import Condition
+
+ return Condition(lock)
def Event(self):
"""Returns an event object"""
- pass
+ from .synchronize import Event
+
+ return Event()
class LokyInitMainContext(LokyContext):
@@ -124,10 +367,12 @@ class LokyInitMainContext(LokyContext):
For more details, see the end of the following section of python doc
https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming
"""
- _name = 'loky_init_main'
+
+ _name = "loky_init_main"
Process = LokyInitMainProcess
+# Register loky context so it works with multiprocessing.get_context
ctx_loky = LokyContext()
-mp.context._concrete_contexts['loky'] = ctx_loky
-mp.context._concrete_contexts['loky_init_main'] = LokyInitMainContext()
+mp.context._concrete_contexts["loky"] = ctx_loky
+mp.context._concrete_contexts["loky_init_main"] = LokyInitMainContext()
diff --git a/joblib/externals/loky/backend/fork_exec.py b/joblib/externals/loky/backend/fork_exec.py
index a8af34a..2353c42 100644
--- a/joblib/externals/loky/backend/fork_exec.py
+++ b/joblib/externals/loky/backend/fork_exec.py
@@ -1,7 +1,43 @@
+###############################################################################
+# Launch a subprocess using forkexec and make sure only the needed fd are
+# shared in the two process.
+#
+# author: Thomas Moreau and Olivier Grisel
+#
import os
import sys
-def close_fds(keep_fds):
+def close_fds(keep_fds): # pragma: no cover
"""Close all the file descriptors except those in keep_fds."""
- pass
+
+ # Make sure to keep stdout and stderr open for logging purpose
+ keep_fds = {*keep_fds, 1, 2}
+
+ # We try to retrieve all the open fds
+ try:
+ open_fds = {int(fd) for fd in os.listdir("/proc/self/fd")}
+ except FileNotFoundError:
+ import resource
+
+ max_nfds = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
+ open_fds = {*range(max_nfds)}
+
+ for i in open_fds - keep_fds:
+ try:
+ os.close(i)
+ except OSError:
+ pass
+
+
+def fork_exec(cmd, keep_fds, env=None):
+ # copy the environment variables to set in the child process
+ env = env or {}
+ child_env = {**os.environ, **env}
+
+ pid = os.fork()
+ if pid == 0: # pragma: no cover
+ close_fds(keep_fds)
+ os.execve(sys.executable, cmd, child_env)
+ else:
+ return pid
diff --git a/joblib/externals/loky/backend/popen_loky_posix.py b/joblib/externals/loky/backend/popen_loky_posix.py
index ef2f2b6..74395be 100644
--- a/joblib/externals/loky/backend/popen_loky_posix.py
+++ b/joblib/externals/loky/backend/popen_loky_posix.py
@@ -1,3 +1,8 @@
+###############################################################################
+# Popen for LokyProcess.
+#
+# author: Thomas Moreau and Olivier Grisel
+#
import os
import sys
import signal
@@ -6,18 +11,33 @@ from io import BytesIO
from multiprocessing import util, process
from multiprocessing.connection import wait
from multiprocessing.context import set_spawning_popen
+
from . import reduction, resource_tracker, spawn
-__all__ = ['Popen']
-class _DupFd:
+__all__ = ["Popen"]
+
+
+#
+# Wrapper for an fd used while launching a process
+#
+
+class _DupFd:
def __init__(self, fd):
self.fd = reduction._mk_inheritable(fd)
+ def detach(self):
+ return self.fd
+
+
+#
+# Start child process using subprocess.Popen
+#
+
class Popen:
- method = 'loky'
+ method = "loky"
DupFd = _DupFd
def __init__(self, process_obj):
@@ -27,19 +47,128 @@ class Popen:
self._fds = []
self._launch(process_obj)
+ def duplicate_for_child(self, fd):
+ self._fds.append(fd)
+ return reduction._mk_inheritable(fd)
+
+ def poll(self, flag=os.WNOHANG):
+ if self.returncode is None:
+ while True:
+ try:
+ pid, sts = os.waitpid(self.pid, flag)
+ except OSError:
+ # Child process not yet created. See #1731717
+ # e.errno == errno.ECHILD == 10
+ return None
+ else:
+ break
+ if pid == self.pid:
+ if os.WIFSIGNALED(sts):
+ self.returncode = -os.WTERMSIG(sts)
+ else:
+ assert os.WIFEXITED(sts)
+ self.returncode = os.WEXITSTATUS(sts)
+ return self.returncode
+
+ def wait(self, timeout=None):
+ if self.returncode is None:
+ if timeout is not None:
+ if not wait([self.sentinel], timeout):
+ return None
+ # This shouldn't block if wait() returned successfully.
+ return self.poll(os.WNOHANG if timeout == 0.0 else 0)
+ return self.returncode
+
+ def terminate(self):
+ if self.returncode is None:
+ try:
+ os.kill(self.pid, signal.SIGTERM)
+ except ProcessLookupError:
+ pass
+ except OSError:
+ if self.wait(timeout=0.1) is None:
+ raise
+
+ def _launch(self, process_obj):
+
+ tracker_fd = resource_tracker._resource_tracker.getfd()
-if __name__ == '__main__':
+ fp = BytesIO()
+ set_spawning_popen(self)
+ try:
+ prep_data = spawn.get_preparation_data(
+ process_obj._name,
+ getattr(process_obj, "init_main_module", True),
+ )
+ reduction.dump(prep_data, fp)
+ reduction.dump(process_obj, fp)
+
+ finally:
+ set_spawning_popen(None)
+
+ try:
+ parent_r, child_w = os.pipe()
+ child_r, parent_w = os.pipe()
+ # for fd in self._fds:
+ # _mk_inheritable(fd)
+
+ cmd_python = [sys.executable]
+ cmd_python += ["-m", self.__module__]
+ cmd_python += ["--process-name", str(process_obj.name)]
+ cmd_python += ["--pipe", str(reduction._mk_inheritable(child_r))]
+ reduction._mk_inheritable(child_w)
+ reduction._mk_inheritable(tracker_fd)
+ self._fds += [child_r, child_w, tracker_fd]
+ if sys.version_info >= (3, 8) and os.name == "posix":
+ mp_tracker_fd = prep_data["mp_tracker_args"]["fd"]
+ self.duplicate_for_child(mp_tracker_fd)
+
+ from .fork_exec import fork_exec
+
+ pid = fork_exec(cmd_python, self._fds, env=process_obj.env)
+ util.debug(
+ f"launched python with pid {pid} and cmd:\n{cmd_python}"
+ )
+ self.sentinel = parent_r
+
+ method = "getbuffer"
+ if not hasattr(fp, method):
+ method = "getvalue"
+ with os.fdopen(parent_w, "wb") as f:
+ f.write(getattr(fp, method)())
+ self.pid = pid
+ finally:
+ if parent_r is not None:
+ util.Finalize(self, os.close, (parent_r,))
+ for fd in (child_r, child_w):
+ if fd is not None:
+ os.close(fd)
+
+ @staticmethod
+ def thread_is_spawning():
+ return True
+
+
+if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser('Command line parser')
- parser.add_argument('--pipe', type=int, required=True, help=
- 'File handle for the pipe')
- parser.add_argument('--process-name', type=str, default=None, help=
- 'Identifier for debugging purpose')
+
+ parser = argparse.ArgumentParser("Command line parser")
+ parser.add_argument(
+ "--pipe", type=int, required=True, help="File handle for the pipe"
+ )
+ parser.add_argument(
+ "--process-name",
+ type=str,
+ default=None,
+ help="Identifier for debugging purpose",
+ )
+
args = parser.parse_args()
+
info = {}
exitcode = 1
try:
- with os.fdopen(args.pipe, 'rb') as from_parent:
+ with os.fdopen(args.pipe, "rb") as from_parent:
process.current_process()._inheriting = True
try:
prep_data = pickle.load(from_parent)
@@ -47,15 +176,18 @@ if __name__ == '__main__':
process_obj = pickle.load(from_parent)
finally:
del process.current_process()._inheriting
+
exitcode = process_obj._bootstrap()
except Exception:
- print('\n\n' + '-' * 80)
- print(f'{args.process_name} failed with traceback: ')
- print('-' * 80)
+ print("\n\n" + "-" * 80)
+ print(f"{args.process_name} failed with traceback: ")
+ print("-" * 80)
import traceback
+
print(traceback.format_exc())
- print('\n' + '-' * 80)
+ print("\n" + "-" * 80)
finally:
if from_parent is not None:
from_parent.close()
+
sys.exit(exitcode)
diff --git a/joblib/externals/loky/backend/popen_loky_win32.py b/joblib/externals/loky/backend/popen_loky_win32.py
index b174751..4f85f65 100644
--- a/joblib/externals/loky/backend/popen_loky_win32.py
+++ b/joblib/externals/loky/backend/popen_loky_win32.py
@@ -6,10 +6,35 @@ from pickle import load
from multiprocessing import process, util
from multiprocessing.context import set_spawning_popen
from multiprocessing.popen_spawn_win32 import Popen as _Popen
+
from . import reduction, spawn
-__all__ = ['Popen']
-WINENV = hasattr(sys, '_base_executable') and not _path_eq(sys.executable,
- sys._base_executable)
+
+
+__all__ = ["Popen"]
+
+#
+#
+#
+
+
+def _path_eq(p1, p2):
+ return p1 == p2 or os.path.normcase(p1) == os.path.normcase(p2)
+
+
+WINENV = hasattr(sys, "_base_executable") and not _path_eq(
+ sys.executable, sys._base_executable
+)
+
+
+def _close_handles(*handles):
+ for handle in handles:
+ _winapi.CloseHandle(handle)
+
+
+#
+# We define a Popen class similar to the one from subprocess, but
+# whose constructor takes a process object as its argument.
+#
class Popen(_Popen):
@@ -24,34 +49,66 @@ class Popen(_Popen):
We also use the loky preparation data, in particular to handle main_module
inits and the loky resource tracker.
"""
- method = 'loky'
+
+ method = "loky"
def __init__(self, process_obj):
- prep_data = spawn.get_preparation_data(process_obj._name, getattr(
- process_obj, 'init_main_module', True))
+ prep_data = spawn.get_preparation_data(
+ process_obj._name, getattr(process_obj, "init_main_module", True)
+ )
+
+ # read end of pipe will be duplicated by the child process
+ # -- see spawn_main() in spawn.py.
+ #
+ # bpo-33929: Previously, the read end of pipe was "stolen" by the child
+ # process, but it leaked a handle if the child process had been
+ # terminated before it could steal the handle from the parent process.
rhandle, whandle = _winapi.CreatePipe(None, 0)
wfd = msvcrt.open_osfhandle(whandle, 0)
cmd = get_command_line(parent_pid=os.getpid(), pipe_handle=rhandle)
+
python_exe = spawn.get_executable()
+
+ # copy the environment variables to set in the child process
child_env = {**os.environ, **process_obj.env}
+
+ # bpo-35797: When running in a venv, we bypass the redirect
+ # executor and launch our base Python.
if WINENV and _path_eq(python_exe, sys.executable):
cmd[0] = python_exe = sys._base_executable
- child_env['__PYVENV_LAUNCHER__'] = sys.executable
- cmd = ' '.join(f'"{x}"' for x in cmd)
- with open(wfd, 'wb') as to_child:
+ child_env["__PYVENV_LAUNCHER__"] = sys.executable
+
+ cmd = " ".join(f'"{x}"' for x in cmd)
+
+ with open(wfd, "wb") as to_child:
+ # start process
try:
- hp, ht, pid, _ = _winapi.CreateProcess(python_exe, cmd,
- None, None, False, 0, child_env, None, None)
+ hp, ht, pid, _ = _winapi.CreateProcess(
+ python_exe,
+ cmd,
+ None,
+ None,
+ False,
+ 0,
+ child_env,
+ None,
+ None,
+ )
_winapi.CloseHandle(ht)
except BaseException:
_winapi.CloseHandle(rhandle)
raise
+
+ # set attributes of self
self.pid = pid
self.returncode = None
self._handle = hp
self.sentinel = int(hp)
- self.finalizer = util.Finalize(self, _close_handles, (self.
- sentinel, int(rhandle)))
+ self.finalizer = util.Finalize(
+ self, _close_handles, (self.sentinel, int(rhandle))
+ )
+
+ # send information to child
set_spawning_popen(self)
try:
reduction.dump(prep_data, to_child)
@@ -62,14 +119,55 @@ class Popen(_Popen):
def get_command_line(pipe_handle, parent_pid, **kwds):
"""Returns prefix of command line used for spawning a child process."""
- pass
+ if getattr(sys, "frozen", False):
+ return [sys.executable, "--multiprocessing-fork", pipe_handle]
+ else:
+ prog = (
+ "from joblib.externals.loky.backend.popen_loky_win32 import main; "
+ f"main(pipe_handle={pipe_handle}, parent_pid={parent_pid})"
+ )
+ opts = util._args_from_interpreter_flags()
+ return [
+ spawn.get_executable(),
+ *opts,
+ "-c",
+ prog,
+ "--multiprocessing-fork",
+ ]
def is_forking(argv):
"""Return whether commandline indicates we are forking."""
- pass
+ if len(argv) >= 2 and argv[1] == "--multiprocessing-fork":
+ return True
+ else:
+ return False
def main(pipe_handle, parent_pid=None):
"""Run code specified by data received over pipe."""
- pass
+ assert is_forking(sys.argv), "Not forking"
+
+ if parent_pid is not None:
+ source_process = _winapi.OpenProcess(
+ _winapi.SYNCHRONIZE | _winapi.PROCESS_DUP_HANDLE, False, parent_pid
+ )
+ else:
+ source_process = None
+ new_handle = reduction.duplicate(
+ pipe_handle, source_process=source_process
+ )
+ fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY)
+ parent_sentinel = source_process
+
+ with os.fdopen(fd, "rb", closefd=True) as from_parent:
+ process.current_process()._inheriting = True
+ try:
+ preparation_data = load(from_parent)
+ spawn.prepare(preparation_data, parent_sentinel)
+ self = load(from_parent)
+ finally:
+ del process.current_process()._inheriting
+
+ exitcode = self._bootstrap(parent_sentinel)
+ sys.exit(exitcode)
diff --git a/joblib/externals/loky/backend/process.py b/joblib/externals/loky/backend/process.py
index 26a97f9..3562550 100644
--- a/joblib/externals/loky/backend/process.py
+++ b/joblib/externals/loky/backend/process.py
@@ -1,36 +1,85 @@
+###############################################################################
+# LokyProcess implementation
+#
+# authors: Thomas Moreau and Olivier Grisel
+#
+# based on multiprocessing/process.py (17/02/2017)
+#
import sys
from multiprocessing.context import assert_spawning
from multiprocessing.process import BaseProcess
class LokyProcess(BaseProcess):
- _start_method = 'loky'
+ _start_method = "loky"
- def __init__(self, group=None, target=None, name=None, args=(), kwargs=
- {}, daemon=None, init_main_module=False, env=None):
- super().__init__(group=group, target=target, name=name, args=args,
- kwargs=kwargs, daemon=daemon)
+ def __init__(
+ self,
+ group=None,
+ target=None,
+ name=None,
+ args=(),
+ kwargs={},
+ daemon=None,
+ init_main_module=False,
+ env=None,
+ ):
+ super().__init__(
+ group=group,
+ target=target,
+ name=name,
+ args=args,
+ kwargs=kwargs,
+ daemon=daemon,
+ )
self.env = {} if env is None else env
self.authkey = self.authkey
self.init_main_module = init_main_module
+ @staticmethod
+ def _Popen(process_obj):
+ if sys.platform == "win32":
+ from .popen_loky_win32 import Popen
+ else:
+ from .popen_loky_posix import Popen
+ return Popen(process_obj)
+
class LokyInitMainProcess(LokyProcess):
- _start_method = 'loky_init_main'
+ _start_method = "loky_init_main"
- def __init__(self, group=None, target=None, name=None, args=(), kwargs=
- {}, daemon=None):
- super().__init__(group=group, target=target, name=name, args=args,
- kwargs=kwargs, daemon=daemon, init_main_module=True)
+ def __init__(
+ self,
+ group=None,
+ target=None,
+ name=None,
+ args=(),
+ kwargs={},
+ daemon=None,
+ ):
+ super().__init__(
+ group=group,
+ target=target,
+ name=name,
+ args=args,
+ kwargs=kwargs,
+ daemon=daemon,
+ init_main_module=True,
+ )
-class AuthenticationKey(bytes):
+#
+# We subclass bytes to avoid accidental transmission of auth keys over network
+#
+
+class AuthenticationKey(bytes):
def __reduce__(self):
try:
assert_spawning(self)
except RuntimeError:
raise TypeError(
- 'Pickling an AuthenticationKey object is disallowed for security reasons'
- )
+ "Pickling an AuthenticationKey object is "
+ "disallowed for security reasons"
+ )
return AuthenticationKey, (bytes(self),)
diff --git a/joblib/externals/loky/backend/queues.py b/joblib/externals/loky/backend/queues.py
index 704e2a3..5afd99b 100644
--- a/joblib/externals/loky/backend/queues.py
+++ b/joblib/externals/loky/backend/queues.py
@@ -1,55 +1,236 @@
+###############################################################################
+# Queue and SimpleQueue implementation for loky
+#
+# authors: Thomas Moreau, Olivier Grisel
+#
+# based on multiprocessing/queues.py (16/02/2017)
+# * Add some custom reducers for the Queues/SimpleQueue to tweak the
+# pickling process. (overload Queue._feed/SimpleQueue.put)
+#
import os
import sys
import errno
import weakref
import threading
from multiprocessing import util
-from multiprocessing.queues import Full, Queue as mp_Queue, SimpleQueue as mp_SimpleQueue, _sentinel
+from multiprocessing.queues import (
+ Full,
+ Queue as mp_Queue,
+ SimpleQueue as mp_SimpleQueue,
+ _sentinel,
+)
from multiprocessing.context import assert_spawning
+
from .reduction import dumps
-__all__ = ['Queue', 'SimpleQueue', 'Full']
-class Queue(mp_Queue):
+__all__ = ["Queue", "SimpleQueue", "Full"]
+
+class Queue(mp_Queue):
def __init__(self, maxsize=0, reducers=None, ctx=None):
super().__init__(maxsize=maxsize, ctx=ctx)
self._reducers = reducers
+ # Use custom queue set/get state to be able to reduce the custom reducers
def __getstate__(self):
assert_spawning(self)
- return (self._ignore_epipe, self._maxsize, self._reader, self.
- _writer, self._reducers, self._rlock, self._wlock, self._sem,
- self._opid)
+ return (
+ self._ignore_epipe,
+ self._maxsize,
+ self._reader,
+ self._writer,
+ self._reducers,
+ self._rlock,
+ self._wlock,
+ self._sem,
+ self._opid,
+ )
def __setstate__(self, state):
- (self._ignore_epipe, self._maxsize, self._reader, self._writer,
- self._reducers, self._rlock, self._wlock, self._sem, self._opid
- ) = state
+ (
+ self._ignore_epipe,
+ self._maxsize,
+ self._reader,
+ self._writer,
+ self._reducers,
+ self._rlock,
+ self._wlock,
+ self._sem,
+ self._opid,
+ ) = state
if sys.version_info >= (3, 9):
self._reset()
else:
self._after_fork()
+ # Overload _start_thread to correctly call our custom _feed
+ def _start_thread(self):
+ util.debug("Queue._start_thread()")
+
+ # Start thread which transfers data from buffer to pipe
+ self._buffer.clear()
+ self._thread = threading.Thread(
+ target=Queue._feed,
+ args=(
+ self._buffer,
+ self._notempty,
+ self._send_bytes,
+ self._wlock,
+ self._writer.close,
+ self._reducers,
+ self._ignore_epipe,
+ self._on_queue_feeder_error,
+ self._sem,
+ ),
+ name="QueueFeederThread",
+ )
+ self._thread.daemon = True
+
+ util.debug("doing self._thread.start()")
+ self._thread.start()
+ util.debug("... done self._thread.start()")
+
+ # On process exit we will wait for data to be flushed to pipe.
+ #
+ # However, if this process created the queue then all
+ # processes which use the queue will be descendants of this
+ # process. Therefore waiting for the queue to be flushed
+ # is pointless once all the child processes have been joined.
+ created_by_this_process = self._opid == os.getpid()
+ if not self._joincancelled and not created_by_this_process:
+ self._jointhread = util.Finalize(
+ self._thread,
+ Queue._finalize_join,
+ [weakref.ref(self._thread)],
+ exitpriority=-5,
+ )
+
+ # Send sentinel to the thread queue object when garbage collected
+ self._close = util.Finalize(
+ self,
+ Queue._finalize_close,
+ [self._buffer, self._notempty],
+ exitpriority=10,
+ )
+
+ # Overload the _feed methods to use our custom pickling strategy.
+ @staticmethod
+ def _feed(
+ buffer,
+ notempty,
+ send_bytes,
+ writelock,
+ close,
+ reducers,
+ ignore_epipe,
+ onerror,
+ queue_sem,
+ ):
+ util.debug("starting thread to feed data to pipe")
+ nacquire = notempty.acquire
+ nrelease = notempty.release
+ nwait = notempty.wait
+ bpopleft = buffer.popleft
+ sentinel = _sentinel
+ if sys.platform != "win32":
+ wacquire = writelock.acquire
+ wrelease = writelock.release
+ else:
+ wacquire = None
+
+ while True:
+ try:
+ nacquire()
+ try:
+ if not buffer:
+ nwait()
+ finally:
+ nrelease()
+ try:
+ while True:
+ obj = bpopleft()
+ if obj is sentinel:
+ util.debug("feeder thread got sentinel -- exiting")
+ close()
+ return
+
+ # serialize the data before acquiring the lock
+ obj_ = dumps(obj, reducers=reducers)
+ if wacquire is None:
+ send_bytes(obj_)
+ else:
+ wacquire()
+ try:
+ send_bytes(obj_)
+ finally:
+ wrelease()
+ # Remove references early to avoid leaking memory
+ del obj, obj_
+ except IndexError:
+ pass
+ except BaseException as e:
+ if ignore_epipe and getattr(e, "errno", 0) == errno.EPIPE:
+ return
+ # Since this runs in a daemon thread the resources it uses
+ # may be become unusable while the process is cleaning up.
+ # We ignore errors which happen after the process has
+ # started to cleanup.
+ if util.is_exiting():
+ util.info(f"error in queue thread: {e}")
+ return
+ else:
+ queue_sem.release()
+ onerror(e, obj)
+
def _on_queue_feeder_error(self, e, obj):
"""
Private API hook called when feeding data in the background thread
raises an exception. For overriding by concurrent.futures.
"""
- pass
+ import traceback
+ traceback.print_exc()
-class SimpleQueue(mp_SimpleQueue):
+class SimpleQueue(mp_SimpleQueue):
def __init__(self, reducers=None, ctx=None):
super().__init__(ctx=ctx)
+
+ # Add possiblity to use custom reducers
self._reducers = reducers
+ def close(self):
+ self._reader.close()
+ self._writer.close()
+
+ # Use custom queue set/get state to be able to reduce the custom reducers
def __getstate__(self):
assert_spawning(self)
- return (self._reader, self._writer, self._reducers, self._rlock,
- self._wlock)
+ return (
+ self._reader,
+ self._writer,
+ self._reducers,
+ self._rlock,
+ self._wlock,
+ )
def __setstate__(self, state):
- (self._reader, self._writer, self._reducers, self._rlock, self._wlock
- ) = state
+ (
+ self._reader,
+ self._writer,
+ self._reducers,
+ self._rlock,
+ self._wlock,
+ ) = state
+
+ # Overload put to use our customizable reducer
+ def put(self, obj):
+ # serialize the data before acquiring the lock
+ obj = dumps(obj, reducers=self._reducers)
+ if self._wlock is None:
+ # writes to a message oriented win32 pipe are atomic
+ self._writer.send_bytes(obj)
+ else:
+ with self._wlock:
+ self._writer.send_bytes(obj)
diff --git a/joblib/externals/loky/backend/reduction.py b/joblib/externals/loky/backend/reduction.py
index 6770d67..bed32ba 100644
--- a/joblib/externals/loky/backend/reduction.py
+++ b/joblib/externals/loky/backend/reduction.py
@@ -1,45 +1,224 @@
+###############################################################################
+# Customizable Pickler with some basic reducers
+#
+# author: Thomas Moreau
+#
+# adapted from multiprocessing/reduction.py (17/02/2017)
+# * Replace the ForkingPickler with a similar _LokyPickler,
+# * Add CustomizableLokyPickler to allow customizing pickling process
+# on the fly.
+#
import copyreg
import io
import functools
import types
import sys
import os
+
from multiprocessing import util
from pickle import loads, HIGHEST_PROTOCOL
+
+###############################################################################
+# Enable custom pickling in Loky.
+
_dispatch_table = {}
+def register(type_, reduce_function):
+ _dispatch_table[type_] = reduce_function
+
+
+###############################################################################
+# Registers extra pickling routines to improve picklization for loky
+
+
+# make methods picklable
+def _reduce_method(m):
+ if m.__self__ is None:
+ return getattr, (m.__class__, m.__func__.__name__)
+ else:
+ return getattr, (m.__self__, m.__func__.__name__)
+
+
class _C:
- pass
+ def f(self):
+ pass
+
+ @classmethod
+ def h(cls):
+ pass
register(type(_C().f), _reduce_method)
register(type(_C.h), _reduce_method)
-if not hasattr(sys, 'pypy_version_info'):
+
+
+if not hasattr(sys, "pypy_version_info"):
+ # PyPy uses functions instead of method_descriptors and wrapper_descriptors
+ def _reduce_method_descriptor(m):
+ return getattr, (m.__objclass__, m.__name__)
+
register(type(list.append), _reduce_method_descriptor)
register(type(int.__add__), _reduce_method_descriptor)
+
+
+# Make partial func pickable
+def _reduce_partial(p):
+ return _rebuild_partial, (p.func, p.args, p.keywords or {})
+
+
+def _rebuild_partial(func, args, keywords):
+ return functools.partial(func, *args, **keywords)
+
+
register(functools.partial, _reduce_partial)
-if sys.platform != 'win32':
- from ._posix_reduction import _mk_inheritable
+
+if sys.platform != "win32":
+ from ._posix_reduction import _mk_inheritable # noqa: F401
else:
- from . import _win_reduction
+ from . import _win_reduction # noqa: F401
+
+# global variable to change the pickler behavior
try:
- from joblib.externals import cloudpickle
- DEFAULT_ENV = 'cloudpickle'
+ from joblib.externals import cloudpickle # noqa: F401
+
+ DEFAULT_ENV = "cloudpickle"
except ImportError:
- DEFAULT_ENV = 'pickle'
-ENV_LOKY_PICKLER = os.environ.get('LOKY_PICKLER', DEFAULT_ENV)
+ # If cloudpickle is not present, fallback to pickle
+ DEFAULT_ENV = "pickle"
+
+ENV_LOKY_PICKLER = os.environ.get("LOKY_PICKLER", DEFAULT_ENV)
_LokyPickler = None
_loky_pickler_name = None
+
+
+def set_loky_pickler(loky_pickler=None):
+ global _LokyPickler, _loky_pickler_name
+
+ if loky_pickler is None:
+ loky_pickler = ENV_LOKY_PICKLER
+
+ loky_pickler_cls = None
+
+ # The default loky_pickler is cloudpickle
+ if loky_pickler in ["", None]:
+ loky_pickler = "cloudpickle"
+
+ if loky_pickler == _loky_pickler_name:
+ return
+
+ if loky_pickler == "cloudpickle":
+ from joblib.externals.cloudpickle import CloudPickler as loky_pickler_cls
+ else:
+ try:
+ from importlib import import_module
+
+ module_pickle = import_module(loky_pickler)
+ loky_pickler_cls = module_pickle.Pickler
+ except (ImportError, AttributeError) as e:
+ extra_info = (
+ "\nThis error occurred while setting loky_pickler to"
+ f" '{loky_pickler}', as required by the env variable "
+ "LOKY_PICKLER or the function set_loky_pickler."
+ )
+ e.args = (e.args[0] + extra_info,) + e.args[1:]
+ e.msg = e.args[0]
+ raise e
+
+ util.debug(
+ f"Using '{loky_pickler if loky_pickler else 'cloudpickle'}' for "
+ "serialization."
+ )
+
+ class CustomizablePickler(loky_pickler_cls):
+ _loky_pickler_cls = loky_pickler_cls
+
+ def _set_dispatch_table(self, dispatch_table):
+ for ancestor_class in self._loky_pickler_cls.mro():
+ dt_attribute = getattr(ancestor_class, "dispatch_table", None)
+ if isinstance(dt_attribute, types.MemberDescriptorType):
+ # Ancestor class (typically _pickle.Pickler) has a
+ # member_descriptor for its "dispatch_table" attribute. Use
+ # it to set the dispatch_table as a member instead of a
+ # dynamic attribute in the __dict__ of the instance,
+ # otherwise it will not be taken into account by the C
+ # implementation of the dump method if a subclass defines a
+ # class-level dispatch_table attribute as was done in
+ # cloudpickle 1.6.0:
+ # https://github.com/joblib/loky/pull/260
+ dt_attribute.__set__(self, dispatch_table)
+ break
+
+ # On top of member descriptor set, also use setattr such that code
+ # that directly access self.dispatch_table gets a consistent view
+ # of the same table.
+ self.dispatch_table = dispatch_table
+
+ def __init__(self, writer, reducers=None, protocol=HIGHEST_PROTOCOL):
+ loky_pickler_cls.__init__(self, writer, protocol=protocol)
+ if reducers is None:
+ reducers = {}
+
+ if hasattr(self, "dispatch_table"):
+ # Force a copy that we will update without mutating the
+ # any class level defined dispatch_table.
+ loky_dt = dict(self.dispatch_table)
+ else:
+ # Use standard reducers as bases
+ loky_dt = copyreg.dispatch_table.copy()
+
+ # Register loky specific reducers
+ loky_dt.update(_dispatch_table)
+
+ # Set the new dispatch table, taking care of the fact that we
+ # need to use the member_descriptor when we inherit from a
+ # subclass of the C implementation of the Pickler base class
+ # with an class level dispatch_table attribute.
+ self._set_dispatch_table(loky_dt)
+
+ # Register the reducers
+ for type, reduce_func in reducers.items():
+ self.register(type, reduce_func)
+
+ def register(self, type, reduce_func):
+ """Attach a reducer function to a given type in the dispatch table."""
+ self.dispatch_table[type] = reduce_func
+
+ _LokyPickler = CustomizablePickler
+ _loky_pickler_name = loky_pickler
+
+
+def get_loky_pickler_name():
+ global _loky_pickler_name
+ return _loky_pickler_name
+
+
+def get_loky_pickler():
+ global _LokyPickler
+ return _LokyPickler
+
+
+# Set it to its default value
set_loky_pickler()
def dump(obj, file, reducers=None, protocol=None):
"""Replacement for pickle.dump() using _LokyPickler."""
- pass
+ global _LokyPickler
+ _LokyPickler(file, reducers=reducers, protocol=protocol).dump(obj)
+
+
+def dumps(obj, reducers=None, protocol=None):
+ global _LokyPickler
+ buf = io.BytesIO()
+ dump(obj, buf, reducers=reducers, protocol=protocol)
+ return buf.getbuffer()
-__all__ = ['dump', 'dumps', 'loads', 'register', 'set_loky_pickler']
-if sys.platform == 'win32':
+
+__all__ = ["dump", "dumps", "loads", "register", "set_loky_pickler"]
+
+if sys.platform == "win32":
from multiprocessing.reduction import duplicate
- __all__ += ['duplicate']
+
+ __all__ += ["duplicate"]
diff --git a/joblib/externals/loky/backend/resource_tracker.py b/joblib/externals/loky/backend/resource_tracker.py
index 3550ef5..25204a7 100644
--- a/joblib/externals/loky/backend/resource_tracker.py
+++ b/joblib/externals/loky/backend/resource_tracker.py
@@ -1,3 +1,48 @@
+###############################################################################
+# Server process to keep track of unlinked resources, like folders and
+# semaphores and clean them.
+#
+# author: Thomas Moreau
+#
+# adapted from multiprocessing/semaphore_tracker.py (17/02/2017)
+# * include custom spawnv_passfds to start the process
+# * add some VERBOSE logging
+#
+# TODO: multiprocessing.resource_tracker was contributed to Python 3.8 so
+# once loky drops support for Python 3.7 it might be possible to stop
+# maintaining this loky-specific fork. As a consequence, it might also be
+# possible to stop maintaining the loky.backend.synchronize fork of
+# multiprocessing.synchronize.
+
+#
+# On Unix we run a server process which keeps track of unlinked
+# resources. The server ignores SIGINT and SIGTERM and reads from a
+# pipe. The resource_tracker implements a reference counting scheme: each time
+# a Python process anticipates the shared usage of a resource by another
+# process, it signals the resource_tracker of this shared usage, and in return,
+# the resource_tracker increments the resource's reference count by 1.
+# Similarly, when access to a resource is closed by a Python process, the
+# process notifies the resource_tracker by asking it to decrement the
+# resource's reference count by 1. When the reference count drops to 0, the
+# resource_tracker attempts to clean up the underlying resource.
+
+# Finally, every other process connected to the resource tracker has a copy of
+# the writable end of the pipe used to communicate with it, so the resource
+# tracker gets EOF when all other processes have exited. Then the
+# resource_tracker process unlinks any remaining leaked resources (with
+# reference count above 0)
+
+# For semaphores, this is important because the system only supports a limited
+# number of named semaphores, and they will not be automatically removed till
+# the next reboot. Without this resource tracker process, "killall python"
+# would probably leave unlinked semaphores.
+
+# Note that this behavior differs from CPython's resource_tracker, which only
+# implements list of shared resources, and not a proper refcounting scheme.
+# Also, CPython's resource tracker will only attempt to cleanup those shared
+# resources once all procsses connected to the resouce tracker have exited.
+
+
import os
import shutil
import sys
@@ -6,49 +51,151 @@ import warnings
import threading
from _multiprocessing import sem_unlink
from multiprocessing import util
+
from . import spawn
-if sys.platform == 'win32':
+
+if sys.platform == "win32":
import _winapi
import msvcrt
from multiprocessing.reduction import duplicate
-__all__ = ['ensure_running', 'register', 'unregister']
-_HAVE_SIGMASK = hasattr(signal, 'pthread_sigmask')
-_IGNORED_SIGNALS = signal.SIGINT, signal.SIGTERM
-_CLEANUP_FUNCS = {'folder': shutil.rmtree, 'file': os.unlink}
-if os.name == 'posix':
- _CLEANUP_FUNCS['semlock'] = sem_unlink
+
+
+__all__ = ["ensure_running", "register", "unregister"]
+
+_HAVE_SIGMASK = hasattr(signal, "pthread_sigmask")
+_IGNORED_SIGNALS = (signal.SIGINT, signal.SIGTERM)
+
+_CLEANUP_FUNCS = {"folder": shutil.rmtree, "file": os.unlink}
+
+if os.name == "posix":
+ _CLEANUP_FUNCS["semlock"] = sem_unlink
+
+
VERBOSE = False
class ResourceTracker:
-
def __init__(self):
self._lock = threading.Lock()
self._fd = None
self._pid = None
+ def getfd(self):
+ self.ensure_running()
+ return self._fd
+
def ensure_running(self):
"""Make sure that resource tracker process is running.
This can be run from any process. Usually a child process will use
the resource created by its parent."""
- pass
+ with self._lock:
+ if self._fd is not None:
+ # resource tracker was launched before, is it still running?
+ if self._check_alive():
+ # => still alive
+ return
+ # => dead, launch it again
+ os.close(self._fd)
+ if os.name == "posix":
+ try:
+ # At this point, the resource_tracker process has been
+ # killed or crashed. Let's remove the process entry
+ # from the process table to avoid zombie processes.
+ os.waitpid(self._pid, 0)
+ except OSError:
+ # The process was terminated or is a child from an
+ # ancestor of the current process.
+ pass
+ self._fd = None
+ self._pid = None
+
+ warnings.warn(
+ "resource_tracker: process died unexpectedly, "
+ "relaunching. Some folders/sempahores might "
+ "leak."
+ )
+
+ fds_to_pass = []
+ try:
+ fds_to_pass.append(sys.stderr.fileno())
+ except Exception:
+ pass
+
+ r, w = os.pipe()
+ if sys.platform == "win32":
+ _r = duplicate(msvcrt.get_osfhandle(r), inheritable=True)
+ os.close(r)
+ r = _r
+
+ cmd = f"from {main.__module__} import main; main({r}, {VERBOSE})"
+ try:
+ fds_to_pass.append(r)
+ # process will out live us, so no need to wait on pid
+ exe = spawn.get_executable()
+ args = [exe, *util._args_from_interpreter_flags(), "-c", cmd]
+ util.debug(f"launching resource tracker: {args}")
+ # bpo-33613: Register a signal mask that will block the
+ # signals. This signal mask will be inherited by the child
+ # that is going to be spawned and will protect the child from a
+ # race condition that can make the child die before it
+ # registers signal handlers for SIGINT and SIGTERM. The mask is
+ # unregistered after spawning the child.
+ try:
+ if _HAVE_SIGMASK:
+ signal.pthread_sigmask(
+ signal.SIG_BLOCK, _IGNORED_SIGNALS
+ )
+ pid = spawnv_passfds(exe, args, fds_to_pass)
+ finally:
+ if _HAVE_SIGMASK:
+ signal.pthread_sigmask(
+ signal.SIG_UNBLOCK, _IGNORED_SIGNALS
+ )
+ except BaseException:
+ os.close(w)
+ raise
+ else:
+ self._fd = w
+ self._pid = pid
+ finally:
+ if sys.platform == "win32":
+ _winapi.CloseHandle(r)
+ else:
+ os.close(r)
def _check_alive(self):
"""Check for the existence of the resource tracker process."""
- pass
+ try:
+ self._send("PROBE", "", "")
+ except BrokenPipeError:
+ return False
+ else:
+ return True
def register(self, name, rtype):
"""Register a named resource, and increment its refcount."""
- pass
+ self.ensure_running()
+ self._send("REGISTER", name, rtype)
def unregister(self, name, rtype):
"""Unregister a named resource with resource tracker."""
- pass
+ self.ensure_running()
+ self._send("UNREGISTER", name, rtype)
def maybe_unlink(self, name, rtype):
"""Decrement the refcount of a resource, and delete it if it hits 0"""
- pass
+ self.ensure_running()
+ self._send("MAYBE_UNLINK", name, rtype)
+
+ def _send(self, cmd, name, rtype):
+ if len(name) > 512:
+ # posix guarantees that writes to a pipe of less than PIPE_BUF
+ # bytes are atomic, and that PIPE_BUF >= 512
+ raise ValueError("name too long")
+ msg = f"{cmd}:{name}:{rtype}\n".encode("ascii")
+ nbytes = os.write(self._fd, msg)
+ assert nbytes == len(msg)
_resource_tracker = ResourceTracker()
@@ -61,4 +208,171 @@ getfd = _resource_tracker.getfd
def main(fd, verbose=0):
"""Run resource tracker."""
- pass
+ # protect the process from ^C and "killall python" etc
+ if verbose:
+ util.log_to_stderr(level=util.DEBUG)
+
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ signal.signal(signal.SIGTERM, signal.SIG_IGN)
+
+ if _HAVE_SIGMASK:
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS)
+
+ for f in (sys.stdin, sys.stdout):
+ try:
+ f.close()
+ except Exception:
+ pass
+
+ if verbose:
+ util.debug("Main resource tracker is running")
+
+ registry = {rtype: {} for rtype in _CLEANUP_FUNCS.keys()}
+ try:
+ # keep track of registered/unregistered resources
+ if sys.platform == "win32":
+ fd = msvcrt.open_osfhandle(fd, os.O_RDONLY)
+ with open(fd, "rb") as f:
+ while True:
+ line = f.readline()
+ if line == b"": # EOF
+ break
+ try:
+ splitted = line.strip().decode("ascii").split(":")
+ # name can potentially contain separator symbols (for
+ # instance folders on Windows)
+ cmd, name, rtype = (
+ splitted[0],
+ ":".join(splitted[1:-1]),
+ splitted[-1],
+ )
+
+ if cmd == "PROBE":
+ continue
+
+ if rtype not in _CLEANUP_FUNCS:
+ raise ValueError(
+ f"Cannot register {name} for automatic cleanup: "
+ f"unknown resource type ({rtype}). Resource type "
+ "should be one of the following: "
+ f"{list(_CLEANUP_FUNCS.keys())}"
+ )
+
+ if cmd == "REGISTER":
+ if name not in registry[rtype]:
+ registry[rtype][name] = 1
+ else:
+ registry[rtype][name] += 1
+
+ if verbose:
+ util.debug(
+ "[ResourceTracker] incremented refcount of "
+ f"{rtype} {name} "
+ f"(current {registry[rtype][name]})"
+ )
+ elif cmd == "UNREGISTER":
+ del registry[rtype][name]
+ if verbose:
+ util.debug(
+ f"[ResourceTracker] unregister {name} {rtype}: "
+ f"registry({len(registry)})"
+ )
+ elif cmd == "MAYBE_UNLINK":
+ registry[rtype][name] -= 1
+ if verbose:
+ util.debug(
+ "[ResourceTracker] decremented refcount of "
+ f"{rtype} {name} "
+ f"(current {registry[rtype][name]})"
+ )
+
+ if registry[rtype][name] == 0:
+ del registry[rtype][name]
+ try:
+ if verbose:
+ util.debug(
+ f"[ResourceTracker] unlink {name}"
+ )
+ _CLEANUP_FUNCS[rtype](name)
+ except Exception as e:
+ warnings.warn(
+ f"resource_tracker: {name}: {e!r}"
+ )
+
+ else:
+ raise RuntimeError(f"unrecognized command {cmd!r}")
+ except BaseException:
+ try:
+ sys.excepthook(*sys.exc_info())
+ except BaseException:
+ pass
+ finally:
+ # all processes have terminated; cleanup any remaining resources
+ def _unlink_resources(rtype_registry, rtype):
+ if rtype_registry:
+ try:
+ warnings.warn(
+ "resource_tracker: There appear to be "
+ f"{len(rtype_registry)} leaked {rtype} objects to "
+ "clean up at shutdown"
+ )
+ except Exception:
+ pass
+ for name in rtype_registry:
+ # For some reason the process which created and registered this
+ # resource has failed to unregister it. Presumably it has
+ # died. We therefore clean it up.
+ try:
+ _CLEANUP_FUNCS[rtype](name)
+ if verbose:
+ util.debug(f"[ResourceTracker] unlink {name}")
+ except Exception as e:
+ warnings.warn(f"resource_tracker: {name}: {e!r}")
+
+ for rtype, rtype_registry in registry.items():
+ if rtype == "folder":
+ continue
+ else:
+ _unlink_resources(rtype_registry, rtype)
+
+ # The default cleanup routine for folders deletes everything inside
+ # those folders recursively, which can include other resources tracked
+ # by the resource tracker). To limit the risk of the resource tracker
+ # attempting to delete twice a resource (once as part of a tracked
+ # folder, and once as a resource), we delete the folders after all
+ # other resource types.
+ if "folder" in registry:
+ _unlink_resources(registry["folder"], "folder")
+
+ if verbose:
+ util.debug("resource tracker shut down")
+
+
+#
+# Start a program with only specified fds kept open
+#
+
+
+def spawnv_passfds(path, args, passfds):
+ passfds = sorted(passfds)
+ if sys.platform != "win32":
+ errpipe_read, errpipe_write = os.pipe()
+ try:
+ from .reduction import _mk_inheritable
+ from .fork_exec import fork_exec
+
+ _pass = [_mk_inheritable(fd) for fd in passfds]
+ return fork_exec(args, _pass)
+ finally:
+ os.close(errpipe_read)
+ os.close(errpipe_write)
+ else:
+ cmd = " ".join(f'"{x}"' for x in args)
+ try:
+ _, ht, pid, _ = _winapi.CreateProcess(
+ path, cmd, None, None, True, 0, None, None, None
+ )
+ _winapi.CloseHandle(ht)
+ except BaseException:
+ pass
+ return pid
diff --git a/joblib/externals/loky/backend/spawn.py b/joblib/externals/loky/backend/spawn.py
index aadb9e2..d011c39 100644
--- a/joblib/externals/loky/backend/spawn.py
+++ b/joblib/externals/loky/backend/spawn.py
@@ -1,31 +1,250 @@
+###############################################################################
+# Prepares and processes the data to setup the new process environment
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from multiprocessing/spawn.py (17/02/2017)
+# * Improve logging data
+#
import os
import sys
import runpy
import textwrap
import types
from multiprocessing import process, util
-if sys.platform != 'win32':
+
+
+if sys.platform != "win32":
WINEXE = False
WINSERVICE = False
else:
import msvcrt
from multiprocessing.reduction import duplicate
- WINEXE = sys.platform == 'win32' and getattr(sys, 'frozen', False)
- WINSERVICE = sys.executable.lower().endswith('pythonservice.exe')
+
+ WINEXE = sys.platform == "win32" and getattr(sys, "frozen", False)
+ WINSERVICE = sys.executable.lower().endswith("pythonservice.exe")
+
if WINSERVICE:
- _python_exe = os.path.join(sys.exec_prefix, 'python.exe')
+ _python_exe = os.path.join(sys.exec_prefix, "python.exe")
else:
_python_exe = sys.executable
+def get_executable():
+ return _python_exe
+
+
+def _check_not_importing_main():
+ if getattr(process.current_process(), "_inheriting", False):
+ raise RuntimeError(
+ textwrap.dedent(
+ """\
+ An attempt has been made to start a new process before the
+ current process has finished its bootstrapping phase.
+
+ This probably means that you are not using fork to start your
+ child processes and you have forgotten to use the proper idiom
+ in the main module:
+
+ if __name__ == '__main__':
+ freeze_support()
+ ...
+
+ The "freeze_support()" line can be omitted if the program
+ is not going to be frozen to produce an executable."""
+ )
+ )
+
+
def get_preparation_data(name, init_main_module=True):
"""Return info about parent needed by child to unpickle process object."""
- pass
+ _check_not_importing_main()
+ d = dict(
+ log_to_stderr=util._log_to_stderr,
+ authkey=bytes(process.current_process().authkey),
+ name=name,
+ sys_argv=sys.argv,
+ orig_dir=process.ORIGINAL_DIR,
+ dir=os.getcwd(),
+ )
+ # Send sys_path and make sure the current directory will not be changed
+ d["sys_path"] = [p if p != "" else process.ORIGINAL_DIR for p in sys.path]
+ # Make sure to pass the information if the multiprocessing logger is active
+ if util._logger is not None:
+ d["log_level"] = util._logger.getEffectiveLevel()
+ if util._logger.handlers:
+ h = util._logger.handlers[0]
+ d["log_fmt"] = h.formatter._fmt
+
+ # Tell the child how to communicate with the resource_tracker
+ from .resource_tracker import _resource_tracker
+
+ _resource_tracker.ensure_running()
+ d["tracker_args"] = {"pid": _resource_tracker._pid}
+ if sys.platform == "win32":
+ d["tracker_args"]["fh"] = msvcrt.get_osfhandle(_resource_tracker._fd)
+ else:
+ d["tracker_args"]["fd"] = _resource_tracker._fd
+
+ if sys.version_info >= (3, 8) and os.name == "posix":
+ # joblib/loky#242: allow loky processes to retrieve the resource
+ # tracker of their parent in case the child processes depickles
+ # shared_memory objects, that are still tracked by multiprocessing's
+ # resource_tracker by default.
+ # XXX: this is a workaround that may be error prone: in the future, it
+ # would be better to have loky subclass multiprocessing's shared_memory
+ # to force registration of shared_memory segments via loky's
+ # resource_tracker.
+ from multiprocessing.resource_tracker import (
+ _resource_tracker as mp_resource_tracker,
+ )
+
+ # multiprocessing's resource_tracker must be running before loky
+ # process is created (othewise the child won't be able to use it if it
+ # is created later on)
+ mp_resource_tracker.ensure_running()
+ d["mp_tracker_args"] = {
+ "fd": mp_resource_tracker._fd,
+ "pid": mp_resource_tracker._pid,
+ }
+
+ # Figure out whether to initialise main in the subprocess as a module
+ # or through direct execution (or to leave it alone entirely)
+ if init_main_module:
+ main_module = sys.modules["__main__"]
+ try:
+ main_mod_name = getattr(main_module.__spec__, "name", None)
+ except BaseException:
+ main_mod_name = None
+ if main_mod_name is not None:
+ d["init_main_from_name"] = main_mod_name
+ elif sys.platform != "win32" or (not WINEXE and not WINSERVICE):
+ main_path = getattr(main_module, "__file__", None)
+ if main_path is not None:
+ if (
+ not os.path.isabs(main_path)
+ and process.ORIGINAL_DIR is not None
+ ):
+ main_path = os.path.join(process.ORIGINAL_DIR, main_path)
+ d["init_main_from_path"] = os.path.normpath(main_path)
+
+ return d
+
+
+#
+# Prepare current process
+#
old_main_modules = []
def prepare(data, parent_sentinel=None):
"""Try to get current process ready to unpickle process object."""
- pass
+ if "name" in data:
+ process.current_process().name = data["name"]
+
+ if "authkey" in data:
+ process.current_process().authkey = data["authkey"]
+
+ if "log_to_stderr" in data and data["log_to_stderr"]:
+ util.log_to_stderr()
+
+ if "log_level" in data:
+ util.get_logger().setLevel(data["log_level"])
+
+ if "log_fmt" in data:
+ import logging
+
+ util.get_logger().handlers[0].setFormatter(
+ logging.Formatter(data["log_fmt"])
+ )
+
+ if "sys_path" in data:
+ sys.path = data["sys_path"]
+
+ if "sys_argv" in data:
+ sys.argv = data["sys_argv"]
+
+ if "dir" in data:
+ os.chdir(data["dir"])
+
+ if "orig_dir" in data:
+ process.ORIGINAL_DIR = data["orig_dir"]
+
+ if "mp_tracker_args" in data:
+ from multiprocessing.resource_tracker import (
+ _resource_tracker as mp_resource_tracker,
+ )
+
+ mp_resource_tracker._fd = data["mp_tracker_args"]["fd"]
+ mp_resource_tracker._pid = data["mp_tracker_args"]["pid"]
+ if "tracker_args" in data:
+ from .resource_tracker import _resource_tracker
+
+ _resource_tracker._pid = data["tracker_args"]["pid"]
+ if sys.platform == "win32":
+ handle = data["tracker_args"]["fh"]
+ handle = duplicate(handle, source_process=parent_sentinel)
+ _resource_tracker._fd = msvcrt.open_osfhandle(handle, os.O_RDONLY)
+ else:
+ _resource_tracker._fd = data["tracker_args"]["fd"]
+
+ if "init_main_from_name" in data:
+ _fixup_main_from_name(data["init_main_from_name"])
+ elif "init_main_from_path" in data:
+ _fixup_main_from_path(data["init_main_from_path"])
+
+
+# Multiprocessing module helpers to fix up the main module in
+# spawned subprocesses
+def _fixup_main_from_name(mod_name):
+ # __main__.py files for packages, directories, zip archives, etc, run
+ # their "main only" code unconditionally, so we don't even try to
+ # populate anything in __main__, nor do we make any changes to
+ # __main__ attributes
+ current_main = sys.modules["__main__"]
+ if mod_name == "__main__" or mod_name.endswith(".__main__"):
+ return
+
+ # If this process was forked, __main__ may already be populated
+ if getattr(current_main.__spec__, "name", None) == mod_name:
+ return
+
+ # Otherwise, __main__ may contain some non-main code where we need to
+ # support unpickling it properly. We rerun it as __mp_main__ and make
+ # the normal __main__ an alias to that
+ old_main_modules.append(current_main)
+ main_module = types.ModuleType("__mp_main__")
+ main_content = runpy.run_module(
+ mod_name, run_name="__mp_main__", alter_sys=True
+ )
+ main_module.__dict__.update(main_content)
+ sys.modules["__main__"] = sys.modules["__mp_main__"] = main_module
+
+
+def _fixup_main_from_path(main_path):
+ # If this process was forked, __main__ may already be populated
+ current_main = sys.modules["__main__"]
+
+ # Unfortunately, the main ipython launch script historically had no
+ # "if __name__ == '__main__'" guard, so we work around that
+ # by treating it like a __main__.py file
+ # See https://github.com/ipython/ipython/issues/4698
+ main_name = os.path.splitext(os.path.basename(main_path))[0]
+ if main_name == "ipython":
+ return
+
+ # Otherwise, if __file__ already has the setting we expect,
+ # there's nothing more to do
+ if getattr(current_main, "__file__", None) == main_path:
+ return
+
+ # If the parent process has sent a path through rather than a module
+ # name we assume it is an executable script that may contain
+ # non-main code that needs to be executed
+ old_main_modules.append(current_main)
+ main_module = types.ModuleType("__mp_main__")
+ main_content = runpy.run_path(main_path, run_name="__mp_main__")
+ main_module.__dict__.update(main_content)
+ sys.modules["__main__"] = sys.modules["__mp_main__"] = main_module
diff --git a/joblib/externals/loky/backend/synchronize.py b/joblib/externals/loky/backend/synchronize.py
index 3136b2c..18db3e3 100644
--- a/joblib/externals/loky/backend/synchronize.py
+++ b/joblib/externals/loky/backend/synchronize.py
@@ -1,3 +1,17 @@
+###############################################################################
+# Synchronization primitives based on our SemLock implementation
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from multiprocessing/synchronize.py (17/02/2017)
+# * Remove ctx argument for compatibility reason
+# * Registers a cleanup function with the loky resource_tracker to remove the
+# semaphore when the process dies instead.
+#
+# TODO: investigate which Python version is required to be able to use
+# multiprocessing.resource_tracker and therefore multiprocessing.synchronize
+# instead of a loky-specific fork.
+
import os
import sys
import tempfile
@@ -6,50 +20,100 @@ import _multiprocessing
from time import time as _time
from multiprocessing import process, util
from multiprocessing.context import assert_spawning
+
from . import resource_tracker
-__all__ = ['Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
- 'Event']
+
+__all__ = [
+ "Lock",
+ "RLock",
+ "Semaphore",
+ "BoundedSemaphore",
+ "Condition",
+ "Event",
+]
+# Try to import the mp.synchronize module cleanly, if it fails
+# raise ImportError for platforms lacking a working sem_open implementation.
+# See issue 3770
try:
from _multiprocessing import SemLock as _SemLock
from _multiprocessing import sem_unlink
except ImportError:
raise ImportError(
- 'This platform lacks a functioning sem_open implementation, therefore, the required synchronization primitives needed will not function, see issue 3770.'
- )
+ "This platform lacks a functioning sem_open"
+ " implementation, therefore, the required"
+ " synchronization primitives needed will not"
+ " function, see issue 3770."
+ )
+
+#
+# Constants
+#
+
RECURSIVE_MUTEX, SEMAPHORE = range(2)
SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX
+#
+# Base class for semaphores and mutexes; wraps `_multiprocessing.SemLock`
+#
+
+
class SemLock:
+
_rand = tempfile._RandomNameSequence()
def __init__(self, kind, value, maxvalue, name=None):
+ # unlink_now is only used on win32 or when we are using fork.
unlink_now = False
if name is None:
+ # Try to find an unused name for the SemLock instance.
for _ in range(100):
try:
- self._semlock = _SemLock(kind, value, maxvalue, SemLock
- ._make_name(), unlink_now)
- except FileExistsError:
+ self._semlock = _SemLock(
+ kind, value, maxvalue, SemLock._make_name(), unlink_now
+ )
+ except FileExistsError: # pragma: no cover
pass
else:
break
- else:
- raise FileExistsError('cannot find name for semaphore')
+ else: # pragma: no cover
+ raise FileExistsError("cannot find name for semaphore")
else:
self._semlock = _SemLock(kind, value, maxvalue, name, unlink_now)
self.name = name
util.debug(
- f'created semlock with handle {self._semlock.handle} and name "{self.name}"'
- )
+ f"created semlock with handle {self._semlock.handle} and name "
+ f'"{self.name}"'
+ )
+
self._make_methods()
def _after_fork(obj):
obj._semlock._after_fork()
+
util.register_after_fork(self, _after_fork)
- resource_tracker.register(self._semlock.name, 'semlock')
- util.Finalize(self, SemLock._cleanup, (self._semlock.name,),
- exitpriority=0)
+
+ # When the object is garbage collected or the
+ # process shuts down we unlink the semaphore name
+ resource_tracker.register(self._semlock.name, "semlock")
+ util.Finalize(
+ self, SemLock._cleanup, (self._semlock.name,), exitpriority=0
+ )
+
+ @staticmethod
+ def _cleanup(name):
+ try:
+ sem_unlink(name)
+ except FileNotFoundError:
+ # Already unlinked, possibly by user code: ignore and make sure to
+ # unregister the semaphore from the resource tracker.
+ pass
+ finally:
+ resource_tracker.unregister(name, "semlock")
+
+ def _make_methods(self):
+ self.acquire = self._semlock.acquire
+ self.release = self._semlock.release
def __enter__(self):
return self._semlock.acquire()
@@ -61,31 +125,49 @@ class SemLock:
assert_spawning(self)
sl = self._semlock
h = sl.handle
- return h, sl.kind, sl.maxvalue, sl.name
+ return (h, sl.kind, sl.maxvalue, sl.name)
def __setstate__(self, state):
self._semlock = _SemLock._rebuild(*state)
util.debug(
f'recreated blocker with handle {state[0]!r} and name "{state[3]}"'
- )
+ )
self._make_methods()
+ @staticmethod
+ def _make_name():
+ # OSX does not support long names for semaphores
+ return f"/loky-{os.getpid()}-{next(SemLock._rand)}"
-class Semaphore(SemLock):
+#
+# Semaphore
+#
+
+
+class Semaphore(SemLock):
def __init__(self, value=1):
SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX)
+ def get_value(self):
+ if sys.platform == "darwin":
+ raise NotImplementedError("OSX does not implement sem_getvalue")
+ return self._semlock._get_value()
+
def __repr__(self):
try:
value = self._semlock._get_value()
except Exception:
- value = 'unknown'
- return f'<{self.__class__.__name__}(value={value})>'
+ value = "unknown"
+ return f"<{self.__class__.__name__}(value={value})>"
-class BoundedSemaphore(Semaphore):
+#
+# Bounded semaphore
+#
+
+class BoundedSemaphore(Semaphore):
def __init__(self, value=1):
SemLock.__init__(self, SEMAPHORE, value, value)
@@ -93,14 +175,19 @@ class BoundedSemaphore(Semaphore):
try:
value = self._semlock._get_value()
except Exception:
- value = 'unknown'
+ value = "unknown"
return (
- f'<{self.__class__.__name__}(value={value}, maxvalue={self._semlock.maxvalue})>'
- )
+ f"<{self.__class__.__name__}(value={value}, "
+ f"maxvalue={self._semlock.maxvalue})>"
+ )
-class Lock(SemLock):
+#
+# Non-recursive lock
+#
+
+class Lock(SemLock):
def __init__(self):
super().__init__(SEMAPHORE, 1, 1)
@@ -108,21 +195,25 @@ class Lock(SemLock):
try:
if self._semlock._is_mine():
name = process.current_process().name
- if threading.current_thread().name != 'MainThread':
- name = f'{name}|{threading.current_thread().name}'
+ if threading.current_thread().name != "MainThread":
+ name = f"{name}|{threading.current_thread().name}"
elif self._semlock._get_value() == 1:
- name = 'None'
+ name = "None"
elif self._semlock._count() > 0:
- name = 'SomeOtherThread'
+ name = "SomeOtherThread"
else:
- name = 'SomeOtherProcess'
+ name = "SomeOtherProcess"
except Exception:
- name = 'unknown'
- return f'<{self.__class__.__name__}(owner={name})>'
+ name = "unknown"
+ return f"<{self.__class__.__name__}(owner={name})>"
-class RLock(SemLock):
+#
+# Recursive lock
+#
+
+class RLock(SemLock):
def __init__(self):
super().__init__(RECURSIVE_MUTEX, 1, 1)
@@ -130,22 +221,26 @@ class RLock(SemLock):
try:
if self._semlock._is_mine():
name = process.current_process().name
- if threading.current_thread().name != 'MainThread':
- name = f'{name}|{threading.current_thread().name}'
+ if threading.current_thread().name != "MainThread":
+ name = f"{name}|{threading.current_thread().name}"
count = self._semlock._count()
elif self._semlock._get_value() == 1:
- name, count = 'None', 0
+ name, count = "None", 0
elif self._semlock._count() > 0:
- name, count = 'SomeOtherThread', 'nonzero'
+ name, count = "SomeOtherThread", "nonzero"
else:
- name, count = 'SomeOtherProcess', 'nonzero'
+ name, count = "SomeOtherProcess", "nonzero"
except Exception:
- name, count = 'unknown', 'unknown'
- return f'<{self.__class__.__name__}({name}, {count})>'
+ name, count = "unknown", "unknown"
+ return f"<{self.__class__.__name__}({name}, {count})>"
-class Condition:
+#
+# Condition variable
+#
+
+class Condition:
def __init__(self, lock=None):
self._lock = lock or RLock()
self._sleeping_count = Semaphore(0)
@@ -155,12 +250,20 @@ class Condition:
def __getstate__(self):
assert_spawning(self)
- return (self._lock, self._sleeping_count, self._woken_count, self.
- _wait_semaphore)
+ return (
+ self._lock,
+ self._sleeping_count,
+ self._woken_count,
+ self._wait_semaphore,
+ )
def __setstate__(self, state):
- (self._lock, self._sleeping_count, self._woken_count, self.
- _wait_semaphore) = state
+ (
+ self._lock,
+ self._sleeping_count,
+ self._woken_count,
+ self._wait_semaphore,
+ ) = state
self._make_methods()
def __enter__(self):
@@ -169,17 +272,138 @@ class Condition:
def __exit__(self, *args):
return self._lock.__exit__(*args)
+ def _make_methods(self):
+ self.acquire = self._lock.acquire
+ self.release = self._lock.release
+
def __repr__(self):
try:
- num_waiters = self._sleeping_count._semlock._get_value(
- ) - self._woken_count._semlock._get_value()
+ num_waiters = (
+ self._sleeping_count._semlock._get_value()
+ - self._woken_count._semlock._get_value()
+ )
except Exception:
- num_waiters = 'unknown'
- return f'<{self.__class__.__name__}({self._lock}, {num_waiters})>'
+ num_waiters = "unknown"
+ return f"<{self.__class__.__name__}({self._lock}, {num_waiters})>"
+ def wait(self, timeout=None):
+ assert (
+ self._lock._semlock._is_mine()
+ ), "must acquire() condition before using wait()"
-class Event:
+ # indicate that this thread is going to sleep
+ self._sleeping_count.release()
+
+ # release lock
+ count = self._lock._semlock._count()
+ for _ in range(count):
+ self._lock.release()
+
+ try:
+ # wait for notification or timeout
+ return self._wait_semaphore.acquire(True, timeout)
+ finally:
+ # indicate that this thread has woken
+ self._woken_count.release()
+
+ # reacquire lock
+ for _ in range(count):
+ self._lock.acquire()
+
+ def notify(self):
+ assert self._lock._semlock._is_mine(), "lock is not owned"
+ assert not self._wait_semaphore.acquire(False)
+
+ # to take account of timeouts since last notify() we subtract
+ # woken_count from sleeping_count and rezero woken_count
+ while self._woken_count.acquire(False):
+ res = self._sleeping_count.acquire(False)
+ assert res
+
+ if self._sleeping_count.acquire(False): # try grabbing a sleeper
+ self._wait_semaphore.release() # wake up one sleeper
+ self._woken_count.acquire() # wait for the sleeper to wake
+
+ # rezero _wait_semaphore in case a timeout just happened
+ self._wait_semaphore.acquire(False)
+
+ def notify_all(self):
+ assert self._lock._semlock._is_mine(), "lock is not owned"
+ assert not self._wait_semaphore.acquire(False)
+
+ # to take account of timeouts since last notify*() we subtract
+ # woken_count from sleeping_count and rezero woken_count
+ while self._woken_count.acquire(False):
+ res = self._sleeping_count.acquire(False)
+ assert res
+
+ sleepers = 0
+ while self._sleeping_count.acquire(False):
+ self._wait_semaphore.release() # wake up one sleeper
+ sleepers += 1
+
+ if sleepers:
+ for _ in range(sleepers):
+ self._woken_count.acquire() # wait for a sleeper to wake
+
+ # rezero wait_semaphore in case some timeouts just happened
+ while self._wait_semaphore.acquire(False):
+ pass
+
+ def wait_for(self, predicate, timeout=None):
+ result = predicate()
+ if result:
+ return result
+ if timeout is not None:
+ endtime = _time() + timeout
+ else:
+ endtime = None
+ waittime = None
+ while not result:
+ if endtime is not None:
+ waittime = endtime - _time()
+ if waittime <= 0:
+ break
+ self.wait(waittime)
+ result = predicate()
+ return result
+
+#
+# Event
+#
+
+
+class Event:
def __init__(self):
self._cond = Condition(Lock())
self._flag = Semaphore(0)
+
+ def is_set(self):
+ with self._cond:
+ if self._flag.acquire(False):
+ self._flag.release()
+ return True
+ return False
+
+ def set(self):
+ with self._cond:
+ self._flag.acquire(False)
+ self._flag.release()
+ self._cond.notify_all()
+
+ def clear(self):
+ with self._cond:
+ self._flag.acquire(False)
+
+ def wait(self, timeout=None):
+ with self._cond:
+ if self._flag.acquire(False):
+ self._flag.release()
+ else:
+ self._cond.wait(timeout)
+
+ if self._flag.acquire(False):
+ self._flag.release()
+ return True
+ return False
diff --git a/joblib/externals/loky/backend/utils.py b/joblib/externals/loky/backend/utils.py
index 3a17859..aa089f7 100644
--- a/joblib/externals/loky/backend/utils.py
+++ b/joblib/externals/loky/backend/utils.py
@@ -6,6 +6,7 @@ import signal
import warnings
import subprocess
import traceback
+
try:
import psutil
except ImportError:
@@ -14,17 +15,115 @@ except ImportError:
def kill_process_tree(process, use_psutil=True):
"""Terminate process and its descendants with SIGKILL"""
- pass
+ if use_psutil and psutil is not None:
+ _kill_process_tree_with_psutil(process)
+ else:
+ _kill_process_tree_without_psutil(process)
+
+
+def recursive_terminate(process, use_psutil=True):
+ warnings.warn(
+ "recursive_terminate is deprecated in loky 3.2, use kill_process_tree"
+ "instead",
+ DeprecationWarning,
+ )
+ kill_process_tree(process, use_psutil=use_psutil)
+
+
+def _kill_process_tree_with_psutil(process):
+ try:
+ descendants = psutil.Process(process.pid).children(recursive=True)
+ except psutil.NoSuchProcess:
+ return
+
+ # Kill the descendants in reverse order to avoid killing the parents before
+ # the descendant in cases where there are more processes nested.
+ for descendant in descendants[::-1]:
+ try:
+ descendant.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ try:
+ psutil.Process(process.pid).kill()
+ except psutil.NoSuchProcess:
+ pass
+ process.join()
def _kill_process_tree_without_psutil(process):
"""Terminate a process and its descendants."""
- pass
+ try:
+ if sys.platform == "win32":
+ _windows_taskkill_process_tree(process.pid)
+ else:
+ _posix_recursive_kill(process.pid)
+ except Exception: # pragma: no cover
+ details = traceback.format_exc()
+ warnings.warn(
+ "Failed to kill subprocesses on this platform. Please install"
+ "psutil: https://github.com/giampaolo/psutil\n"
+ f"Details:\n{details}"
+ )
+ # In case we cannot introspect or kill the descendants, we fall back to
+ # only killing the main process.
+ #
+ # Note: on Windows, process.kill() is an alias for process.terminate()
+ # which in turns calls the Win32 API function TerminateProcess().
+ process.kill()
+ process.join()
+
+
+def _windows_taskkill_process_tree(pid):
+ # On windows, the taskkill function with option `/T` terminate a given
+ # process pid and its children.
+ try:
+ subprocess.check_output(
+ ["taskkill", "/F", "/T", "/PID", str(pid)], stderr=None
+ )
+ except subprocess.CalledProcessError as e:
+ # In Windows, taskkill returns 128, 255 for no process found.
+ if e.returncode not in [128, 255]:
+ # Let's raise to let the caller log the error details in a
+ # warning and only kill the root process.
+ raise # pragma: no cover
+
+
+def _kill(pid):
+ # Not all systems (e.g. Windows) have a SIGKILL, but the C specification
+ # mandates a SIGTERM signal. While Windows is handled specifically above,
+ # let's try to be safe for other hypothetic platforms that only have
+ # SIGTERM without SIGKILL.
+ kill_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
+ try:
+ os.kill(pid, kill_signal)
+ except OSError as e:
+ # if OSError is raised with [Errno 3] no such process, the process
+ # is already terminated, else, raise the error and let the top
+ # level function raise a warning and retry to kill the process.
+ if e.errno != errno.ESRCH:
+ raise # pragma: no cover
def _posix_recursive_kill(pid):
"""Recursively kill the descendants of a process before killing it."""
- pass
+ try:
+ children_pids = subprocess.check_output(
+ ["pgrep", "-P", str(pid)], stderr=None, text=True
+ )
+ except subprocess.CalledProcessError as e:
+ # `ps` returns 1 when no child process has been found
+ if e.returncode == 1:
+ children_pids = ""
+ else:
+ raise # pragma: no cover
+
+ # Decode the result, split the cpid and remove the trailing line
+ for cpid in children_pids.splitlines():
+ cpid = int(cpid)
+ _posix_recursive_kill(cpid)
+
+ _kill(pid)
def get_exitcodes_terminated_worker(processes):
@@ -33,9 +132,50 @@ def get_exitcodes_terminated_worker(processes):
If necessary, wait (up to .25s) for the system to correctly set the
exitcode of one terminated worker.
"""
- pass
+ patience = 5
+
+ # Catch the exitcode of the terminated workers. There should at least be
+ # one. If not, wait a bit for the system to correctly set the exitcode of
+ # the terminated worker.
+ exitcodes = [
+ p.exitcode for p in list(processes.values()) if p.exitcode is not None
+ ]
+ while not exitcodes and patience > 0:
+ patience -= 1
+ exitcodes = [
+ p.exitcode
+ for p in list(processes.values())
+ if p.exitcode is not None
+ ]
+ time.sleep(0.05)
+
+ return _format_exitcodes(exitcodes)
def _format_exitcodes(exitcodes):
"""Format a list of exit code with names of the signals if possible"""
- pass
+ str_exitcodes = [
+ f"{_get_exitcode_name(e)}({e})" for e in exitcodes if e is not None
+ ]
+ return "{" + ", ".join(str_exitcodes) + "}"
+
+
+def _get_exitcode_name(exitcode):
+ if sys.platform == "win32":
+ # The exitcode are unreliable on windows (see bpo-31863).
+ # For this case, return UNKNOWN
+ return "UNKNOWN"
+
+ if exitcode < 0:
+ try:
+ import signal
+
+ return signal.Signals(-exitcode).name
+ except ValueError:
+ return "UNKNOWN"
+ elif exitcode != 255:
+ # The exitcode are unreliable on forkserver were 255 is always returned
+ # (see bpo-30589). For this case, return UNKNOWN
+ return "EXIT"
+
+ return "UNKNOWN"
diff --git a/joblib/externals/loky/cloudpickle_wrapper.py b/joblib/externals/loky/cloudpickle_wrapper.py
index 808ade4..099debc 100644
--- a/joblib/externals/loky/cloudpickle_wrapper.py
+++ b/joblib/externals/loky/cloudpickle_wrapper.py
@@ -1,11 +1,12 @@
import inspect
from functools import partial
from joblib.externals.cloudpickle import dumps, loads
+
+
WRAP_CACHE = {}
class CloudpickledObjectWrapper:
-
def __init__(self, obj, keep_wrapper=False):
self._obj = obj
self._keep_wrapper = keep_wrapper
@@ -14,20 +15,67 @@ class CloudpickledObjectWrapper:
_pickled_object = dumps(self._obj)
if not self._keep_wrapper:
return loads, (_pickled_object,)
+
return _reconstruct_wrapper, (_pickled_object, self._keep_wrapper)
def __getattr__(self, attr):
- if attr not in ['_obj', '_keep_wrapper']:
+ # Ensure that the wrapped object can be used seemlessly as the
+ # previous object.
+ if attr not in ["_obj", "_keep_wrapper"]:
return getattr(self._obj, attr)
return getattr(self, attr)
+# Make sure the wrapped object conserves the callable property
class CallableObjectWrapper(CloudpickledObjectWrapper):
-
def __call__(self, *args, **kwargs):
return self._obj(*args, **kwargs)
+def _wrap_non_picklable_objects(obj, keep_wrapper):
+ if callable(obj):
+ return CallableObjectWrapper(obj, keep_wrapper=keep_wrapper)
+ return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper)
+
+
+def _reconstruct_wrapper(_pickled_object, keep_wrapper):
+ obj = loads(_pickled_object)
+ return _wrap_non_picklable_objects(obj, keep_wrapper)
+
+
+def _wrap_objects_when_needed(obj):
+ # Function to introspect an object and decide if it should be wrapped or
+ # not.
+ need_wrap = "__main__" in getattr(obj, "__module__", "")
+ if isinstance(obj, partial):
+ return partial(
+ _wrap_objects_when_needed(obj.func),
+ *[_wrap_objects_when_needed(a) for a in obj.args],
+ **{
+ k: _wrap_objects_when_needed(v)
+ for k, v in obj.keywords.items()
+ }
+ )
+ if callable(obj):
+ # Need wrap if the object is a function defined in a local scope of
+ # another function.
+ func_code = getattr(obj, "__code__", "")
+ need_wrap |= getattr(func_code, "co_flags", 0) & inspect.CO_NESTED
+
+ # Need wrap if the obj is a lambda expression
+ func_name = getattr(obj, "__name__", "")
+ need_wrap |= "<lambda>" in func_name
+
+ if not need_wrap:
+ return obj
+
+ wrapped_obj = WRAP_CACHE.get(obj)
+ if wrapped_obj is None:
+ wrapped_obj = _wrap_non_picklable_objects(obj, keep_wrapper=False)
+ WRAP_CACHE[obj] = wrapped_obj
+ return wrapped_obj
+
+
def wrap_non_picklable_objects(obj, keep_wrapper=True):
"""Wrapper for non-picklable object to use cloudpickle to serialize them.
@@ -37,4 +85,18 @@ def wrap_non_picklable_objects(obj, keep_wrapper=True):
objects in the main scripts and to implement __reduce__ functions for
complex classes.
"""
- pass
+ # If obj is a class, create a CloudpickledClassWrapper which instantiates
+ # the object internally and wrap it directly in a CloudpickledObjectWrapper
+ if inspect.isclass(obj):
+
+ class CloudpickledClassWrapper(CloudpickledObjectWrapper):
+ def __init__(self, *args, **kwargs):
+ self._obj = obj(*args, **kwargs)
+ self._keep_wrapper = keep_wrapper
+
+ CloudpickledClassWrapper.__name__ = obj.__name__
+ return CloudpickledClassWrapper
+
+ # If obj is an instance of a class, just wrap it in a regular
+ # CloudpickledObjectWrapper
+ return _wrap_non_picklable_objects(obj, keep_wrapper=keep_wrapper)
diff --git a/joblib/externals/loky/initializers.py b/joblib/externals/loky/initializers.py
index 81c0c79..aea0e56 100644
--- a/joblib/externals/loky/initializers.py
+++ b/joblib/externals/loky/initializers.py
@@ -3,7 +3,30 @@ import warnings
def _viztracer_init(init_kwargs):
"""Initialize viztracer's profiler in worker processes"""
- pass
+ from viztracer import VizTracer
+
+ tracer = VizTracer(**init_kwargs)
+ tracer.register_exit()
+ tracer.start()
+
+
+def _make_viztracer_initializer_and_initargs():
+ try:
+ import viztracer
+
+ tracer = viztracer.get_tracer()
+ if tracer is not None and getattr(tracer, "enable", False):
+ # Profiler is active: introspect its configuration to
+ # initialize the workers with the same configuration.
+ return _viztracer_init, (tracer.init_kwargs,)
+ except ImportError:
+ # viztracer is not installed: nothing to do
+ pass
+ except Exception as e:
+ # In case viztracer's API evolve, we do not want to crash loky but
+ # we want to know about it to be able to update loky.
+ warnings.warn(f"Unable to introspect viztracer state: {e}")
+ return None, ()
class _ChainedInitializer:
@@ -26,4 +49,32 @@ def _chain_initializers(initializer_and_args):
If some initializers are None, they are filtered out.
"""
- pass
+ filtered_initializers = []
+ filtered_initargs = []
+ for initializer, initargs in initializer_and_args:
+ if initializer is not None:
+ filtered_initializers.append(initializer)
+ filtered_initargs.append(initargs)
+
+ if not filtered_initializers:
+ return None, ()
+ elif len(filtered_initializers) == 1:
+ return filtered_initializers[0], filtered_initargs[0]
+ else:
+ return _ChainedInitializer(filtered_initializers), filtered_initargs
+
+
+def _prepare_initializer(initializer, initargs):
+ if initializer is not None and not callable(initializer):
+ raise TypeError(
+ f"initializer must be a callable, got: {initializer!r}"
+ )
+
+ # Introspect runtime to determine if we need to propagate the viztracer
+ # profiler information to the workers:
+ return _chain_initializers(
+ [
+ (initializer, initargs),
+ _make_viztracer_initializer_and_initargs(),
+ ]
+ )
diff --git a/joblib/externals/loky/process_executor.py b/joblib/externals/loky/process_executor.py
index c68582c..3040719 100644
--- a/joblib/externals/loky/process_executor.py
+++ b/joblib/externals/loky/process_executor.py
@@ -1,3 +1,17 @@
+###############################################################################
+# Re-implementation of the ProcessPoolExecutor more robust to faults
+#
+# author: Thomas Moreau and Olivier Grisel
+#
+# adapted from concurrent/futures/process_pool_executor.py (17/02/2017)
+# * Add an extra management thread to detect executor_manager_thread failures,
+# * Improve the shutdown process to avoid deadlocks,
+# * Add timeout for workers,
+# * More robust pickling process.
+#
+# Copyright 2009 Brian Quinlan. All Rights Reserved.
+# Licensed to PSF under a Contributor Agreement.
+
"""Implements ProcessPoolExecutor.
The follow diagram and text describe the data-flow through the system:
@@ -37,9 +51,13 @@ Local worker thread:
Process #1..n:
- reads _CallItems from "Call Q", executes the calls, and puts the resulting
- _ResultItems in "Result Q\"
+ _ResultItems in "Result Q"
"""
-__author__ = 'Thomas Moreau (thomas.moreau.2010@gmail.com)'
+
+
+__author__ = "Thomas Moreau (thomas.moreau.2010@gmail.com)"
+
+
import os
import gc
import sys
@@ -58,6 +76,7 @@ from concurrent.futures import Executor
from concurrent.futures._base import LOGGER
from concurrent.futures.process import BrokenProcessPool as _BPPException
from multiprocessing.connection import wait
+
from ._base import Future
from .backend import get_context
from .backend.context import cpu_count, _MAX_WINDOWS_WORKERS
@@ -65,23 +84,58 @@ from .backend.queues import Queue, SimpleQueue
from .backend.reduction import set_loky_pickler, get_loky_pickler_name
from .backend.utils import kill_process_tree, get_exitcodes_terminated_worker
from .initializers import _prepare_initializer
-MAX_DEPTH = int(os.environ.get('LOKY_MAX_DEPTH', 10))
+
+
+# Mechanism to prevent infinite process spawning. When a worker of a
+# ProcessPoolExecutor nested in MAX_DEPTH Executor tries to create a new
+# Executor, a LokyRecursionError is raised
+MAX_DEPTH = int(os.environ.get("LOKY_MAX_DEPTH", 10))
_CURRENT_DEPTH = 0
+
+# Minimum time interval between two consecutive memory leak protection checks.
_MEMORY_LEAK_CHECK_DELAY = 1.0
-_MAX_MEMORY_LEAK_SIZE = int(300000000.0)
+
+# Number of bytes of memory usage allowed over the reference process size.
+_MAX_MEMORY_LEAK_SIZE = int(3e8)
+
+
try:
from psutil import Process
+
_USE_PSUTIL = True
+
+ def _get_memory_usage(pid, force_gc=False):
+ if force_gc:
+ gc.collect()
+
+ mem_size = Process(pid).memory_info().rss
+ mp.util.debug(f"psutil return memory size: {mem_size}")
+ return mem_size
+
except ImportError:
_USE_PSUTIL = False
class _ThreadWakeup:
-
def __init__(self):
self._closed = False
self._reader, self._writer = mp.Pipe(duplex=False)
+ def close(self):
+ if not self._closed:
+ self._closed = True
+ self._writer.close()
+ self._reader.close()
+
+ def wakeup(self):
+ if not self._closed:
+ self._writer.send_bytes(b"")
+
+ def clear(self):
+ if not self._closed:
+ while self._reader.poll():
+ self._reader.recv_bytes()
+
class _ExecutorFlags:
"""necessary references to maintain executor states without preventing gc
@@ -92,17 +146,84 @@ class _ExecutorFlags:
"""
def __init__(self, shutdown_lock):
+
self.shutdown = False
self.broken = None
self.kill_workers = False
self.shutdown_lock = shutdown_lock
+ def flag_as_shutting_down(self, kill_workers=None):
+ with self.shutdown_lock:
+ self.shutdown = True
+ if kill_workers is not None:
+ self.kill_workers = kill_workers
+
+ def flag_as_broken(self, broken):
+ with self.shutdown_lock:
+ self.shutdown = True
+ self.broken = broken
+
+
+# Prior to 3.9, executor_manager_thread is created as daemon thread. This means
+# that it is not joined automatically when the interpreter is shutting down.
+# To work around this problem, an exit handler is installed to tell the
+# thread to exit when the interpreter is shutting down and then waits until
+# it finishes. The thread needs to be daemonized because the atexit hooks are
+# called after all non daemonized threads are joined.
+#
+# Starting 3.9, there exists a specific atexit hook to be called before joining
+# the threads so the executor_manager_thread does not need to be daemonized
+# anymore.
+#
+# The atexit hooks are registered when starting the first ProcessPoolExecutor
+# to avoid import having an effect on the interpreter.
_global_shutdown = False
_global_shutdown_lock = threading.Lock()
_threads_wakeups = weakref.WeakKeyDictionary()
+
+
+def _python_exit():
+ global _global_shutdown
+ _global_shutdown = True
+
+ # Materialize the list of items to avoid error due to iterating over
+ # changing size dictionary.
+ items = list(_threads_wakeups.items())
+ if len(items) > 0:
+ mp.util.debug(
+ "Interpreter shutting down. Waking up {len(items)}"
+ f"executor_manager_thread:\n{items}"
+ )
+
+ # Wake up the executor_manager_thread's so they can detect the interpreter
+ # is shutting down and exit.
+ for _, (shutdown_lock, thread_wakeup) in items:
+ with shutdown_lock:
+ thread_wakeup.wakeup()
+
+ # Collect the executor_manager_thread's to make sure we exit cleanly.
+ for thread, _ in items:
+ # This locks is to prevent situations where an executor is gc'ed in one
+ # thread while the atexit finalizer is running in another thread. This
+ # can happen when joblib is used in pypy for instance.
+ with _global_shutdown_lock:
+ thread.join()
+
+
+# With the fork context, _thread_wakeups is propagated to children.
+# Clear it after fork to avoid some situation that can cause some
+# freeze when joining the workers.
mp.util.register_after_fork(_threads_wakeups, lambda obj: obj.clear())
+
+
+# Module variable to register the at_exit call
process_pool_executor_at_exit = None
+
+# Controls how many more calls than processes will be queued in the call queue.
+# A smaller number will mean that processes spend more time idle waiting for
+# work while a larger number will make Future.cancel() succeed less frequently
+# (Futures in the call queue cannot be cancelled).
EXTRA_QUEUED_CALLS = 1
@@ -116,14 +237,15 @@ class _RemoteTraceback(Exception):
return self.tb
+# Do not inherit from BaseException to mirror
+# concurrent.futures.process._ExceptionWithTraceback
class _ExceptionWithTraceback:
-
def __init__(self, exc):
- tb = getattr(exc, '__traceback__', None)
+ tb = getattr(exc, "__traceback__", None)
if tb is None:
_, _, tb = sys.exc_info()
tb = traceback.format_exception(type(exc), exc, tb)
- tb = ''.join(tb)
+ tb = "".join(tb)
self.exc = exc
self.tb = tb
@@ -131,8 +253,14 @@ class _ExceptionWithTraceback:
return _rebuild_exc, (self.exc, self.tb)
+def _rebuild_exc(exc, tb):
+ exc.__cause__ = _RemoteTraceback(tb)
+ return exc
+
+
class _WorkItem:
- __slots__ = ['future', 'fn', 'args', 'kwargs']
+
+ __slots__ = ["future", "fn", "args", "kwargs"]
def __init__(self, future, fn, args, kwargs):
self.future = future
@@ -142,7 +270,6 @@ class _WorkItem:
class _ResultItem:
-
def __init__(self, work_id, exception=None, result=None):
self.work_id = work_id
self.exception = exception
@@ -150,12 +277,13 @@ class _ResultItem:
class _CallItem:
-
def __init__(self, work_id, fn, args, kwargs):
self.work_id = work_id
self.fn = fn
self.args = args
self.kwargs = kwargs
+
+ # Store the current loky_pickler so it is correctly set in the worker
self.loky_pickler = get_loky_pickler_name()
def __call__(self):
@@ -164,23 +292,64 @@ class _CallItem:
def __repr__(self):
return (
- f'CallItem({self.work_id}, {self.fn}, {self.args}, {self.kwargs})')
+ f"CallItem({self.work_id}, {self.fn}, {self.args}, {self.kwargs})"
+ )
class _SafeQueue(Queue):
"""Safe Queue set exception to the future object linked to a job"""
- def __init__(self, max_size=0, ctx=None, pending_work_items=None,
- running_work_items=None, thread_wakeup=None, reducers=None):
+ def __init__(
+ self,
+ max_size=0,
+ ctx=None,
+ pending_work_items=None,
+ running_work_items=None,
+ thread_wakeup=None,
+ reducers=None,
+ ):
self.thread_wakeup = thread_wakeup
self.pending_work_items = pending_work_items
self.running_work_items = running_work_items
super().__init__(max_size, reducers=reducers, ctx=ctx)
+ def _on_queue_feeder_error(self, e, obj):
+ if isinstance(obj, _CallItem):
+ # format traceback only works on python3
+ if isinstance(e, struct.error):
+ raised_error = RuntimeError(
+ "The task could not be sent to the workers as it is too "
+ "large for `send_bytes`."
+ )
+ else:
+ raised_error = PicklingError(
+ "Could not pickle the task to send it to the workers."
+ )
+ tb = traceback.format_exception(
+ type(e), e, getattr(e, "__traceback__", None)
+ )
+ raised_error.__cause__ = _RemoteTraceback("".join(tb))
+ work_item = self.pending_work_items.pop(obj.work_id, None)
+ self.running_work_items.remove(obj.work_id)
+ # work_item can be None if another process terminated. In this
+ # case, the executor_manager_thread fails all work_items with
+ # BrokenProcessPool
+ if work_item is not None:
+ work_item.future.set_exception(raised_error)
+ del work_item
+ self.thread_wakeup.wakeup()
+ else:
+ super()._on_queue_feeder_error(e, obj)
+
def _get_chunks(chunksize, *iterables):
"""Iterates over zip()ed iterables in chunks."""
- pass
+ it = zip(*iterables)
+ while True:
+ chunk = tuple(itertools.islice(it, chunksize))
+ if not chunk:
+ return
+ yield chunk
def _process_chunk(fn, chunk):
@@ -192,16 +361,30 @@ def _process_chunk(fn, chunk):
This function is run in a separate process.
"""
- pass
+ return [fn(*args) for args in chunk]
def _sendback_result(result_queue, work_id, result=None, exception=None):
"""Safely send back the given result or exception"""
- pass
-
-
-def _process_worker(call_queue, result_queue, initializer, initargs,
- processes_management_lock, timeout, worker_exit_lock, current_depth):
+ try:
+ result_queue.put(
+ _ResultItem(work_id, result=result, exception=exception)
+ )
+ except BaseException as e:
+ exc = _ExceptionWithTraceback(e)
+ result_queue.put(_ResultItem(work_id, exception=exc))
+
+
+def _process_worker(
+ call_queue,
+ result_queue,
+ initializer,
+ initargs,
+ processes_management_lock,
+ timeout,
+ worker_exit_lock,
+ current_depth,
+):
"""Evaluates calls from call_queue and places the results in result_queue.
This worker is run in a separate process.
@@ -221,7 +404,111 @@ def _process_worker(call_queue, result_queue, initializer, initargs,
workers timeout.
current_depth: Nested parallelism level, to avoid infinite spawning.
"""
- pass
+ if initializer is not None:
+ try:
+ initializer(*initargs)
+ except BaseException:
+ LOGGER.critical("Exception in initializer:", exc_info=True)
+ # The parent will notice that the process stopped and
+ # mark the pool broken
+ return
+
+ # set the global _CURRENT_DEPTH mechanism to limit recursive call
+ global _CURRENT_DEPTH
+ _CURRENT_DEPTH = current_depth
+ _process_reference_size = None
+ _last_memory_leak_check = None
+ pid = os.getpid()
+
+ mp.util.debug(f"Worker started with timeout={timeout}")
+ while True:
+ try:
+ call_item = call_queue.get(block=True, timeout=timeout)
+ if call_item is None:
+ mp.util.info("Shutting down worker on sentinel")
+ except queue.Empty:
+ mp.util.info(f"Shutting down worker after timeout {timeout:0.3f}s")
+ if processes_management_lock.acquire(block=False):
+ processes_management_lock.release()
+ call_item = None
+ else:
+ mp.util.info("Could not acquire processes_management_lock")
+ continue
+ except BaseException:
+ previous_tb = traceback.format_exc()
+ try:
+ result_queue.put(_RemoteTraceback(previous_tb))
+ except BaseException:
+ # If we cannot format correctly the exception, at least print
+ # the traceback.
+ print(previous_tb)
+ mp.util.debug("Exiting with code 1")
+ sys.exit(1)
+ if call_item is None:
+ # Notify queue management thread about worker shutdown
+ result_queue.put(pid)
+ is_clean = worker_exit_lock.acquire(True, timeout=30)
+
+ # Early notify any loky executor running in this worker process
+ # (nested parallelism) that this process is about to shutdown to
+ # avoid a deadlock waiting undifinitely for the worker to finish.
+ _python_exit()
+
+ if is_clean:
+ mp.util.debug("Exited cleanly")
+ else:
+ mp.util.info("Main process did not release worker_exit")
+ return
+ try:
+ r = call_item()
+ except BaseException as e:
+ exc = _ExceptionWithTraceback(e)
+ result_queue.put(_ResultItem(call_item.work_id, exception=exc))
+ else:
+ _sendback_result(result_queue, call_item.work_id, result=r)
+ del r
+
+ # Free the resource as soon as possible, to avoid holding onto
+ # open files or shared memory that is not needed anymore
+ del call_item
+
+ if _USE_PSUTIL:
+ if _process_reference_size is None:
+ # Make reference measurement after the first call
+ _process_reference_size = _get_memory_usage(pid, force_gc=True)
+ _last_memory_leak_check = time()
+ continue
+ if time() - _last_memory_leak_check > _MEMORY_LEAK_CHECK_DELAY:
+ mem_usage = _get_memory_usage(pid)
+ _last_memory_leak_check = time()
+ if mem_usage - _process_reference_size < _MAX_MEMORY_LEAK_SIZE:
+ # Memory usage stays within bounds: everything is fine.
+ continue
+
+ # Check again memory usage; this time take the measurement
+ # after a forced garbage collection to break any reference
+ # cycles.
+ mem_usage = _get_memory_usage(pid, force_gc=True)
+ _last_memory_leak_check = time()
+ if mem_usage - _process_reference_size < _MAX_MEMORY_LEAK_SIZE:
+ # The GC managed to free the memory: everything is fine.
+ continue
+
+ # The process is leaking memory: let the main process
+ # know that we need to start a new worker.
+ mp.util.info("Memory leak detected: shutting down worker")
+ result_queue.put(pid)
+ with worker_exit_lock:
+ mp.util.debug("Exit due to memory leak")
+ return
+ else:
+ # if psutil is not installed, trigger gc.collect events
+ # regularly to limit potential memory leaks due to reference cycles
+ if _last_memory_leak_check is None or (
+ time() - _last_memory_leak_check > _MEMORY_LEAK_CHECK_DELAY
+ ):
+ gc.collect()
+ _last_memory_leak_check = time()
class _ExecutorManagerThread(threading.Thread):
@@ -237,42 +524,468 @@ class _ExecutorManagerThread(threading.Thread):
"""
def __init__(self, executor):
+ # Store references to necessary internals of the executor.
+
+ # A _ThreadWakeup to allow waking up the executor_manager_thread from
+ # the main Thread and avoid deadlocks caused by permanently
+ # locked queues.
self.thread_wakeup = executor._executor_manager_thread_wakeup
self.shutdown_lock = executor._shutdown_lock
- def weakref_cb(_, thread_wakeup=self.thread_wakeup, shutdown_lock=
- self.shutdown_lock):
+ # A weakref.ref to the ProcessPoolExecutor that owns this thread. Used
+ # to determine if the ProcessPoolExecutor has been garbage collected
+ # and that the manager can exit.
+ # When the executor gets garbage collected, the weakref callback
+ # will wake up the queue management thread so that it can terminate
+ # if there is no pending work item.
+ def weakref_cb(
+ _,
+ thread_wakeup=self.thread_wakeup,
+ shutdown_lock=self.shutdown_lock,
+ ):
if mp is not None:
+ # At this point, the multiprocessing module can already be
+ # garbage collected. We only log debug info when still
+ # possible.
mp.util.debug(
- 'Executor collected: triggering callback for QueueManager wakeup'
- )
+ "Executor collected: triggering callback for"
+ " QueueManager wakeup"
+ )
with shutdown_lock:
thread_wakeup.wakeup()
+
self.executor_reference = weakref.ref(executor, weakref_cb)
+
+ # The flags of the executor
self.executor_flags = executor._flags
+
+ # A list of the ctx.Process instances used as workers.
self.processes = executor._processes
+
+ # A ctx.Queue that will be filled with _CallItems derived from
+ # _WorkItems for processing by the process workers.
self.call_queue = executor._call_queue
+
+ # A ctx.SimpleQueue of _ResultItems generated by the process workers.
self.result_queue = executor._result_queue
+
+ # A queue.Queue of work ids e.g. Queue([5, 6, ...]).
self.work_ids_queue = executor._work_ids
+
+ # A dict mapping work ids to _WorkItems e.g.
+ # {5: <_WorkItem...>, 6: <_WorkItem...>, ...}
self.pending_work_items = executor._pending_work_items
+
+ # A list of the work_ids that are currently running
self.running_work_items = executor._running_work_items
+
+ # A lock to avoid concurrent shutdown of workers on timeout and spawn
+ # of new processes or shut down
self.processes_management_lock = executor._processes_management_lock
- super().__init__(name='ExecutorManagerThread')
+
+ super().__init__(name="ExecutorManagerThread")
if sys.version_info < (3, 9):
self.daemon = True
+ def run(self):
+ # Main loop for the executor manager thread.
+
+ while True:
+ self.add_call_item_to_queue()
+
+ result_item, is_broken, bpe = self.wait_result_broken_or_wakeup()
+
+ if is_broken:
+ self.terminate_broken(bpe)
+ return
+ if result_item is not None:
+ self.process_result_item(result_item)
+ # Delete reference to result_item to avoid keeping references
+ # while waiting on new results.
+ del result_item
+
+ if self.is_shutting_down():
+ self.flag_executor_shutting_down()
+
+ # Since no new work items can be added, it is safe to shutdown
+ # this thread if there are no pending work items.
+ if not self.pending_work_items:
+ self.join_executor_internals()
+ return
+
+ def add_call_item_to_queue(self):
+ # Fills call_queue with _WorkItems from pending_work_items.
+ # This function never blocks.
+ while True:
+ if self.call_queue.full():
+ return
+ try:
+ work_id = self.work_ids_queue.get(block=False)
+ except queue.Empty:
+ return
+ else:
+ work_item = self.pending_work_items[work_id]
+
+ if work_item.future.set_running_or_notify_cancel():
+ self.running_work_items += [work_id]
+ self.call_queue.put(
+ _CallItem(
+ work_id,
+ work_item.fn,
+ work_item.args,
+ work_item.kwargs,
+ ),
+ block=True,
+ )
+ else:
+ del self.pending_work_items[work_id]
+ continue
+
+ def wait_result_broken_or_wakeup(self):
+ # Wait for a result to be ready in the result_queue while checking
+ # that all worker processes are still running, or for a wake up
+ # signal send. The wake up signals come either from new tasks being
+ # submitted, from the executor being shutdown/gc-ed, or from the
+ # shutdown of the python interpreter.
+ result_reader = self.result_queue._reader
+ wakeup_reader = self.thread_wakeup._reader
+ readers = [result_reader, wakeup_reader]
+ worker_sentinels = [p.sentinel for p in list(self.processes.values())]
+ ready = wait(readers + worker_sentinels)
+
+ bpe = None
+ is_broken = True
+ result_item = None
+ if result_reader in ready:
+ try:
+ result_item = result_reader.recv()
+ if isinstance(result_item, _RemoteTraceback):
+ bpe = BrokenProcessPool(
+ "A task has failed to un-serialize. Please ensure that"
+ " the arguments of the function are all picklable."
+ )
+ bpe.__cause__ = result_item
+ else:
+ is_broken = False
+ except BaseException as e:
+ bpe = BrokenProcessPool(
+ "A result has failed to un-serialize. Please ensure that "
+ "the objects returned by the function are always "
+ "picklable."
+ )
+ tb = traceback.format_exception(
+ type(e), e, getattr(e, "__traceback__", None)
+ )
+ bpe.__cause__ = _RemoteTraceback("".join(tb))
+
+ elif wakeup_reader in ready:
+ # This is simply a wake-up event that might either trigger putting
+ # more tasks in the queue or trigger the clean up of resources.
+ is_broken = False
+ else:
+ # A worker has terminated and we don't know why, set the state of
+ # the executor as broken
+ exit_codes = ""
+ if sys.platform != "win32":
+ # In Windows, introspecting terminated workers exitcodes seems
+ # unstable, therefore they are not appended in the exception
+ # message.
+ exit_codes = (
+ "\nThe exit codes of the workers are "
+ f"{get_exitcodes_terminated_worker(self.processes)}"
+ )
+ mp.util.debug(
+ "A worker unexpectedly terminated. Workers that "
+ "might have caused the breakage: "
+ + str(
+ {
+ p.name: p.exitcode
+ for p in list(self.processes.values())
+ if p is not None and p.sentinel in ready
+ }
+ )
+ )
+ bpe = TerminatedWorkerError(
+ "A worker process managed by the executor was unexpectedly "
+ "terminated. This could be caused by a segmentation fault "
+ "while calling the function or by an excessive memory usage "
+ "causing the Operating System to kill the worker.\n"
+ f"{exit_codes}"
+ )
+
+ self.thread_wakeup.clear()
+
+ return result_item, is_broken, bpe
+
+ def process_result_item(self, result_item):
+ # Process the received a result_item. This can be either the PID of a
+ # worker that exited gracefully or a _ResultItem
+
+ if isinstance(result_item, int):
+ # Clean shutdown of a worker using its PID, either on request
+ # by the executor.shutdown method or by the timeout of the worker
+ # itself: we should not mark the executor as broken.
+ with self.processes_management_lock:
+ p = self.processes.pop(result_item, None)
+
+ # p can be None if the executor is concurrently shutting down.
+ if p is not None:
+ p._worker_exit_lock.release()
+ mp.util.debug(
+ f"joining {p.name} when processing {p.pid} as result_item"
+ )
+ p.join()
+ del p
+
+ # Make sure the executor have the right number of worker, even if a
+ # worker timeout while some jobs were submitted. If some work is
+ # pending or there is less processes than running items, we need to
+ # start a new Process and raise a warning.
+ n_pending = len(self.pending_work_items)
+ n_running = len(self.running_work_items)
+ if n_pending - n_running > 0 or n_running > len(self.processes):
+ executor = self.executor_reference()
+ if (
+ executor is not None
+ and len(self.processes) < executor._max_workers
+ ):
+ warnings.warn(
+ "A worker stopped while some jobs were given to the "
+ "executor. This can be caused by a too short worker "
+ "timeout or by a memory leak.",
+ UserWarning,
+ )
+ with executor._processes_management_lock:
+ executor._adjust_process_count()
+ executor = None
+ else:
+ # Received a _ResultItem so mark the future as completed.
+ work_item = self.pending_work_items.pop(result_item.work_id, None)
+ # work_item can be None if another process terminated (see above)
+ if work_item is not None:
+ if result_item.exception:
+ work_item.future.set_exception(result_item.exception)
+ else:
+ work_item.future.set_result(result_item.result)
+ self.running_work_items.remove(result_item.work_id)
+
+ def is_shutting_down(self):
+ # Check whether we should start shutting down the executor.
+ executor = self.executor_reference()
+ # No more work items can be added if:
+ # - The interpreter is shutting down OR
+ # - The executor that owns this thread is not broken AND
+ # * The executor that owns this worker has been collected OR
+ # * The executor that owns this worker has been shutdown.
+ # If the executor is broken, it should be detected in the next loop.
+ return _global_shutdown or (
+ (executor is None or self.executor_flags.shutdown)
+ and not self.executor_flags.broken
+ )
+
+ def terminate_broken(self, bpe):
+ # Terminate the executor because it is in a broken state. The bpe
+ # argument can be used to display more information on the error that
+ # lead the executor into becoming broken.
+
+ # Mark the process pool broken so that submits fail right now.
+ self.executor_flags.flag_as_broken(bpe)
+
+ # Mark pending tasks as failed.
+ for work_item in self.pending_work_items.values():
+ work_item.future.set_exception(bpe)
+ # Delete references to object. See issue16284
+ del work_item
+ self.pending_work_items.clear()
+
+ # Terminate remaining workers forcibly: the queues or their
+ # locks may be in a dirty state and block forever.
+ self.kill_workers(reason="broken executor")
+
+ # clean up resources
+ self.join_executor_internals()
+
+ def flag_executor_shutting_down(self):
+ # Flag the executor as shutting down and cancel remaining tasks if
+ # requested as early as possible if it is not gc-ed yet.
+ self.executor_flags.flag_as_shutting_down()
+
+ # Cancel pending work items if requested.
+ if self.executor_flags.kill_workers:
+ while self.pending_work_items:
+ _, work_item = self.pending_work_items.popitem()
+ work_item.future.set_exception(
+ ShutdownExecutorError(
+ "The Executor was shutdown with `kill_workers=True` "
+ "before this job could complete."
+ )
+ )
+ del work_item
+
+ # Kill the remaining worker forcibly to no waste time joining them
+ self.kill_workers(reason="executor shutting down")
+
+ def kill_workers(self, reason=""):
+ # Terminate the remaining workers using SIGKILL. This function also
+ # terminates descendant workers of the children in case there is some
+ # nested parallelism.
+ while self.processes:
+ _, p = self.processes.popitem()
+ mp.util.debug(f"terminate process {p.name}, reason: {reason}")
+ try:
+ kill_process_tree(p)
+ except ProcessLookupError: # pragma: no cover
+ pass
+
+ def shutdown_workers(self):
+ # shutdown all workers in self.processes
+
+ # Create a list to avoid RuntimeError due to concurrent modification of
+ # processes. nb_children_alive is thus an upper bound. Also release the
+ # processes' _worker_exit_lock to accelerate the shutdown procedure, as
+ # there is no need for hand-shake here.
+ with self.processes_management_lock:
+ n_children_to_stop = 0
+ for p in list(self.processes.values()):
+ mp.util.debug(f"releasing worker exit lock on {p.name}")
+ p._worker_exit_lock.release()
+ n_children_to_stop += 1
+
+ mp.util.debug(f"found {n_children_to_stop} processes to stop")
+
+ # Send the right number of sentinels, to make sure all children are
+ # properly terminated. Do it with a mechanism that avoid hanging on
+ # Full queue when all workers have already been shutdown.
+ n_sentinels_sent = 0
+ cooldown_time = 0.001
+ while (
+ n_sentinels_sent < n_children_to_stop
+ and self.get_n_children_alive() > 0
+ ):
+ for _ in range(n_children_to_stop - n_sentinels_sent):
+ try:
+ self.call_queue.put_nowait(None)
+ n_sentinels_sent += 1
+ except queue.Full as e:
+ if cooldown_time > 5.0:
+ mp.util.info(
+ "failed to send all sentinels and exit with error."
+ f"\ncall_queue size={self.call_queue._maxsize}; "
+ f" full is {self.call_queue.full()}; "
+ )
+ raise e
+ mp.util.info(
+ "full call_queue prevented to send all sentinels at "
+ "once, waiting..."
+ )
+ sleep(cooldown_time)
+ cooldown_time *= 1.2
+ break
+
+ mp.util.debug(f"sent {n_sentinels_sent} sentinels to the call queue")
+
+ def join_executor_internals(self):
+ self.shutdown_workers()
+
+ # Release the queue's resources as soon as possible. Flag the feeder
+ # thread for clean exit to avoid having the crash detection thread flag
+ # the Executor as broken during the shutdown. This is safe as either:
+ # * We don't need to communicate with the workers anymore
+ # * There is nothing left in the Queue buffer except None sentinels
+ mp.util.debug("closing call_queue")
+ self.call_queue.close()
+ self.call_queue.join_thread()
+
+ # Closing result_queue
+ mp.util.debug("closing result_queue")
+ self.result_queue.close()
+
+ mp.util.debug("closing thread_wakeup")
+ with self.shutdown_lock:
+ self.thread_wakeup.close()
+
+ # If .join() is not called on the created processes then
+ # some ctx.Queue methods may deadlock on macOS.
+ with self.processes_management_lock:
+ mp.util.debug(f"joining {len(self.processes)} processes")
+ n_joined_processes = 0
+ while True:
+ try:
+ pid, p = self.processes.popitem()
+ mp.util.debug(f"joining process {p.name} with pid {pid}")
+ p.join()
+ n_joined_processes += 1
+ except KeyError:
+ break
+
+ mp.util.debug(
+ "executor management thread clean shutdown of "
+ f"{n_joined_processes} workers"
+ )
+
+ def get_n_children_alive(self):
+ # This is an upper bound on the number of children alive.
+ with self.processes_management_lock:
+ return sum(p.is_alive() for p in list(self.processes.values()))
+
_system_limits_checked = False
_system_limited = None
+def _check_system_limits():
+ global _system_limits_checked, _system_limited
+ if _system_limits_checked and _system_limited:
+ raise NotImplementedError(_system_limited)
+ _system_limits_checked = True
+ try:
+ nsems_max = os.sysconf("SC_SEM_NSEMS_MAX")
+ except (AttributeError, ValueError):
+ # sysconf not available or setting not available
+ return
+ if nsems_max == -1:
+ # undetermined limit, assume that limit is determined
+ # by available memory only
+ return
+ if nsems_max >= 256:
+ # minimum number of semaphores available
+ # according to POSIX
+ return
+ _system_limited = (
+ f"system provides too few semaphores ({nsems_max} available, "
+ "256 necessary)"
+ )
+ raise NotImplementedError(_system_limited)
+
+
def _chain_from_iterable_of_lists(iterable):
"""
Specialized implementation of itertools.chain.from_iterable.
Each item in *iterable* should be a list. This function is
careful not to keep references to yielded objects.
"""
- pass
+ for element in iterable:
+ element.reverse()
+ while element:
+ yield element.pop()
+
+
+def _check_max_depth(context):
+ # Limit the maxmal recursion level
+ global _CURRENT_DEPTH
+ if context.get_start_method() == "fork" and _CURRENT_DEPTH > 0:
+ raise LokyRecursionError(
+ "Could not spawn extra nested processes at depth superior to "
+ "MAX_DEPTH=1. It is not possible to increase this limit when "
+ "using the 'fork' start method."
+ )
+
+ if 0 < MAX_DEPTH and _CURRENT_DEPTH + 1 > MAX_DEPTH:
+ raise LokyRecursionError(
+ "Could not spawn extra nested processes at depth superior to "
+ f"MAX_DEPTH={MAX_DEPTH}. If this is intendend, you can change "
+ "this limit with the LOKY_MAX_DEPTH environment variable."
+ )
class LokyRecursionError(RuntimeError):
@@ -295,10 +1008,13 @@ class TerminatedWorkerError(BrokenProcessPool):
"""
+# Alias for backward compat (for code written for loky 1.1.4 and earlier). Do
+# not use in new code.
BrokenExecutor = BrokenProcessPool
class ShutdownExecutorError(RuntimeError):
+
"""
Raised when a ProcessPoolExecutor is shutdown while a future was in the
running or pending state.
@@ -306,11 +1022,20 @@ class ShutdownExecutorError(RuntimeError):
class ProcessPoolExecutor(Executor):
+
_at_exit = None
- def __init__(self, max_workers=None, job_reducers=None, result_reducers
- =None, timeout=None, context=None, initializer=None, initargs=(),
- env=None):
+ def __init__(
+ self,
+ max_workers=None,
+ job_reducers=None,
+ result_reducers=None,
+ timeout=None,
+ context=None,
+ initializer=None,
+ initargs=(),
+ env=None,
+ ):
"""Initializes a new ProcessPoolExecutor instance.
Args:
@@ -336,30 +1061,47 @@ class ProcessPoolExecutor(Executor):
loaded. Note that this only works with the loky context.
"""
_check_system_limits()
+
if max_workers is None:
self._max_workers = cpu_count()
else:
if max_workers <= 0:
- raise ValueError('max_workers must be greater than 0')
+ raise ValueError("max_workers must be greater than 0")
self._max_workers = max_workers
- if (sys.platform == 'win32' and self._max_workers >
- _MAX_WINDOWS_WORKERS):
+
+ if (
+ sys.platform == "win32"
+ and self._max_workers > _MAX_WINDOWS_WORKERS
+ ):
warnings.warn(
- f'On Windows, max_workers cannot exceed {_MAX_WINDOWS_WORKERS} due to limitations of the operating system.'
- )
+ f"On Windows, max_workers cannot exceed {_MAX_WINDOWS_WORKERS} "
+ "due to limitations of the operating system."
+ )
self._max_workers = _MAX_WINDOWS_WORKERS
+
if context is None:
context = get_context()
self._context = context
self._env = env
- self._initializer, self._initargs = _prepare_initializer(initializer,
- initargs)
+
+ self._initializer, self._initargs = _prepare_initializer(
+ initializer, initargs
+ )
_check_max_depth(self._context)
+
if result_reducers is None:
result_reducers = job_reducers
+
+ # Timeout
self._timeout = timeout
+
+ # Management thread
self._executor_manager_thread = None
+
+ # Map of pids to processes
self._processes = {}
+
+ # Internal variables of the ProcessPoolExecutor
self._processes = {}
self._queue_count = 0
self._pending_work_items = {}
@@ -368,14 +1110,144 @@ class ProcessPoolExecutor(Executor):
self._processes_management_lock = self._context.Lock()
self._executor_manager_thread = None
self._shutdown_lock = threading.Lock()
+
+ # _ThreadWakeup is a communication channel used to interrupt the wait
+ # of the main loop of executor_manager_thread from another thread (e.g.
+ # when calling executor.submit or executor.shutdown). We do not use the
+ # _result_queue to send wakeup signals to the executor_manager_thread
+ # as it could result in a deadlock if a worker process dies with the
+ # _result_queue write lock still acquired.
+ #
+ # _shutdown_lock must be locked to access _ThreadWakeup.wakeup.
self._executor_manager_thread_wakeup = _ThreadWakeup()
+
+ # Flag to hold the state of the Executor. This permits to introspect
+ # the Executor state even once it has been garbage collected.
self._flags = _ExecutorFlags(self._shutdown_lock)
+
+ # Finally setup the queues for interprocess communication
self._setup_queues(job_reducers, result_reducers)
- mp.util.debug('ProcessPoolExecutor is setup')
+
+ mp.util.debug("ProcessPoolExecutor is setup")
+
+ def _setup_queues(self, job_reducers, result_reducers, queue_size=None):
+ # Make the call queue slightly larger than the number of processes to
+ # prevent the worker processes from idling. But don't make it too big
+ # because futures in the call queue cannot be cancelled.
+ if queue_size is None:
+ queue_size = 2 * self._max_workers + EXTRA_QUEUED_CALLS
+ self._call_queue = _SafeQueue(
+ max_size=queue_size,
+ pending_work_items=self._pending_work_items,
+ running_work_items=self._running_work_items,
+ thread_wakeup=self._executor_manager_thread_wakeup,
+ reducers=job_reducers,
+ ctx=self._context,
+ )
+ # Killed worker processes can produce spurious "broken pipe"
+ # tracebacks in the queue's own worker thread. But we detect killed
+ # processes anyway, so silence the tracebacks.
+ self._call_queue._ignore_epipe = True
+
+ self._result_queue = SimpleQueue(
+ reducers=result_reducers, ctx=self._context
+ )
+
+ def _start_executor_manager_thread(self):
+ if self._executor_manager_thread is None:
+ mp.util.debug("_start_executor_manager_thread called")
+
+ # Start the processes so that their sentinels are known.
+ self._executor_manager_thread = _ExecutorManagerThread(self)
+ self._executor_manager_thread.start()
+
+ # register this executor in a mechanism that ensures it will wakeup
+ # when the interpreter is exiting.
+ _threads_wakeups[self._executor_manager_thread] = (
+ self._shutdown_lock,
+ self._executor_manager_thread_wakeup,
+ )
+
+ global process_pool_executor_at_exit
+ if process_pool_executor_at_exit is None:
+ # Ensure that the _python_exit function will be called before
+ # the multiprocessing.Queue._close finalizers which have an
+ # exitpriority of 10.
+
+ if sys.version_info < (3, 9):
+ process_pool_executor_at_exit = mp.util.Finalize(
+ None, _python_exit, exitpriority=20
+ )
+ else:
+ process_pool_executor_at_exit = threading._register_atexit(
+ _python_exit
+ )
+
+ def _adjust_process_count(self):
+ while len(self._processes) < self._max_workers:
+ worker_exit_lock = self._context.BoundedSemaphore(1)
+ args = (
+ self._call_queue,
+ self._result_queue,
+ self._initializer,
+ self._initargs,
+ self._processes_management_lock,
+ self._timeout,
+ worker_exit_lock,
+ _CURRENT_DEPTH + 1,
+ )
+ worker_exit_lock.acquire()
+ try:
+ # Try to spawn the process with some environment variable to
+ # overwrite but it only works with the loky context for now.
+ p = self._context.Process(
+ target=_process_worker, args=args, env=self._env
+ )
+ except TypeError:
+ p = self._context.Process(target=_process_worker, args=args)
+ p._worker_exit_lock = worker_exit_lock
+ p.start()
+ self._processes[p.pid] = p
+ mp.util.debug(
+ f"Adjusted process count to {self._max_workers}: "
+ f"{[(p.name, pid) for pid, p in self._processes.items()]}"
+ )
def _ensure_executor_running(self):
"""ensures all workers and management thread are running"""
- pass
+ with self._processes_management_lock:
+ if len(self._processes) != self._max_workers:
+ self._adjust_process_count()
+ self._start_executor_manager_thread()
+
+ def submit(self, fn, *args, **kwargs):
+ with self._flags.shutdown_lock:
+ if self._flags.broken is not None:
+ raise self._flags.broken
+ if self._flags.shutdown:
+ raise ShutdownExecutorError(
+ "cannot schedule new futures after shutdown"
+ )
+
+ # Cannot submit a new calls once the interpreter is shutting down.
+ # This check avoids spawning new processes at exit.
+ if _global_shutdown:
+ raise RuntimeError(
+ "cannot schedule new futures after " "interpreter shutdown"
+ )
+
+ f = Future()
+ w = _WorkItem(f, fn, args, kwargs)
+
+ self._pending_work_items[self._queue_count] = w
+ self._work_ids.put(self._queue_count)
+ self._queue_count += 1
+ # Wake up queue management thread
+ self._executor_manager_thread_wakeup.wakeup()
+
+ self._ensure_executor_running()
+ return f
+
submit.__doc__ = Executor.submit.__doc__
def map(self, fn, *iterables, **kwargs):
@@ -400,5 +1272,43 @@ class ProcessPoolExecutor(Executor):
before the given timeout.
Exception: If fn(*args) raises for any values.
"""
- pass
+ timeout = kwargs.get("timeout", None)
+ chunksize = kwargs.get("chunksize", 1)
+ if chunksize < 1:
+ raise ValueError("chunksize must be >= 1.")
+
+ results = super().map(
+ partial(_process_chunk, fn),
+ _get_chunks(chunksize, *iterables),
+ timeout=timeout,
+ )
+ return _chain_from_iterable_of_lists(results)
+
+ def shutdown(self, wait=True, kill_workers=False):
+ mp.util.debug(f"shutting down executor {self}")
+
+ self._flags.flag_as_shutting_down(kill_workers)
+ executor_manager_thread = self._executor_manager_thread
+ executor_manager_thread_wakeup = self._executor_manager_thread_wakeup
+
+ if executor_manager_thread_wakeup is not None:
+ # Wake up queue management thread
+ with self._shutdown_lock:
+ self._executor_manager_thread_wakeup.wakeup()
+
+ if executor_manager_thread is not None and wait:
+ # This locks avoids concurrent join if the interpreter
+ # is shutting down.
+ with _global_shutdown_lock:
+ executor_manager_thread.join()
+ _threads_wakeups.pop(executor_manager_thread, None)
+
+ # To reduce the risk of opening too many files, remove references to
+ # objects that use file descriptors.
+ self._executor_manager_thread = None
+ self._executor_manager_thread_wakeup = None
+ self._call_queue = None
+ self._result_queue = None
+ self._processes_management_lock = None
+
shutdown.__doc__ = Executor.shutdown.__doc__
diff --git a/joblib/externals/loky/reusable_executor.py b/joblib/externals/loky/reusable_executor.py
index 5509ecc..ad016fd 100644
--- a/joblib/externals/loky/reusable_executor.py
+++ b/joblib/externals/loky/reusable_executor.py
@@ -1,11 +1,20 @@
+###############################################################################
+# Reusable ProcessPoolExecutor
+#
+# author: Thomas Moreau and Olivier Grisel
+#
import time
import warnings
import threading
import multiprocessing as mp
+
from .process_executor import ProcessPoolExecutor, EXTRA_QUEUED_CALLS
from .backend.context import cpu_count
from .backend import get_context
-__all__ = ['get_reusable_executor']
+
+__all__ = ["get_reusable_executor"]
+
+# Singleton executor and id management
_executor_lock = threading.RLock()
_next_executor_id = 0
_executor = None
@@ -18,12 +27,25 @@ def _get_next_executor_id():
The purpose of this monotonic id is to help debug and test automated
instance creation.
"""
- pass
+ global _next_executor_id
+ with _executor_lock:
+ executor_id = _next_executor_id
+ _next_executor_id += 1
+ return executor_id
-def get_reusable_executor(max_workers=None, context=None, timeout=10,
- kill_workers=False, reuse='auto', job_reducers=None, result_reducers=
- None, initializer=None, initargs=(), env=None):
+def get_reusable_executor(
+ max_workers=None,
+ context=None,
+ timeout=10,
+ kill_workers=False,
+ reuse="auto",
+ job_reducers=None,
+ result_reducers=None,
+ initializer=None,
+ initargs=(),
+ env=None,
+):
"""Return the current ReusableExectutor instance.
Start a new instance if it has not been started already or if the previous
@@ -64,21 +86,200 @@ def get_reusable_executor(max_workers=None, context=None, timeout=10,
in the children before any module is loaded. This only works with the
``loky`` context.
"""
- pass
+ _executor, _ = _ReusablePoolExecutor.get_reusable_executor(
+ max_workers=max_workers,
+ context=context,
+ timeout=timeout,
+ kill_workers=kill_workers,
+ reuse=reuse,
+ job_reducers=job_reducers,
+ result_reducers=result_reducers,
+ initializer=initializer,
+ initargs=initargs,
+ env=env,
+ )
+ return _executor
class _ReusablePoolExecutor(ProcessPoolExecutor):
-
- def __init__(self, submit_resize_lock, max_workers=None, context=None,
- timeout=None, executor_id=0, job_reducers=None, result_reducers=
- None, initializer=None, initargs=(), env=None):
- super().__init__(max_workers=max_workers, context=context, timeout=
- timeout, job_reducers=job_reducers, result_reducers=
- result_reducers, initializer=initializer, initargs=initargs,
- env=env)
+ def __init__(
+ self,
+ submit_resize_lock,
+ max_workers=None,
+ context=None,
+ timeout=None,
+ executor_id=0,
+ job_reducers=None,
+ result_reducers=None,
+ initializer=None,
+ initargs=(),
+ env=None,
+ ):
+ super().__init__(
+ max_workers=max_workers,
+ context=context,
+ timeout=timeout,
+ job_reducers=job_reducers,
+ result_reducers=result_reducers,
+ initializer=initializer,
+ initargs=initargs,
+ env=env,
+ )
self.executor_id = executor_id
self._submit_resize_lock = submit_resize_lock
+ @classmethod
+ def get_reusable_executor(
+ cls,
+ max_workers=None,
+ context=None,
+ timeout=10,
+ kill_workers=False,
+ reuse="auto",
+ job_reducers=None,
+ result_reducers=None,
+ initializer=None,
+ initargs=(),
+ env=None,
+ ):
+ with _executor_lock:
+ global _executor, _executor_kwargs
+ executor = _executor
+
+ if max_workers is None:
+ if reuse is True and executor is not None:
+ max_workers = executor._max_workers
+ else:
+ max_workers = cpu_count()
+ elif max_workers <= 0:
+ raise ValueError(
+ f"max_workers must be greater than 0, got {max_workers}."
+ )
+
+ if isinstance(context, str):
+ context = get_context(context)
+ if context is not None and context.get_start_method() == "fork":
+ raise ValueError(
+ "Cannot use reusable executor with the 'fork' context"
+ )
+
+ kwargs = dict(
+ context=context,
+ timeout=timeout,
+ job_reducers=job_reducers,
+ result_reducers=result_reducers,
+ initializer=initializer,
+ initargs=initargs,
+ env=env,
+ )
+ if executor is None:
+ is_reused = False
+ mp.util.debug(
+ f"Create a executor with max_workers={max_workers}."
+ )
+ executor_id = _get_next_executor_id()
+ _executor_kwargs = kwargs
+ _executor = executor = cls(
+ _executor_lock,
+ max_workers=max_workers,
+ executor_id=executor_id,
+ **kwargs,
+ )
+ else:
+ if reuse == "auto":
+ reuse = kwargs == _executor_kwargs
+ if (
+ executor._flags.broken
+ or executor._flags.shutdown
+ or not reuse
+ ):
+ if executor._flags.broken:
+ reason = "broken"
+ elif executor._flags.shutdown:
+ reason = "shutdown"
+ else:
+ reason = "arguments have changed"
+ mp.util.debug(
+ "Creating a new executor with max_workers="
+ f"{max_workers} as the previous instance cannot be "
+ f"reused ({reason})."
+ )
+ executor.shutdown(wait=True, kill_workers=kill_workers)
+ _executor = executor = _executor_kwargs = None
+ # Recursive call to build a new instance
+ return cls.get_reusable_executor(
+ max_workers=max_workers, **kwargs
+ )
+ else:
+ mp.util.debug(
+ "Reusing existing executor with "
+ f"max_workers={executor._max_workers}."
+ )
+ is_reused = True
+ executor._resize(max_workers)
+
+ return executor, is_reused
+
+ def submit(self, fn, *args, **kwargs):
+ with self._submit_resize_lock:
+ return super().submit(fn, *args, **kwargs)
+
+ def _resize(self, max_workers):
+ with self._submit_resize_lock:
+ if max_workers is None:
+ raise ValueError("Trying to resize with max_workers=None")
+ elif max_workers == self._max_workers:
+ return
+
+ if self._executor_manager_thread is None:
+ # If the executor_manager_thread has not been started
+ # then no processes have been spawned and we can just
+ # update _max_workers and return
+ self._max_workers = max_workers
+ return
+
+ self._wait_job_completion()
+
+ # Some process might have returned due to timeout so check how many
+ # children are still alive. Use the _process_management_lock to
+ # ensure that no process are spawned or timeout during the resize.
+ with self._processes_management_lock:
+ processes = list(self._processes.values())
+ nb_children_alive = sum(p.is_alive() for p in processes)
+ self._max_workers = max_workers
+ for _ in range(max_workers, nb_children_alive):
+ self._call_queue.put(None)
+ while (
+ len(self._processes) > max_workers and not self._flags.broken
+ ):
+ time.sleep(1e-3)
+
+ self._adjust_process_count()
+ processes = list(self._processes.values())
+ while not all(p.is_alive() for p in processes):
+ time.sleep(1e-3)
+
def _wait_job_completion(self):
"""Wait for the cache to be empty before resizing the pool."""
- pass
+ # Issue a warning to the user about the bad effect of this usage.
+ if self._pending_work_items:
+ warnings.warn(
+ "Trying to resize an executor with running jobs: "
+ "waiting for jobs completion before resizing.",
+ UserWarning,
+ )
+ mp.util.debug(
+ f"Executor {self.executor_id} waiting for jobs completion "
+ "before resizing"
+ )
+ # Wait for the completion of the jobs
+ while self._pending_work_items:
+ time.sleep(1e-3)
+
+ def _setup_queues(self, job_reducers, result_reducers):
+ # As this executor can be resized, use a large queue size to avoid
+ # underestimating capacity and introducing overhead
+ queue_size = 2 * cpu_count() + EXTRA_QUEUED_CALLS
+ super()._setup_queues(
+ job_reducers, result_reducers, queue_size=queue_size
+ )
diff --git a/joblib/func_inspect.py b/joblib/func_inspect.py
index 5a263a0..3f80946 100644
--- a/joblib/func_inspect.py
+++ b/joblib/func_inspect.py
@@ -1,16 +1,24 @@
"""
My own variation on function-specific inspect-like features.
"""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2009 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
import inspect
import warnings
import re
import os
import collections
+
from itertools import islice
from tokenize import open as open_py_source
+
from .logger import pformat
-full_argspec_fields = (
- 'args varargs varkw defaults kwonlyargs kwonlydefaults annotations')
+
+full_argspec_fields = ('args varargs varkw defaults kwonlyargs '
+ 'kwonlydefaults annotations')
full_argspec_type = collections.namedtuple('FullArgSpec', full_argspec_fields)
@@ -35,12 +43,53 @@ def get_func_code(func):
This function does a bit more magic than inspect, and is thus
more robust.
"""
- pass
+ source_file = None
+ try:
+ code = func.__code__
+ source_file = code.co_filename
+ if not os.path.exists(source_file):
+ # Use inspect for lambda functions and functions defined in an
+ # interactive shell, or in doctests
+ source_code = ''.join(inspect.getsourcelines(func)[0])
+ line_no = 1
+ if source_file.startswith('<doctest '):
+ source_file, line_no = re.match(
+ r'\<doctest (.*\.rst)\[(.*)\]\>', source_file).groups()
+ line_no = int(line_no)
+ source_file = '<doctest %s>' % source_file
+ return source_code, source_file, line_no
+ # Try to retrieve the source code.
+ with open_py_source(source_file) as source_file_obj:
+ first_line = code.co_firstlineno
+ # All the lines after the function definition:
+ source_lines = list(islice(source_file_obj, first_line - 1, None))
+ return ''.join(inspect.getblock(source_lines)), source_file, first_line
+ except: # noqa: E722
+ # If the source code fails, we use the hash. This is fragile and
+ # might change from one session to another.
+ if hasattr(func, '__code__'):
+ # Python 3.X
+ return str(func.__code__.__hash__()), source_file, -1
+ else:
+ # Weird objects like numpy ufunc don't have __code__
+ # This is fragile, as quite often the id of the object is
+ # in the repr, so it might not persist across sessions,
+ # however it will work for ufuncs.
+ return repr(func), source_file, -1
def _clean_win_chars(string):
"""Windows cannot encode some characters in filename."""
- pass
+ import urllib
+ if hasattr(urllib, 'quote'):
+ quote = urllib.quote
+ else:
+ # In Python 3, quote is elsewhere
+ import urllib.parse
+ quote = urllib.parse.quote
+ for char in ('<', '>', '!', ':', '\\'):
+ string = string.replace(char, quote(char))
+ return string
def get_func_name(func, resolv_alias=True, win_characters=True):
@@ -57,17 +106,96 @@ def get_func_name(func, resolv_alias=True, win_characters=True):
If true, substitute special characters using urllib.quote
This is useful in Windows, as it cannot encode some filenames
"""
- pass
+ if hasattr(func, '__module__'):
+ module = func.__module__
+ else:
+ try:
+ module = inspect.getmodule(func)
+ except TypeError:
+ if hasattr(func, '__class__'):
+ module = func.__class__.__module__
+ else:
+ module = 'unknown'
+ if module is None:
+ # Happens in doctests, eg
+ module = ''
+ if module == '__main__':
+ try:
+ filename = os.path.abspath(inspect.getsourcefile(func))
+ except: # noqa: E722
+ filename = None
+ if filename is not None:
+ # mangling of full path to filename
+ parts = filename.split(os.sep)
+ if parts[-1].startswith('<ipython-input'):
+ # We're in a IPython (or notebook) session. parts[-1] comes
+ # from func.__code__.co_filename and is of the form
+ # <ipython-input-N-XYZ>, where:
+ # - N is the cell number where the function was defined
+ # - XYZ is a hash representing the function's code (and name).
+ # It will be consistent across sessions and kernel restarts,
+ # and will change if the function's code/name changes
+ # We remove N so that cache is properly hit if the cell where
+ # the func is defined is re-exectuted.
+ # The XYZ hash should avoid collisions between functions with
+ # the same name, both within the same notebook but also across
+ # notebooks
+ splitted = parts[-1].split('-')
+ parts[-1] = '-'.join(splitted[:2] + splitted[3:])
+ elif len(parts) > 2 and parts[-2].startswith('ipykernel_'):
+ # In a notebook session (ipykernel). Filename seems to be 'xyz'
+ # of above. parts[-2] has the structure ipykernel_XXXXXX where
+ # XXXXXX is a six-digit number identifying the current run (?).
+ # If we split it off, the function again has the same
+ # identifier across runs.
+ parts[-2] = 'ipykernel'
+ filename = '-'.join(parts)
+ if filename.endswith('.py'):
+ filename = filename[:-3]
+ module = module + '-' + filename
+ module = module.split('.')
+ if hasattr(func, 'func_name'):
+ name = func.func_name
+ elif hasattr(func, '__name__'):
+ name = func.__name__
+ else:
+ name = 'unknown'
+ # Hack to detect functions not defined at the module-level
+ if resolv_alias:
+ # TODO: Maybe add a warning here?
+ if hasattr(func, 'func_globals') and name in func.func_globals:
+ if not func.func_globals[name] is func:
+ name = '%s-alias' % name
+ if hasattr(func, '__qualname__') and func.__qualname__ != name:
+ # Extend the module name in case of nested functions to avoid
+ # (module, name) collisions
+ module.extend(func.__qualname__.split(".")[:-1])
+ if inspect.ismethod(func):
+ # We need to add the name of the class
+ if hasattr(func, 'im_class'):
+ klass = func.im_class
+ module.append(klass.__name__)
+ if os.name == 'nt' and win_characters:
+ # Windows can't encode certain characters in filenames
+ name = _clean_win_chars(name)
+ module = [_clean_win_chars(s) for s in module]
+ return module, name
def _signature_str(function_name, arg_sig):
"""Helper function to output a function signature"""
- pass
+ return '{}{}'.format(function_name, arg_sig)
def _function_called_str(function_name, args, kwargs):
"""Helper function to output a function call"""
- pass
+ template_str = '{0}({1}, {2})'
+
+ args_str = repr(args)[1:-1]
+ kwargs_str = ', '.join('%s=%s' % (k, v)
+ for k, v in kwargs.items())
+ return template_str.format(function_name, args_str,
+ kwargs_str)
def filter_args(func, ignore_lst, args=(), kwargs=dict()):
@@ -91,11 +219,151 @@ def filter_args(func, ignore_lst, args=(), kwargs=dict()):
filtered_args: list
List of filtered positional and keyword arguments.
"""
- pass
+ args = list(args)
+ if isinstance(ignore_lst, str):
+ # Catch a common mistake
+ raise ValueError(
+ 'ignore_lst must be a list of parameters to ignore '
+ '%s (type %s) was given' % (ignore_lst, type(ignore_lst)))
+ # Special case for functools.partial objects
+ if (not inspect.ismethod(func) and not inspect.isfunction(func)):
+ if ignore_lst:
+ warnings.warn('Cannot inspect object %s, ignore list will '
+ 'not work.' % func, stacklevel=2)
+ return {'*': args, '**': kwargs}
+ arg_sig = inspect.signature(func)
+ arg_names = []
+ arg_defaults = []
+ arg_kwonlyargs = []
+ arg_varargs = None
+ arg_varkw = None
+ for param in arg_sig.parameters.values():
+ if param.kind is param.POSITIONAL_OR_KEYWORD:
+ arg_names.append(param.name)
+ elif param.kind is param.KEYWORD_ONLY:
+ arg_names.append(param.name)
+ arg_kwonlyargs.append(param.name)
+ elif param.kind is param.VAR_POSITIONAL:
+ arg_varargs = param.name
+ elif param.kind is param.VAR_KEYWORD:
+ arg_varkw = param.name
+ if param.default is not param.empty:
+ arg_defaults.append(param.default)
+ if inspect.ismethod(func):
+ # First argument is 'self', it has been removed by Python
+ # we need to add it back:
+ args = [func.__self__, ] + args
+ # func is an instance method, inspect.signature(func) does not
+ # include self, we need to fetch it from the class method, i.e
+ # func.__func__
+ class_method_sig = inspect.signature(func.__func__)
+ self_name = next(iter(class_method_sig.parameters))
+ arg_names = [self_name] + arg_names
+ # XXX: Maybe I need an inspect.isbuiltin to detect C-level methods, such
+ # as on ndarrays.
+
+ _, name = get_func_name(func, resolv_alias=False)
+ arg_dict = dict()
+ arg_position = -1
+ for arg_position, arg_name in enumerate(arg_names):
+ if arg_position < len(args):
+ # Positional argument or keyword argument given as positional
+ if arg_name not in arg_kwonlyargs:
+ arg_dict[arg_name] = args[arg_position]
+ else:
+ raise ValueError(
+ "Keyword-only parameter '%s' was passed as "
+ 'positional parameter for %s:\n'
+ ' %s was called.'
+ % (arg_name,
+ _signature_str(name, arg_sig),
+ _function_called_str(name, args, kwargs))
+ )
+
+ else:
+ position = arg_position - len(arg_names)
+ if arg_name in kwargs:
+ arg_dict[arg_name] = kwargs[arg_name]
+ else:
+ try:
+ arg_dict[arg_name] = arg_defaults[position]
+ except (IndexError, KeyError) as e:
+ # Missing argument
+ raise ValueError(
+ 'Wrong number of arguments for %s:\n'
+ ' %s was called.'
+ % (_signature_str(name, arg_sig),
+ _function_called_str(name, args, kwargs))
+ ) from e
+
+ varkwargs = dict()
+ for arg_name, arg_value in sorted(kwargs.items()):
+ if arg_name in arg_dict:
+ arg_dict[arg_name] = arg_value
+ elif arg_varkw is not None:
+ varkwargs[arg_name] = arg_value
+ else:
+ raise TypeError("Ignore list for %s() contains an unexpected "
+ "keyword argument '%s'" % (name, arg_name))
+
+ if arg_varkw is not None:
+ arg_dict['**'] = varkwargs
+ if arg_varargs is not None:
+ varargs = args[arg_position + 1:]
+ arg_dict['*'] = varargs
+
+ # Now remove the arguments to be ignored
+ for item in ignore_lst:
+ if item in arg_dict:
+ arg_dict.pop(item)
+ else:
+ raise ValueError("Ignore list: argument '%s' is not defined for "
+ "function %s"
+ % (item,
+ _signature_str(name, arg_sig))
+ )
+ # XXX: Return a sorted list of pairs?
+ return arg_dict
+
+
+def _format_arg(arg):
+ formatted_arg = pformat(arg, indent=2)
+ if len(formatted_arg) > 1500:
+ formatted_arg = '%s...' % formatted_arg[:700]
+ return formatted_arg
+
+
+def format_signature(func, *args, **kwargs):
+ # XXX: Should this use inspect.formatargvalues/formatargspec?
+ module, name = get_func_name(func)
+ module = [m for m in module if m]
+ if module:
+ module.append(name)
+ module_path = '.'.join(module)
+ else:
+ module_path = name
+ arg_str = list()
+ previous_length = 0
+ for arg in args:
+ formatted_arg = _format_arg(arg)
+ if previous_length > 80:
+ formatted_arg = '\n%s' % formatted_arg
+ previous_length = len(formatted_arg)
+ arg_str.append(formatted_arg)
+ arg_str.extend(['%s=%s' % (v, _format_arg(i)) for v, i in kwargs.items()])
+ arg_str = ', '.join(arg_str)
+
+ signature = '%s(%s)' % (name, arg_str)
+ return module_path, signature
-def format_call(func, args, kwargs, object_name='Memory'):
+def format_call(func, args, kwargs, object_name="Memory"):
""" Returns a nicely formatted statement displaying the function
call with the given arguments.
"""
- pass
+ path, signature = format_signature(func, *args, **kwargs)
+ msg = '%s\n[%s] Calling %s...\n%s' % (80 * '_', object_name,
+ path, signature)
+ return msg
+ # XXX: Not using logging framework
+ # self.debug(msg)
diff --git a/joblib/hashing.py b/joblib/hashing.py
index 7f57b88..6c081f0 100644
--- a/joblib/hashing.py
+++ b/joblib/hashing.py
@@ -2,6 +2,11 @@
Fast cryptographic hash of Python objects, with a special case for fast
hashing of numpy arrays.
"""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2009 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
import pickle
import hashlib
import sys
@@ -9,6 +14,8 @@ import types
import struct
import io
import decimal
+
+
Pickler = pickle._Pickler
@@ -16,12 +23,18 @@ class _ConsistentSet(object):
""" Class used to ensure the hash of Sets is preserved
whatever the order of its items.
"""
-
def __init__(self, set_sequence):
+ # Forces order of elements in set to ensure consistent hash.
try:
+ # Trying first to order the set assuming the type of elements is
+ # consistent and orderable.
+ # This fails on python 3 when elements are unorderable
+ # but we keep it in a try as it's faster.
self._sequence = sorted(set_sequence)
except (TypeError, decimal.InvalidOperation):
- self._sequence = sorted(hash(e) for e in set_sequence)
+ # If elements are unorderable, sorting them using their hash.
+ # This is slower but works in any case.
+ self._sequence = sorted((hash(e) for e in set_sequence))
class _MyHash(object):
@@ -38,14 +51,103 @@ class Hasher(Pickler):
def __init__(self, hash_name='md5'):
self.stream = io.BytesIO()
+ # By default we want a pickle protocol that only changes with
+ # the major python version and not the minor one
protocol = 3
Pickler.__init__(self, self.stream, protocol=protocol)
+ # Initialise the hash obj
self._hash = hashlib.new(hash_name)
+
+ def hash(self, obj, return_digest=True):
+ try:
+ self.dump(obj)
+ except pickle.PicklingError as e:
+ e.args += ('PicklingError while hashing %r: %r' % (obj, e),)
+ raise
+ dumps = self.stream.getvalue()
+ self._hash.update(dumps)
+ if return_digest:
+ return self._hash.hexdigest()
+
+ def save(self, obj):
+ if isinstance(obj, (types.MethodType, type({}.pop))):
+ # the Pickler cannot pickle instance methods; here we decompose
+ # them into components that make them uniquely identifiable
+ if hasattr(obj, '__func__'):
+ func_name = obj.__func__.__name__
+ else:
+ func_name = obj.__name__
+ inst = obj.__self__
+ if type(inst) is type(pickle):
+ obj = _MyHash(func_name, inst.__name__)
+ elif inst is None:
+ # type(None) or type(module) do not pickle
+ obj = _MyHash(func_name, inst)
+ else:
+ cls = obj.__self__.__class__
+ obj = _MyHash(func_name, inst, cls)
+ Pickler.save(self, obj)
+
+ def memoize(self, obj):
+ # We want hashing to be sensitive to value instead of reference.
+ # For example we want ['aa', 'aa'] and ['aa', 'aaZ'[:2]]
+ # to hash to the same value and that's why we disable memoization
+ # for strings
+ if isinstance(obj, (bytes, str)):
+ return
+ Pickler.memoize(self, obj)
+
+ # The dispatch table of the pickler is not accessible in Python
+ # 3, as these lines are only bugware for IPython, we skip them.
+ def save_global(self, obj, name=None, pack=struct.pack):
+ # We have to override this method in order to deal with objects
+ # defined interactively in IPython that are not injected in
+ # __main__
+ kwargs = dict(name=name, pack=pack)
+ del kwargs['pack']
+ try:
+ Pickler.save_global(self, obj, **kwargs)
+ except pickle.PicklingError:
+ Pickler.save_global(self, obj, **kwargs)
+ module = getattr(obj, "__module__", None)
+ if module == '__main__':
+ my_name = name
+ if my_name is None:
+ my_name = obj.__name__
+ mod = sys.modules[module]
+ if not hasattr(mod, my_name):
+ # IPython doesn't inject the variables define
+ # interactively in __main__
+ setattr(mod, my_name, obj)
+
dispatch = Pickler.dispatch.copy()
+ # builtin
dispatch[type(len)] = save_global
+ # type
dispatch[type(object)] = save_global
+ # classobj
dispatch[type(Pickler)] = save_global
+ # function
dispatch[type(pickle.dump)] = save_global
+
+ def _batch_setitems(self, items):
+ # forces order of keys in dict to ensure consistent hash.
+ try:
+ # Trying first to compare dict assuming the type of keys is
+ # consistent and orderable.
+ # This fails on python 3 when keys are unorderable
+ # but we keep it in a try as it's faster.
+ Pickler._batch_setitems(self, iter(sorted(items)))
+ except TypeError:
+ # If keys are unorderable, sorting them using their hash. This is
+ # slower but works in any case.
+ Pickler._batch_setitems(self, iter(sorted((hash(k), v)
+ for k, v in items)))
+
+ def save_set(self, set_items):
+ # forces order of items in Set to ensure consistent hash
+ Pickler.save(self, _ConsistentSet(set_items))
+
dispatch[type(set())] = save_set
@@ -65,6 +167,7 @@ class NumpyHasher(Hasher):
"""
self.coerce_mmap = coerce_mmap
Hasher.__init__(self, hash_name=hash_name)
+ # delayed import of numpy, to avoid tight coupling
import numpy as np
self.np = np
if hasattr(np, 'getbuffer'):
@@ -77,7 +180,65 @@ class NumpyHasher(Hasher):
than pickling them. Off course, this is a total abuse of
the Pickler class.
"""
- pass
+ if isinstance(obj, self.np.ndarray) and not obj.dtype.hasobject:
+ # Compute a hash of the object
+ # The update function of the hash requires a c_contiguous buffer.
+ if obj.shape == ():
+ # 0d arrays need to be flattened because viewing them as bytes
+ # raises a ValueError exception.
+ obj_c_contiguous = obj.flatten()
+ elif obj.flags.c_contiguous:
+ obj_c_contiguous = obj
+ elif obj.flags.f_contiguous:
+ obj_c_contiguous = obj.T
+ else:
+ # Cater for non-single-segment arrays: this creates a
+ # copy, and thus alleviates this issue.
+ # XXX: There might be a more efficient way of doing this
+ obj_c_contiguous = obj.flatten()
+
+ # memoryview is not supported for some dtypes, e.g. datetime64, see
+ # https://github.com/numpy/numpy/issues/4983. The
+ # workaround is to view the array as bytes before
+ # taking the memoryview.
+ self._hash.update(
+ self._getbuffer(obj_c_contiguous.view(self.np.uint8)))
+
+ # We store the class, to be able to distinguish between
+ # Objects with the same binary content, but different
+ # classes.
+ if self.coerce_mmap and isinstance(obj, self.np.memmap):
+ # We don't make the difference between memmap and
+ # normal ndarrays, to be able to reload previously
+ # computed results with memmap.
+ klass = self.np.ndarray
+ else:
+ klass = obj.__class__
+ # We also return the dtype and the shape, to distinguish
+ # different views on the same data with different dtypes.
+
+ # The object will be pickled by the pickler hashed at the end.
+ obj = (klass, ('HASHED', obj.dtype, obj.shape, obj.strides))
+ elif isinstance(obj, self.np.dtype):
+ # numpy.dtype consistent hashing is tricky to get right. This comes
+ # from the fact that atomic np.dtype objects are interned:
+ # ``np.dtype('f4') is np.dtype('f4')``. The situation is
+ # complicated by the fact that this interning does not resist a
+ # simple pickle.load/dump roundtrip:
+ # ``pickle.loads(pickle.dumps(np.dtype('f4'))) is not
+ # np.dtype('f4') Because pickle relies on memoization during
+ # pickling, it is easy to
+ # produce different hashes for seemingly identical objects, such as
+ # ``[np.dtype('f4'), np.dtype('f4')]``
+ # and ``[np.dtype('f4'), pickle.loads(pickle.dumps('f4'))]``.
+ # To prevent memoization from interfering with hashing, we isolate
+ # the serialization (and thus the pickle memoization) of each dtype
+ # using each time a different ``pickle.dumps`` call unrelated to
+ # the current Hasher instance.
+ self._hash.update("_HASHED_DTYPE".encode('utf-8'))
+ self._hash.update(pickle.dumps(obj))
+ return
+ Hasher.save(self, obj)
def hash(obj, hash_name='md5', coerce_mmap=False):
@@ -92,4 +253,13 @@ def hash(obj, hash_name='md5', coerce_mmap=False):
coerce_mmap: boolean
Make no difference between np.memmap and np.ndarray
"""
- pass
+ valid_hash_names = ('md5', 'sha1')
+ if hash_name not in valid_hash_names:
+ raise ValueError("Valid options for 'hash_name' are {}. "
+ "Got hash_name={!r} instead."
+ .format(valid_hash_names, hash_name))
+ if 'numpy' in sys.modules:
+ hasher = NumpyHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
+ else:
+ hasher = Hasher(hash_name=hash_name)
+ return hasher.hash(obj)
diff --git a/joblib/logger.py b/joblib/logger.py
index 4991108..cf9d258 100644
--- a/joblib/logger.py
+++ b/joblib/logger.py
@@ -3,13 +3,20 @@ Helpers for logging.
This module needs much love to become useful.
"""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2008 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
from __future__ import print_function
+
import time
import sys
import os
import shutil
import logging
import pprint
+
from .disk import mkdirp
@@ -18,9 +25,41 @@ def _squeeze_time(t):
stat files. This is needed to make results similar to timings under
Unix, for tests
"""
- pass
+ if sys.platform.startswith('win'):
+ return max(0, t - .1)
+ else:
+ return t
+
+
+def format_time(t):
+ t = _squeeze_time(t)
+ return "%.1fs, %.1fmin" % (t, t / 60.)
+
+def short_format_time(t):
+ t = _squeeze_time(t)
+ if t > 60:
+ return "%4.1fmin" % (t / 60.)
+ else:
+ return " %5.1fs" % (t)
+
+def pformat(obj, indent=0, depth=3):
+ if 'numpy' in sys.modules:
+ import numpy as np
+ print_options = np.get_printoptions()
+ np.set_printoptions(precision=6, threshold=64, edgeitems=1)
+ else:
+ print_options = None
+ out = pprint.pformat(obj, depth=depth, indent=indent)
+ if print_options:
+ np.set_printoptions(**print_options)
+ return out
+
+
+###############################################################################
+# class `Logger`
+###############################################################################
class Logger(object):
""" Base class for logging messages.
"""
@@ -37,11 +76,24 @@ class Logger(object):
self.depth = depth
self._name = name if name else 'joblib'
+ def warn(self, msg):
+ logging.getLogger(self._name).warning("[%s]: %s" % (self, msg))
+
+ def info(self, msg):
+ logging.info("[%s]: %s" % (self, msg))
+
+ def debug(self, msg):
+ # XXX: This conflicts with the debug flag used in children class
+ logging.getLogger(self._name).debug("[%s]: %s" % (self, msg))
+
def format(self, obj, indent=0):
"""Return the formatted representation of the object."""
- pass
+ return pformat(obj, indent=indent, depth=self.depth)
+###############################################################################
+# class `PrintTime`
+###############################################################################
class PrintTime(object):
""" Print and log messages while keeping track of time.
"""
@@ -49,6 +101,7 @@ class PrintTime(object):
def __init__(self, logfile=None, logdir=None):
if logfile is not None and logdir is not None:
raise ValueError('Cannot specify both logfile and logdir')
+ # XXX: Need argument docstring
self.last_time = time.time()
self.start_time = self.last_time
if logdir is not None:
@@ -57,25 +110,30 @@ class PrintTime(object):
if logfile is not None:
mkdirp(os.path.dirname(logfile))
if os.path.exists(logfile):
+ # Rotate the logs
for i in range(1, 9):
try:
- shutil.move(logfile + '.%i' % i, logfile + '.%i' %
- (i + 1))
- except:
- """No reason failing here"""
+ shutil.move(logfile + '.%i' % i,
+ logfile + '.%i' % (i + 1))
+ except: # noqa: E722
+ "No reason failing here"
+ # Use a copy rather than a move, so that a process
+ # monitoring this file does not get lost.
try:
shutil.copy(logfile, logfile + '.1')
- except:
- """No reason failing here"""
+ except: # noqa: E722
+ "No reason failing here"
try:
with open(logfile, 'w') as logfile:
logfile.write('\nLogging joblib python script\n')
logfile.write('\n---%s---\n' % time.ctime(self.last_time))
- except:
+ except: # noqa: E722
""" Multiprocessing writing to files can create race
conditions. Rather fail silently than crash the
computation.
"""
+ # XXX: We actually need a debug flag to disable this
+ # silent failure.
def __call__(self, msg='', total=False):
""" Print the time elapsed between the last call and the current
@@ -83,19 +141,22 @@ class PrintTime(object):
"""
if not total:
time_lapse = time.time() - self.last_time
- full_msg = '%s: %s' % (msg, format_time(time_lapse))
+ full_msg = "%s: %s" % (msg, format_time(time_lapse))
else:
+ # FIXME: Too much logic duplicated
time_lapse = time.time() - self.start_time
- full_msg = '%s: %.2fs, %.1f min' % (msg, time_lapse, time_lapse /
- 60)
+ full_msg = "%s: %.2fs, %.1f min" % (msg, time_lapse,
+ time_lapse / 60)
print(full_msg, file=sys.stderr)
if self.logfile is not None:
try:
with open(self.logfile, 'a') as f:
print(full_msg, file=f)
- except:
+ except: # noqa: E722
""" Multiprocessing writing to files can create race
conditions. Rather fail silently than crash the
calculation.
"""
+ # XXX: We actually need a debug flag to disable this
+ # silent failure.
self.last_time = time.time()
diff --git a/joblib/memory.py b/joblib/memory.py
index 14f8456..b83a855 100644
--- a/joblib/memory.py
+++ b/joblib/memory.py
@@ -3,6 +3,12 @@ A context object for caching a function's return value each time it
is called with the same input arguments.
"""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2009 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
+
import asyncio
import datetime
import functools
@@ -18,19 +24,40 @@ import tokenize
import traceback
import warnings
import weakref
+
from . import hashing
-from ._store_backends import CacheWarning
+from ._store_backends import CacheWarning # noqa
from ._store_backends import FileSystemStoreBackend, StoreBackendBase
-from .func_inspect import filter_args, format_call, format_signature, get_func_code, get_func_name
+from .func_inspect import (filter_args, format_call, format_signature,
+ get_func_code, get_func_name)
from .logger import Logger, format_time, pformat
-FIRST_LINE_TEXT = '# first line:'
+
+FIRST_LINE_TEXT = "# first line:"
+
+# TODO: The following object should have a data store object as a sub
+# object, and the interface to persist and query should be separated in
+# the data store.
+#
+# This would enable creating 'Memory' objects with a different logic for
+# pickling that would simply span a MemorizedFunc with the same
+# store (or do we want to copy it to avoid cross-talks?), for instance to
+# implement HDF5 pickling.
+
+# TODO: Same remark for the logger, and probably use the Python logging
+# mechanism.
def extract_first_line(func_code):
""" Extract the first line information from the function code
text if available.
"""
- pass
+ if func_code.startswith(FIRST_LINE_TEXT):
+ func_code = func_code.split('\n')
+ first_line = int(func_code[0][len(FIRST_LINE_TEXT):])
+ func_code = '\n'.join(func_code[1:])
+ else:
+ first_line = -1
+ return func_code, first_line
class JobLibCollisionWarning(UserWarning):
@@ -59,22 +86,72 @@ def register_store_backend(backend_name, backend):
The name of a class that implements the StoreBackendBase interface.
"""
- pass
+ if not isinstance(backend_name, str):
+ raise ValueError("Store backend name should be a string, "
+ "'{0}' given.".format(backend_name))
+ if backend is None or not issubclass(backend, StoreBackendBase):
+ raise ValueError("Store backend should inherit "
+ "StoreBackendBase, "
+ "'{0}' given.".format(backend))
+
+ _STORE_BACKENDS[backend_name] = backend
def _store_backend_factory(backend, location, verbose=0, backend_options=None):
"""Return the correct store object for the given location."""
- pass
+ if backend_options is None:
+ backend_options = {}
+
+ if isinstance(location, pathlib.Path):
+ location = str(location)
+
+ if isinstance(location, StoreBackendBase):
+ return location
+ elif isinstance(location, str):
+ obj = None
+ location = os.path.expanduser(location)
+ # The location is not a local file system, we look in the
+ # registered backends if there's one matching the given backend
+ # name.
+ for backend_key, backend_obj in _STORE_BACKENDS.items():
+ if backend == backend_key:
+ obj = backend_obj()
+
+ # By default, we assume the FileSystemStoreBackend can be used if no
+ # matching backend could be found.
+ if obj is None:
+ raise TypeError('Unknown location {0} or backend {1}'.format(
+ location, backend))
+
+ # The store backend is configured with the extra named parameters,
+ # some of them are specific to the underlying store backend.
+ obj.configure(location, verbose=verbose,
+ backend_options=backend_options)
+ return obj
+ elif location is not None:
+ warnings.warn(
+ "Instantiating a backend using a {} as a location is not "
+ "supported by joblib. Returning None instead.".format(
+ location.__class__.__name__), UserWarning)
+
+ return None
def _build_func_identifier(func):
"""Build a roughly unique identifier for the cached function."""
- pass
+ modules, funcname = get_func_name(func)
+ # We reuse historical fs-like way of building a function identifier
+ return os.path.join(*modules, funcname)
+# An in-memory store to avoid looking at the disk-based function
+# source code to check if a function definition has changed
_FUNCTION_HASHES = weakref.WeakKeyDictionary()
+###############################################################################
+# class `MemorizedResult`
+###############################################################################
class MemorizedResult(Logger):
"""Object representing a cached value.
@@ -105,33 +182,69 @@ class MemorizedResult(Logger):
timestamp, metadata: string
for internal use only.
"""
-
def __init__(self, location, call_id, backend='local', mmap_mode=None,
- verbose=0, timestamp=None, metadata=None):
+ verbose=0, timestamp=None, metadata=None):
Logger.__init__(self)
self._call_id = call_id
self.store_backend = _store_backend_factory(backend, location,
- verbose=verbose)
+ verbose=verbose)
self.mmap_mode = mmap_mode
+
if metadata is not None:
self.metadata = metadata
else:
self.metadata = self.store_backend.get_metadata(self._call_id)
+
self.duration = self.metadata.get('duration', None)
self.verbose = verbose
self.timestamp = timestamp
+ @property
+ def func(self):
+ return self.func_id
+
+ @property
+ def func_id(self):
+ return self._call_id[0]
+
+ @property
+ def args_id(self):
+ return self._call_id[1]
+
+ @property
+ def argument_hash(self):
+ warnings.warn(
+ "The 'argument_hash' attribute has been deprecated in version "
+ "0.12 and will be removed in version 0.14.\n"
+ "Use `args_id` attribute instead.",
+ DeprecationWarning, stacklevel=2)
+ return self.args_id
+
def get(self):
"""Read value from cache and return it."""
- pass
+ try:
+ return self.store_backend.load_item(
+ self._call_id,
+ timestamp=self.timestamp,
+ metadata=self.metadata,
+ verbose=self.verbose
+ )
+ except ValueError as exc:
+ new_exc = KeyError(
+ "Error while trying to load a MemorizedResult's value. "
+ "It seems that this folder is corrupted : {}".format(
+ os.path.join(self.store_backend.location, *self._call_id)))
+ raise new_exc from exc
def clear(self):
"""Clear value from cache"""
- pass
+ self.store_backend.clear_item(self._call_id)
def __repr__(self):
- return '{}(location="{}", func="{}", args_id="{}")'.format(self.
- __class__.__name__, self.store_backend.location, *self._call_id)
+ return '{}(location="{}", func="{}", args_id="{}")'.format(
+ self.__class__.__name__, self.store_backend.location,
+ *self._call_id
+ )
def __getstate__(self):
state = self.__dict__.copy()
@@ -144,27 +257,42 @@ class NotMemorizedResult(object):
This class is a replacement for MemorizedResult when there is no cache.
"""
- __slots__ = 'value', 'valid'
+ __slots__ = ('value', 'valid')
def __init__(self, value):
self.value = value
self.valid = True
+ def get(self):
+ if self.valid:
+ return self.value
+ else:
+ raise KeyError("No value stored.")
+
+ def clear(self):
+ self.valid = False
+ self.value = None
+
def __repr__(self):
if self.valid:
- return '{class_name}({value})'.format(class_name=self.__class__
- .__name__, value=pformat(self.value))
+ return ('{class_name}({value})'
+ .format(class_name=self.__class__.__name__,
+ value=pformat(self.value)))
else:
return self.__class__.__name__ + ' with no value'
+ # __getstate__ and __setstate__ are required because of __slots__
def __getstate__(self):
- return {'valid': self.valid, 'value': self.value}
+ return {"valid": self.valid, "value": self.value}
def __setstate__(self, state):
- self.valid = state['valid']
- self.value = state['value']
+ self.valid = state["valid"]
+ self.value = state["value"]
+###############################################################################
+# class `NotMemorizedFunc`
+###############################################################################
class NotMemorizedFunc(object):
"""No-op object decorating a function.
@@ -176,21 +304,41 @@ class NotMemorizedFunc(object):
func: callable
Original undecorated function.
"""
-
+ # Should be a light as possible (for speed)
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
+ def call_and_shelve(self, *args, **kwargs):
+ return NotMemorizedResult(self.func(*args, **kwargs))
+
def __repr__(self):
return '{0}(func={1})'.format(self.__class__.__name__, self.func)
+ def clear(self, warn=True):
+ # Argument "warn" is for compatibility with MemorizedFunc.clear
+ pass
+ def call(self, *args, **kwargs):
+ return self.func(*args, **kwargs), {}
+
+ def check_call_in_cache(self, *args, **kwargs):
+ return False
+
+
+###############################################################################
+# class `AsyncNotMemorizedFunc`
+###############################################################################
class AsyncNotMemorizedFunc(NotMemorizedFunc):
- pass
+ async def call_and_shelve(self, *args, **kwargs):
+ return NotMemorizedResult(await self.func(*args, **kwargs))
+###############################################################################
+# class `MemorizedFunc`
+###############################################################################
class MemorizedFunc(Logger):
"""Callable object decorating a function for caching its return value
each time it is called.
@@ -236,10 +384,13 @@ class MemorizedFunc(Logger):
argument. If it returns True, the cached result is returned, else the
cache for these arguments is cleared and the result is recomputed.
"""
+ # ------------------------------------------------------------------------
+ # Public interface
+ # ------------------------------------------------------------------------
def __init__(self, func, location, backend='local', ignore=None,
- mmap_mode=None, compress=False, verbose=1, timestamp=None,
- cache_validation_callback=None):
+ mmap_mode=None, compress=False, verbose=1, timestamp=None,
+ cache_validation_callback=None):
Logger.__init__(self)
self.mmap_mode = mmap_mode
self.compress = compress
@@ -248,23 +399,34 @@ class MemorizedFunc(Logger):
self.func_id = _build_func_identifier(func)
self.ignore = ignore if ignore is not None else []
self._verbose = verbose
+
+ # retrieve store object from backend type and location.
self.store_backend = _store_backend_factory(backend, location,
- verbose=verbose, backend_options=dict(compress=compress,
- mmap_mode=mmap_mode))
+ verbose=verbose,
+ backend_options=dict(
+ compress=compress,
+ mmap_mode=mmap_mode),
+ )
if self.store_backend is not None:
+ # Create func directory on demand.
self.store_backend.store_cached_func_code([self.func_id])
+
self.timestamp = timestamp if timestamp is not None else time.time()
try:
functools.update_wrapper(self, func)
except Exception:
- pass
+ pass # Objects like ufunc don't like that
if inspect.isfunction(func):
doc = pydoc.TextDoc().document(func)
+ # Remove blank line
doc = doc.replace('\n', '\n\n', 1)
+ # Strip backspace-overprints for compatibility with autodoc
doc = re.sub('\x08.', '', doc)
else:
+ # Pydoc does a poor job on other objects
doc = func.__doc__
self.__doc__ = 'Memoized version of %s' % doc
+
self._func_code_info = None
self._func_code_id = None
@@ -279,7 +441,22 @@ class MemorizedFunc(Logger):
Returns True if the function call is in cache and can be used, and
returns False otherwise.
"""
- pass
+ # Check if the code of the function has changed
+ if not self._check_previous_func_code(stacklevel=4):
+ return False
+
+ # Check if this specific call is in the cache
+ if not self.store_backend.contains_item(call_id):
+ return False
+
+ # Call the user defined cache validation callback
+ metadata = self.store_backend.get_metadata(call_id)
+ if (self.cache_validation_callback is not None and
+ not self.cache_validation_callback(metadata)):
+ self.store_backend.clear_item(call_id)
+ return False
+
+ return True
def _cached_call(self, args, kwargs, shelving):
"""Call wrapped function and cache result, or read cache if available.
@@ -303,7 +480,79 @@ class MemorizedFunc(Logger):
MemorizedResult reference to the value if shelving is true.
metadata: dict containing the metadata associated with the call.
"""
- pass
+ args_id = self._get_args_id(*args, **kwargs)
+ call_id = (self.func_id, args_id)
+ _, func_name = get_func_name(self.func)
+ func_info = self.store_backend.get_cached_func_info([self.func_id])
+ location = func_info['location']
+
+ if self._verbose >= 20:
+ logging.basicConfig(level=logging.INFO)
+ _, signature = format_signature(self.func, *args, **kwargs)
+ self.info(
+ textwrap.dedent(
+ f"""
+ Querying {func_name} with signature
+ {signature}.
+
+ (argument hash {args_id})
+
+ The store location is {location}.
+ """
+ )
+ )
+
+ # Compare the function code with the previous to see if the
+ # function code has changed and check if the results are present in
+ # the cache.
+ if self._is_in_cache_and_valid(call_id):
+ if shelving:
+ return self._get_memorized_result(call_id), {}
+
+ try:
+ start_time = time.time()
+ output = self._load_item(call_id)
+ if self._verbose > 4:
+ self._print_duration(time.time() - start_time,
+ context='cache loaded ')
+ return output, {}
+ except Exception:
+ # XXX: Should use an exception logger
+ _, signature = format_signature(self.func, *args, **kwargs)
+ self.warn('Exception while loading results for '
+ '{}\n {}'.format(signature, traceback.format_exc()))
+
+ if self._verbose > 10:
+ self.warn(
+ f"Computing func {func_name}, argument hash {args_id} "
+ f"in location {location}"
+ )
+
+ # Returns the output but not the metadata
+ return self._call(call_id, args, kwargs, shelving)
+
+ @property
+ def func_code_info(self):
+ # 3-tuple property containing: the function source code, source file,
+ # and first line of the code inside the source file
+ if hasattr(self.func, '__code__'):
+ if self._func_code_id is None:
+ self._func_code_id = id(self.func.__code__)
+ elif id(self.func.__code__) != self._func_code_id:
+ # Be robust to dynamic reassignments of self.func.__code__
+ self._func_code_info = None
+
+ if self._func_code_info is None:
+ # Cache the source code of self.func . Provided that get_func_code
+ # (which should be called once on self) gets called in the process
+ # in which self.func was defined, this caching mechanism prevents
+ # undesired cache clearing when the cached function is called in
+ # an environment where the introspection utilities get_func_code
+ # relies on do not work (typically, in joblib child processes).
+ # See #1035 for more info
+ # TODO (pierreglaser): do the same with get_func_name?
+ self._func_code_info = get_func_code(self.func)
+ return self._func_code_info
def call_and_shelve(self, *args, **kwargs):
"""Call wrapped function, cache result and return a reference.
@@ -320,16 +569,27 @@ class MemorizedFunc(Logger):
class "NotMemorizedResult" is used when there is no cache
activated (e.g. location=None in Memory).
"""
- pass
+ # Return the wrapped output, without the metadata
+ return self._cached_call(args, kwargs, shelving=True)[0]
def __call__(self, *args, **kwargs):
+ # Return the output, without the metadata
return self._cached_call(args, kwargs, shelving=False)[0]
def __getstate__(self):
+ # Make sure self.func's source is introspected prior to being pickled -
+ # code introspection utilities typically do not work inside child
+ # processes
_ = self.func_code_info
+
+ # We don't store the timestamp when pickling, to avoid the hash
+ # depending from it.
state = self.__dict__.copy()
state['timestamp'] = None
+
+ # Invalidate the code id as id(obj) will be different in the child
state['_func_code_id'] = None
+
return state
def check_call_in_cache(self, *args, **kwargs):
@@ -344,31 +604,140 @@ class MemorizedFunc(Logger):
Whether or not the result of the function has been cached
for the input arguments that have been passed.
"""
- pass
+ call_id = (self.func_id, self._get_args_id(*args, **kwargs))
+ return self.store_backend.contains_item(call_id)
+
+ # ------------------------------------------------------------------------
+ # Private interface
+ # ------------------------------------------------------------------------
def _get_args_id(self, *args, **kwargs):
"""Return the input parameter hash of a result."""
- pass
+ return hashing.hash(filter_args(self.func, self.ignore, args, kwargs),
+ coerce_mmap=self.mmap_mode is not None)
def _hash_func(self):
"""Hash a function to key the online cache"""
- pass
+ func_code_h = hash(getattr(self.func, '__code__', None))
+ return id(self.func), hash(self.func), func_code_h
def _write_func_code(self, func_code, first_line):
""" Write the function code and the filename to a file.
"""
- pass
+ # We store the first line because the filename and the function
+ # name is not always enough to identify a function: people
+ # sometimes have several functions named the same way in a
+ # file. This is bad practice, but joblib should be robust to bad
+ # practice.
+ func_code = u'%s %i\n%s' % (FIRST_LINE_TEXT, first_line, func_code)
+ self.store_backend.store_cached_func_code([self.func_id], func_code)
+
+ # Also store in the in-memory store of function hashes
+ is_named_callable = (hasattr(self.func, '__name__') and
+ self.func.__name__ != '<lambda>')
+ if is_named_callable:
+ # Don't do this for lambda functions or strange callable
+ # objects, as it ends up being too fragile
+ func_hash = self._hash_func()
+ try:
+ _FUNCTION_HASHES[self.func] = func_hash
+ except TypeError:
+ # Some callable are not hashable
+ pass
def _check_previous_func_code(self, stacklevel=2):
"""
stacklevel is the depth a which this function is called, to
issue useful warnings to the user.
"""
- pass
+ # First check if our function is in the in-memory store.
+ # Using the in-memory store not only makes things faster, but it
+ # also renders us robust to variations of the files when the
+ # in-memory version of the code does not vary
+ try:
+ if self.func in _FUNCTION_HASHES:
+ # We use as an identifier the id of the function and its
+ # hash. This is more likely to falsely change than have hash
+ # collisions, thus we are on the safe side.
+ func_hash = self._hash_func()
+ if func_hash == _FUNCTION_HASHES[self.func]:
+ return True
+ except TypeError:
+ # Some callables are not hashable
+ pass
+
+ # Here, we go through some effort to be robust to dynamically
+ # changing code and collision. We cannot inspect.getsource
+ # because it is not reliable when using IPython's magic "%run".
+ func_code, source_file, first_line = self.func_code_info
+ try:
+ old_func_code, old_first_line = extract_first_line(
+ self.store_backend.get_cached_func_code([self.func_id]))
+ except (IOError, OSError): # some backend can also raise OSError
+ self._write_func_code(func_code, first_line)
+ return False
+ if old_func_code == func_code:
+ return True
+
+ # We have differing code, is this because we are referring to
+ # different functions, or because the function we are referring to has
+ # changed?
+
+ _, func_name = get_func_name(self.func, resolv_alias=False,
+ win_characters=False)
+ if old_first_line == first_line == -1 or func_name == '<lambda>':
+ if not first_line == -1:
+ func_description = ("{0} ({1}:{2})"
+ .format(func_name, source_file,
+ first_line))
+ else:
+ func_description = func_name
+ warnings.warn(JobLibCollisionWarning(
+ "Cannot detect name collisions for function '{0}'"
+ .format(func_description)), stacklevel=stacklevel)
+
+ # Fetch the code at the old location and compare it. If it is the
+ # same than the code store, we have a collision: the code in the
+ # file has not changed, but the name we have is pointing to a new
+ # code block.
+ if not old_first_line == first_line and source_file is not None:
+ if os.path.exists(source_file):
+ _, func_name = get_func_name(self.func, resolv_alias=False)
+ num_lines = len(func_code.split('\n'))
+ with tokenize.open(source_file) as f:
+ on_disk_func_code = f.readlines()[
+ old_first_line - 1:old_first_line - 1 + num_lines - 1]
+ on_disk_func_code = ''.join(on_disk_func_code)
+ possible_collision = (on_disk_func_code.rstrip() ==
+ old_func_code.rstrip())
+ else:
+ possible_collision = source_file.startswith('<doctest ')
+ if possible_collision:
+ warnings.warn(JobLibCollisionWarning(
+ 'Possible name collisions between functions '
+ "'%s' (%s:%i) and '%s' (%s:%i)" %
+ (func_name, source_file, old_first_line,
+ func_name, source_file, first_line)),
+ stacklevel=stacklevel)
+
+ # The function has changed, wipe the cache directory.
+ # XXX: Should be using warnings, and giving stacklevel
+ if self._verbose > 10:
+ _, func_name = get_func_name(self.func, resolv_alias=False)
+ self.warn("Function {0} (identified by {1}) has changed"
+ ".".format(func_name, self.func_id))
+ self.clear(warn=True)
+ return False
def clear(self, warn=True):
"""Empty the function's cache."""
- pass
+ func_id = self.func_id
+ if self._verbose > 0 and warn:
+ self.warn("Clearing function cache identified by %s" % func_id)
+ self.store_backend.clear_path([func_id, ])
+
+ func_code, _, first_line = self.func_code_info
+ self._write_func_code(func_code, first_line)
def call(self, *args, **kwargs):
"""Force the execution of the function with the given arguments.
@@ -390,10 +759,40 @@ class MemorizedFunc(Logger):
metadata : dict
The metadata associated with the call.
"""
- pass
+ call_id = (self.func_id, self._get_args_id(*args, **kwargs))
+
+ # Return the output and the metadata
+ return self._call(call_id, args, kwargs)
+
+ def _call(self, call_id, args, kwargs, shelving=False):
+ # Return the output and the metadata
+ self._before_call(args, kwargs)
+ start_time = time.time()
+ output = self.func(*args, **kwargs)
+ return self._after_call(call_id, args, kwargs, shelving,
+ output, start_time)
+
+ def _before_call(self, args, kwargs):
+ if self._verbose > 0:
+ print(format_call(self.func, args, kwargs))
+
+ def _after_call(self, call_id, args, kwargs, shelving, output, start_time):
+ self.store_backend.dump_item(call_id, output, verbose=self._verbose)
+ duration = time.time() - start_time
+ if self._verbose > 0:
+ self._print_duration(duration)
+ metadata = self._persist_input(duration, call_id, args, kwargs)
+ if shelving:
+ return self._get_memorized_result(call_id, metadata), metadata
+
+ if self.mmap_mode is not None:
+ # Memmap the output at the first call to be consistent with
+ # later calls
+ output = self._load_item(call_id, metadata)
+ return output, metadata
def _persist_input(self, duration, call_id, args, kwargs,
- this_duration_limit=0.5):
+ this_duration_limit=0.5):
""" Save a small summary of the call using json format in the
output directory.
@@ -410,22 +809,92 @@ class MemorizedFunc(Logger):
this_duration_limit: float
Max execution time for this function before issuing a warning.
"""
- pass
+ start_time = time.time()
+ argument_dict = filter_args(self.func, self.ignore,
+ args, kwargs)
+
+ input_repr = dict((k, repr(v)) for k, v in argument_dict.items())
+ # This can fail due to race-conditions with multiple
+ # concurrent joblibs removing the file or the directory
+ metadata = {
+ "duration": duration, "input_args": input_repr, "time": start_time,
+ }
+
+ self.store_backend.store_metadata(call_id, metadata)
+
+ this_duration = time.time() - start_time
+ if this_duration > this_duration_limit:
+ # This persistence should be fast. It will not be if repr() takes
+ # time and its output is large, because json.dump will have to
+ # write a large file. This should not be an issue with numpy arrays
+ # for which repr() always output a short representation, but can
+ # be with complex dictionaries. Fixing the problem should be a
+ # matter of replacing repr() above by something smarter.
+ warnings.warn("Persisting input arguments took %.2fs to run."
+ "If this happens often in your code, it can cause "
+ "performance problems "
+ "(results will be correct in all cases). "
+ "The reason for this is probably some large input "
+ "arguments for a wrapped function."
+ % this_duration, stacklevel=5)
+ return metadata
+
+ def _get_memorized_result(self, call_id, metadata=None):
+ return MemorizedResult(self.store_backend, call_id,
+ metadata=metadata, timestamp=self.timestamp,
+ verbose=self._verbose - 1)
+
+ def _load_item(self, call_id, metadata=None):
+ return self.store_backend.load_item(call_id, metadata=metadata,
+ timestamp=self.timestamp,
+ verbose=self._verbose)
+
+ def _print_duration(self, duration, context=''):
+ _, name = get_func_name(self.func)
+ msg = f"{name} {context}- {format_time(duration)}"
+ print(max(0, (80 - len(msg))) * '_' + msg)
+
+ # ------------------------------------------------------------------------
+ # Private `object` interface
+ # ------------------------------------------------------------------------
def __repr__(self):
return '{class_name}(func={func}, location={location})'.format(
- class_name=self.__class__.__name__, func=self.func, location=
- self.store_backend.location)
+ class_name=self.__class__.__name__,
+ func=self.func,
+ location=self.store_backend.location,)
+###############################################################################
+# class `AsyncMemorizedFunc`
+###############################################################################
class AsyncMemorizedFunc(MemorizedFunc):
-
async def __call__(self, *args, **kwargs):
out = self._cached_call(args, kwargs, shelving=False)
out = await out if asyncio.iscoroutine(out) else out
- return out[0]
+ return out[0] # Don't return metadata
+ async def call_and_shelve(self, *args, **kwargs):
+ out = self._cached_call(args, kwargs, shelving=True)
+ out = await out if asyncio.iscoroutine(out) else out
+ return out[0] # Don't return metadata
+ async def call(self, *args, **kwargs):
+ out = super().call(*args, **kwargs)
+ return await out if asyncio.iscoroutine(out) else out
+
+ async def _call(self, call_id, args, kwargs, shelving=False):
+ self._before_call(args, kwargs)
+ start_time = time.time()
+ output = await self.func(*args, **kwargs)
+ return self._after_call(
+ call_id, args, kwargs, shelving, output, start_time
+ )
+
+
+###############################################################################
+# class `Memory`
+###############################################################################
class Memory(Logger):
""" A context object for caching a function's return value each time it
is called with the same input arguments.
@@ -482,35 +951,46 @@ class Memory(Logger):
Contains a dictionary of named parameters used to configure
the store backend.
"""
+ # ------------------------------------------------------------------------
+ # Public interface
+ # ------------------------------------------------------------------------
- def __init__(self, location=None, backend='local', mmap_mode=None,
- compress=False, verbose=1, bytes_limit=None, backend_options=None):
+ def __init__(self, location=None, backend='local',
+ mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
+ backend_options=None):
Logger.__init__(self)
self._verbose = verbose
self.mmap_mode = mmap_mode
self.timestamp = time.time()
if bytes_limit is not None:
warnings.warn(
- 'bytes_limit argument has been deprecated. It will be removed in version 1.5. Please pass its value directly to Memory.reduce_size.'
- , category=DeprecationWarning)
+ "bytes_limit argument has been deprecated. It will be removed "
+ "in version 1.5. Please pass its value directly to "
+ "Memory.reduce_size.",
+ category=DeprecationWarning
+ )
self.bytes_limit = bytes_limit
self.backend = backend
self.compress = compress
if backend_options is None:
backend_options = {}
self.backend_options = backend_options
+
if compress and mmap_mode is not None:
warnings.warn('Compressed results cannot be memmapped',
- stacklevel=2)
+ stacklevel=2)
+
self.location = location
if isinstance(location, str):
location = os.path.join(location, 'joblib')
- self.store_backend = _store_backend_factory(backend, location,
- verbose=self._verbose, backend_options=dict(compress=compress,
- mmap_mode=mmap_mode, **backend_options))
+
+ self.store_backend = _store_backend_factory(
+ backend, location, verbose=self._verbose,
+ backend_options=dict(compress=compress, mmap_mode=mmap_mode,
+ **backend_options))
def cache(self, func=None, ignore=None, verbose=None, mmap_mode=False,
- cache_validation_callback=None):
+ cache_validation_callback=None):
""" Decorates the given function func to only compute its return
value for input arguments not cached on disk.
@@ -543,12 +1023,55 @@ class Memory(Logger):
methods for cache lookup and management. See the
documentation for :class:`joblib.memory.MemorizedFunc`.
"""
- pass
+ if (cache_validation_callback is not None and
+ not callable(cache_validation_callback)):
+ raise ValueError(
+ "cache_validation_callback needs to be callable. "
+ f"Got {cache_validation_callback}."
+ )
+ if func is None:
+ # Partial application, to be able to specify extra keyword
+ # arguments in decorators
+ return functools.partial(
+ self.cache, ignore=ignore,
+ mmap_mode=mmap_mode,
+ verbose=verbose,
+ cache_validation_callback=cache_validation_callback
+ )
+ if self.store_backend is None:
+ cls = (AsyncNotMemorizedFunc
+ if asyncio.iscoroutinefunction(func)
+ else NotMemorizedFunc)
+ return cls(func)
+ if verbose is None:
+ verbose = self._verbose
+ if mmap_mode is False:
+ mmap_mode = self.mmap_mode
+ if isinstance(func, MemorizedFunc):
+ func = func.func
+ cls = (AsyncMemorizedFunc
+ if asyncio.iscoroutinefunction(func)
+ else MemorizedFunc)
+ return cls(
+ func, location=self.store_backend, backend=self.backend,
+ ignore=ignore, mmap_mode=mmap_mode, compress=self.compress,
+ verbose=verbose, timestamp=self.timestamp,
+ cache_validation_callback=cache_validation_callback
+ )
def clear(self, warn=True):
""" Erase the complete cache directory.
"""
- pass
+ if warn:
+ self.warn('Flushing completely the cache')
+ if self.store_backend is not None:
+ self.store_backend.clear()
+
+ # As the cache is completely clear, make sure the _FUNCTION_HASHES
+ # cache is also reset. Else, for a function that is present in this
+ # table, results cached after this clear will be have cache miss
+ # as the function code is not re-written.
+ _FUNCTION_HASHES.clear()
def reduce_size(self, bytes_limit=None, items_limit=None, age_limit=None):
"""Remove cache elements to make the cache fit its limits.
@@ -576,7 +1099,21 @@ class Memory(Logger):
of the cache, any items last accessed more than the given length of
time ago are deleted.
"""
- pass
+ if bytes_limit is None:
+ bytes_limit = self.bytes_limit
+
+ if self.store_backend is None:
+ # No cached results, this function does nothing.
+ return
+
+ if bytes_limit is None and items_limit is None and age_limit is None:
+ # No limitation to impose, returning
+ return
+
+ # Defers the actual limits enforcing to the store backend.
+ self.store_backend.enforce_store_limits(
+ bytes_limit, items_limit, age_limit
+ )
def eval(self, func, *args, **kwargs):
""" Eval function func with arguments `*args` and `**kwargs`,
@@ -587,12 +1124,19 @@ class Memory(Logger):
up to date.
"""
- pass
+ if self.store_backend is None:
+ return func(*args, **kwargs)
+ return self.cache(func)(*args, **kwargs)
+
+ # ------------------------------------------------------------------------
+ # Private `object` interface
+ # ------------------------------------------------------------------------
def __repr__(self):
- return '{class_name}(location={location})'.format(class_name=self.
- __class__.__name__, location=None if self.store_backend is None
- else self.store_backend.location)
+ return '{class_name}(location={location})'.format(
+ class_name=self.__class__.__name__,
+ location=(None if self.store_backend is None
+ else self.store_backend.location))
def __getstate__(self):
""" We don't store the timestamp when pickling, to avoid the hash
@@ -603,8 +1147,12 @@ class Memory(Logger):
return state
-def expires_after(days=0, seconds=0, microseconds=0, milliseconds=0,
- minutes=0, hours=0, weeks=0):
+###############################################################################
+# cache_validation_callback helpers
+###############################################################################
+
+def expires_after(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0,
+ hours=0, weeks=0):
"""Helper cache_validation_callback to force recompute after a duration.
Parameters
@@ -612,4 +1160,13 @@ def expires_after(days=0, seconds=0, microseconds=0, milliseconds=0,
days, seconds, microseconds, milliseconds, minutes, hours, weeks: numbers
argument passed to a timedelta.
"""
- pass
+ delta = datetime.timedelta(
+ days=days, seconds=seconds, microseconds=microseconds,
+ milliseconds=milliseconds, minutes=minutes, hours=hours, weeks=weeks
+ )
+
+ def cache_validation_callback(metadata):
+ computation_age = time.time() - metadata['time']
+ return computation_age < delta.total_seconds()
+
+ return cache_validation_callback
diff --git a/joblib/numpy_pickle.py b/joblib/numpy_pickle.py
index 6cc4089..bf83bb0 100644
--- a/joblib/numpy_pickle.py
+++ b/joblib/numpy_pickle.py
@@ -1,26 +1,48 @@
"""Utilities for fast persistence of big data, with optional compression."""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2009 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
import pickle
import os
import warnings
import io
from pathlib import Path
+
from .compressor import lz4, LZ4_NOT_INSTALLED_ERROR
from .compressor import _COMPRESSORS, register_compressor, BinaryZlibFile
-from .compressor import ZlibCompressorWrapper, GzipCompressorWrapper, BZ2CompressorWrapper, LZMACompressorWrapper, XZCompressorWrapper, LZ4CompressorWrapper
+from .compressor import (ZlibCompressorWrapper, GzipCompressorWrapper,
+ BZ2CompressorWrapper, LZMACompressorWrapper,
+ XZCompressorWrapper, LZ4CompressorWrapper)
from .numpy_pickle_utils import Unpickler, Pickler
from .numpy_pickle_utils import _read_fileobject, _write_fileobject
from .numpy_pickle_utils import _read_bytes, BUFFER_SIZE
from .numpy_pickle_utils import _ensure_native_byte_order
from .numpy_pickle_compat import load_compatibility
from .numpy_pickle_compat import NDArrayWrapper
-from .numpy_pickle_compat import ZNDArrayWrapper
+# For compatibility with old versions of joblib, we need ZNDArrayWrapper
+# to be visible in the current namespace.
+# Explicitly skipping next line from flake8 as it triggers an F401 warning
+# which we don't care.
+from .numpy_pickle_compat import ZNDArrayWrapper # noqa
from .backports import make_memmap
+
+# Register supported compressors
register_compressor('zlib', ZlibCompressorWrapper())
register_compressor('gzip', GzipCompressorWrapper())
register_compressor('bz2', BZ2CompressorWrapper())
register_compressor('lzma', LZMACompressorWrapper())
register_compressor('xz', XZCompressorWrapper())
register_compressor('lz4', LZ4CompressorWrapper())
+
+
+###############################################################################
+# Utility objects for persistence.
+
+# For convenience, 16 bytes are used to be sure to cover all the possible
+# dtypes' alignments. For reference, see:
+# https://numpy.org/devdocs/dev/alignment.html
NUMPY_ARRAY_ALIGNMENT_BYTES = 16
@@ -55,22 +77,61 @@ class NumpyArrayWrapper(object):
"""
def __init__(self, subclass, shape, order, dtype, allow_mmap=False,
- numpy_array_alignment_bytes=NUMPY_ARRAY_ALIGNMENT_BYTES):
+ numpy_array_alignment_bytes=NUMPY_ARRAY_ALIGNMENT_BYTES):
"""Constructor. Store the useful information for later."""
self.subclass = subclass
self.shape = shape
self.order = order
self.dtype = dtype
self.allow_mmap = allow_mmap
+ # We make numpy_array_alignment_bytes an instance attribute to allow us
+ # to change our mind about the default alignment and still load the old
+ # pickles (with the previous alignment) correctly
self.numpy_array_alignment_bytes = numpy_array_alignment_bytes
+ def safe_get_numpy_array_alignment_bytes(self):
+ # NumpyArrayWrapper instances loaded from joblib <= 1.1 pickles don't
+ # have an numpy_array_alignment_bytes attribute
+ return getattr(self, 'numpy_array_alignment_bytes', None)
+
def write_array(self, array, pickler):
"""Write array bytes to pickler file handle.
This function is an adaptation of the numpy write_array function
available in version 1.10.1 in numpy/lib/format.py.
"""
- pass
+ # Set buffer size to 16 MiB to hide the Python loop overhead.
+ buffersize = max(16 * 1024 ** 2 // array.itemsize, 1)
+ if array.dtype.hasobject:
+ # We contain Python objects so we cannot write out the data
+ # directly. Instead, we will pickle it out with version 2 of the
+ # pickle protocol.
+ pickle.dump(array, pickler.file_handle, protocol=2)
+ else:
+ numpy_array_alignment_bytes = \
+ self.safe_get_numpy_array_alignment_bytes()
+ if numpy_array_alignment_bytes is not None:
+ current_pos = pickler.file_handle.tell()
+ pos_after_padding_byte = current_pos + 1
+ padding_length = numpy_array_alignment_bytes - (
+ pos_after_padding_byte % numpy_array_alignment_bytes)
+ # A single byte is written that contains the padding length in
+ # bytes
+ padding_length_byte = int.to_bytes(
+ padding_length, length=1, byteorder='little')
+ pickler.file_handle.write(padding_length_byte)
+
+ if padding_length != 0:
+ padding = b'\xff' * padding_length
+ pickler.file_handle.write(padding)
+
+ for chunk in pickler.np.nditer(array,
+ flags=['external_loop',
+ 'buffered',
+ 'zerosize_ok'],
+ buffersize=buffersize,
+ order=self.order):
+ pickler.file_handle.write(chunk.tobytes('C'))
def read_array(self, unpickler):
"""Read array from unpickler file handle.
@@ -78,11 +139,97 @@ class NumpyArrayWrapper(object):
This function is an adaptation of the numpy read_array function
available in version 1.10.1 in numpy/lib/format.py.
"""
- pass
+ if len(self.shape) == 0:
+ count = 1
+ else:
+ # joblib issue #859: we cast the elements of self.shape to int64 to
+ # prevent a potential overflow when computing their product.
+ shape_int64 = [unpickler.np.int64(x) for x in self.shape]
+ count = unpickler.np.multiply.reduce(shape_int64)
+ # Now read the actual data.
+ if self.dtype.hasobject:
+ # The array contained Python objects. We need to unpickle the data.
+ array = pickle.load(unpickler.file_handle)
+ else:
+ numpy_array_alignment_bytes = \
+ self.safe_get_numpy_array_alignment_bytes()
+ if numpy_array_alignment_bytes is not None:
+ padding_byte = unpickler.file_handle.read(1)
+ padding_length = int.from_bytes(
+ padding_byte, byteorder='little')
+ if padding_length != 0:
+ unpickler.file_handle.read(padding_length)
+
+ # This is not a real file. We have to read it the
+ # memory-intensive way.
+ # crc32 module fails on reads greater than 2 ** 32 bytes,
+ # breaking large reads from gzip streams. Chunk reads to
+ # BUFFER_SIZE bytes to avoid issue and reduce memory overhead
+ # of the read. In non-chunked case count < max_read_count, so
+ # only one read is performed.
+ max_read_count = BUFFER_SIZE // min(BUFFER_SIZE,
+ self.dtype.itemsize)
+
+ array = unpickler.np.empty(count, dtype=self.dtype)
+ for i in range(0, count, max_read_count):
+ read_count = min(max_read_count, count - i)
+ read_size = int(read_count * self.dtype.itemsize)
+ data = _read_bytes(unpickler.file_handle,
+ read_size, "array data")
+ array[i:i + read_count] = \
+ unpickler.np.frombuffer(data, dtype=self.dtype,
+ count=read_count)
+ del data
+
+ if self.order == 'F':
+ array.shape = self.shape[::-1]
+ array = array.transpose()
+ else:
+ array.shape = self.shape
+
+ # Detect byte order mismatch and swap as needed.
+ return _ensure_native_byte_order(array)
def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
- pass
+ current_pos = unpickler.file_handle.tell()
+ offset = current_pos
+ numpy_array_alignment_bytes = \
+ self.safe_get_numpy_array_alignment_bytes()
+
+ if numpy_array_alignment_bytes is not None:
+ padding_byte = unpickler.file_handle.read(1)
+ padding_length = int.from_bytes(padding_byte, byteorder='little')
+ # + 1 is for the padding byte
+ offset += padding_length + 1
+
+ if unpickler.mmap_mode == 'w+':
+ unpickler.mmap_mode = 'r+'
+
+ marray = make_memmap(unpickler.filename,
+ dtype=self.dtype,
+ shape=self.shape,
+ order=self.order,
+ mode=unpickler.mmap_mode,
+ offset=offset)
+ # update the offset so that it corresponds to the end of the read array
+ unpickler.file_handle.seek(offset + marray.nbytes)
+
+ if (numpy_array_alignment_bytes is None and
+ current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES != 0):
+ message = (
+ f'The memmapped array {marray} loaded from the file '
+ f'{unpickler.file_handle.name} is not byte aligned. '
+ 'This may cause segmentation faults if this memmapped array '
+ 'is used in some libraries like BLAS or PyTorch. '
+ 'To get rid of this warning, regenerate your pickle file '
+ 'with joblib >= 1.2.0. '
+ 'See https://github.com/joblib/joblib/issues/563 '
+ 'for more details'
+ )
+ warnings.warn(message)
+
+ return _ensure_native_byte_order(marray)
def read(self, unpickler):
"""Read the array corresponding to this wrapper.
@@ -98,7 +245,25 @@ class NumpyArrayWrapper(object):
array: numpy.ndarray
"""
- pass
+ # When requested, only use memmap mode if allowed.
+ if unpickler.mmap_mode is not None and self.allow_mmap:
+ array = self.read_mmap(unpickler)
+ else:
+ array = self.read_array(unpickler)
+
+ # Manage array subclass case
+ if (hasattr(array, '__array_prepare__') and
+ self.subclass not in (unpickler.np.ndarray,
+ unpickler.np.memmap)):
+ # We need to reconstruct another subclass
+ new_array = unpickler.np.core.multiarray._reconstruct(
+ self.subclass, (0,), 'b')
+ return new_array.__array_prepare__(array)
+ else:
+ return array
+
+###############################################################################
+# Pickler classes
class NumpyPickler(Pickler):
@@ -115,14 +280,20 @@ class NumpyPickler(Pickler):
protocol: int, optional
Pickle protocol used. Default is pickle.DEFAULT_PROTOCOL.
"""
+
dispatch = Pickler.dispatch.copy()
def __init__(self, fp, protocol=None):
self.file_handle = fp
self.buffered = isinstance(self.file_handle, BinaryZlibFile)
+
+ # By default we want a pickle protocol that only changes with
+ # the major python version and not the minor one
if protocol is None:
protocol = pickle.DEFAULT_PROTOCOL
+
Pickler.__init__(self, self.file_handle, protocol=protocol)
+ # delayed import of numpy, to avoid tight coupling
try:
import numpy as np
except ImportError:
@@ -131,7 +302,22 @@ class NumpyPickler(Pickler):
def _create_array_wrapper(self, array):
"""Create and returns a numpy array wrapper from a numpy array."""
- pass
+ order = 'F' if (array.flags.f_contiguous and
+ not array.flags.c_contiguous) else 'C'
+ allow_mmap = not self.buffered and not array.dtype.hasobject
+
+ kwargs = {}
+ try:
+ self.file_handle.tell()
+ except io.UnsupportedOperation:
+ kwargs = {'numpy_array_alignment_bytes': None}
+
+ wrapper = NumpyArrayWrapper(type(array),
+ array.shape, order, array.dtype,
+ allow_mmap=allow_mmap,
+ **kwargs)
+
+ return wrapper
def save(self, obj):
"""Subclass the Pickler `save` method.
@@ -143,7 +329,30 @@ class NumpyPickler(Pickler):
after in the file. Warning: the file produced does not follow the
pickle format. As such it can not be read with `pickle.load`.
"""
- pass
+ if self.np is not None and type(obj) in (self.np.ndarray,
+ self.np.matrix,
+ self.np.memmap):
+ if type(obj) is self.np.memmap:
+ # Pickling doesn't work with memmapped arrays
+ obj = self.np.asanyarray(obj)
+
+ # The array wrapper is pickled instead of the real array.
+ wrapper = self._create_array_wrapper(obj)
+ Pickler.save(self, wrapper)
+
+ # A framer was introduced with pickle protocol 4 and we want to
+ # ensure the wrapper object is written before the numpy array
+ # buffer in the pickle file.
+ # See https://www.python.org/dev/peps/pep-3154/#framing to get
+ # more information on the framer behavior.
+ if self.proto >= 4:
+ self.framer.commit_frame(force=True)
+
+ # And then array bytes are written right after the wrapper.
+ wrapper.write_array(obj, self)
+ return
+
+ return Pickler.save(self, obj)
class NumpyUnpickler(Unpickler):
@@ -162,12 +371,17 @@ class NumpyUnpickler(Unpickler):
Reference to numpy module if numpy is installed else None.
"""
+
dispatch = Unpickler.dispatch.copy()
def __init__(self, filename, file_handle, mmap_mode=None):
+ # The next line is for backward compatibility with pickle generated
+ # with joblib versions less than 0.10.
self._dirname = os.path.dirname(filename)
+
self.mmap_mode = mmap_mode
self.file_handle = file_handle
+ # filename is required for numpy mmap mode.
self.filename = filename
self.compat_mode = False
Unpickler.__init__(self, self.file_handle)
@@ -185,10 +399,28 @@ class NumpyUnpickler(Unpickler):
replace them directly in the stack of pickler.
NDArrayWrapper is used for backward compatibility with joblib <= 0.9.
"""
- pass
+ Unpickler.load_build(self)
+
+ # For backward compatibility, we support NDArrayWrapper objects.
+ if isinstance(self.stack[-1], (NDArrayWrapper, NumpyArrayWrapper)):
+ if self.np is None:
+ raise ImportError("Trying to unpickle an ndarray, "
+ "but numpy didn't import correctly")
+ array_wrapper = self.stack.pop()
+ # If any NDArrayWrapper is found, we switch to compatibility mode,
+ # this will be used to raise a DeprecationWarning to the user at
+ # the end of the unpickling.
+ if isinstance(array_wrapper, NDArrayWrapper):
+ self.compat_mode = True
+ self.stack.append(array_wrapper.read(self))
+
+ # Be careful to register our new method.
dispatch[pickle.BUILD[0]] = load_build
+###############################################################################
+# Utility functions
+
def dump(value, filename, compress=0, protocol=None, cache_size=None):
"""Persist an arbitrary Python object into one file.
@@ -236,12 +468,137 @@ def dump(value, filename, compress=0, protocol=None, cache_size=None):
dump and load.
"""
- pass
-
-def _unpickle(fobj, filename='', mmap_mode=None):
+ if Path is not None and isinstance(filename, Path):
+ filename = str(filename)
+
+ is_filename = isinstance(filename, str)
+ is_fileobj = hasattr(filename, "write")
+
+ compress_method = 'zlib' # zlib is the default compression method.
+ if compress is True:
+ # By default, if compress is enabled, we want the default compress
+ # level of the compressor.
+ compress_level = None
+ elif isinstance(compress, tuple):
+ # a 2-tuple was set in compress
+ if len(compress) != 2:
+ raise ValueError(
+ 'Compress argument tuple should contain exactly 2 elements: '
+ '(compress method, compress level), you passed {}'
+ .format(compress))
+ compress_method, compress_level = compress
+ elif isinstance(compress, str):
+ compress_method = compress
+ compress_level = None # Use default compress level
+ compress = (compress_method, compress_level)
+ else:
+ compress_level = compress
+
+ if compress_method == 'lz4' and lz4 is None:
+ raise ValueError(LZ4_NOT_INSTALLED_ERROR)
+
+ if (compress_level is not None and
+ compress_level is not False and
+ compress_level not in range(10)):
+ # Raising an error if a non valid compress level is given.
+ raise ValueError(
+ 'Non valid compress level given: "{}". Possible values are '
+ '{}.'.format(compress_level, list(range(10))))
+
+ if compress_method not in _COMPRESSORS:
+ # Raising an error if an unsupported compression method is given.
+ raise ValueError(
+ 'Non valid compression method given: "{}". Possible values are '
+ '{}.'.format(compress_method, _COMPRESSORS))
+
+ if not is_filename and not is_fileobj:
+ # People keep inverting arguments, and the resulting error is
+ # incomprehensible
+ raise ValueError(
+ 'Second argument should be a filename or a file-like object, '
+ '%s (type %s) was given.'
+ % (filename, type(filename))
+ )
+
+ if is_filename and not isinstance(compress, tuple):
+ # In case no explicit compression was requested using both compression
+ # method and level in a tuple and the filename has an explicit
+ # extension, we select the corresponding compressor.
+
+ # unset the variable to be sure no compression level is set afterwards.
+ compress_method = None
+ for name, compressor in _COMPRESSORS.items():
+ if filename.endswith(compressor.extension):
+ compress_method = name
+
+ if compress_method in _COMPRESSORS and compress_level == 0:
+ # we choose the default compress_level in case it was not given
+ # as an argument (using compress).
+ compress_level = None
+
+ if cache_size is not None:
+ # Cache size is deprecated starting from version 0.10
+ warnings.warn("Please do not set 'cache_size' in joblib.dump, "
+ "this parameter has no effect and will be removed. "
+ "You used 'cache_size={}'".format(cache_size),
+ DeprecationWarning, stacklevel=2)
+
+ if compress_level != 0:
+ with _write_fileobject(filename, compress=(compress_method,
+ compress_level)) as f:
+ NumpyPickler(f, protocol=protocol).dump(value)
+ elif is_filename:
+ with open(filename, 'wb') as f:
+ NumpyPickler(f, protocol=protocol).dump(value)
+ else:
+ NumpyPickler(filename, protocol=protocol).dump(value)
+
+ # If the target container is a file object, nothing is returned.
+ if is_fileobj:
+ return
+
+ # For compatibility, the list of created filenames (e.g with one element
+ # after 0.10.0) is returned by default.
+ return [filename]
+
+
+def _unpickle(fobj, filename="", mmap_mode=None):
"""Internal unpickling function."""
- pass
+ # We are careful to open the file handle early and keep it open to
+ # avoid race-conditions on renames.
+ # That said, if data is stored in companion files, which can be
+ # the case with the old persistence format, moving the directory
+ # will create a race when joblib tries to access the companion
+ # files.
+ unpickler = NumpyUnpickler(filename, fobj, mmap_mode=mmap_mode)
+ obj = None
+ try:
+ obj = unpickler.load()
+ if unpickler.compat_mode:
+ warnings.warn("The file '%s' has been generated with a "
+ "joblib version less than 0.10. "
+ "Please regenerate this pickle file."
+ % filename,
+ DeprecationWarning, stacklevel=3)
+ except UnicodeDecodeError as exc:
+ # More user-friendly error message
+ new_exc = ValueError(
+ 'You may be trying to read with '
+ 'python 3 a joblib pickle generated with python 2. '
+ 'This feature is not supported by joblib.')
+ new_exc.__cause__ = exc
+ raise new_exc
+ return obj
+
+
+def load_temporary_memmap(filename, mmap_mode, unlink_on_gc_collect):
+ from ._memmapping_reducer import JOBLIB_MMAPS, add_maybe_unlink_finalizer
+ obj = load(filename, mmap_mode)
+ JOBLIB_MMAPS.add(obj.filename)
+ if unlink_on_gc_collect:
+ add_maybe_unlink_finalizer(obj)
+ return obj
def load(filename, mmap_mode=None):
@@ -281,4 +638,22 @@ def load(filename, mmap_mode=None):
object might not match the original pickled object. Note that if the
file was saved with compression, the arrays cannot be memmapped.
"""
- pass
+ if Path is not None and isinstance(filename, Path):
+ filename = str(filename)
+
+ if hasattr(filename, "read"):
+ fobj = filename
+ filename = getattr(fobj, 'name', '')
+ with _read_fileobject(fobj, filename, mmap_mode) as fobj:
+ obj = _unpickle(fobj)
+ else:
+ with open(filename, 'rb') as f:
+ with _read_fileobject(f, filename, mmap_mode) as fobj:
+ if isinstance(fobj, str):
+ # if the returned file object is a string, this means we
+ # try to load a pickle file generated with an version of
+ # Joblib so we load it with joblib compatibility function.
+ return load_compatibility(fobj)
+
+ obj = _unpickle(fobj, filename, mmap_mode)
+ return obj
diff --git a/joblib/numpy_pickle_compat.py b/joblib/numpy_pickle_compat.py
index 4ccffb6..3261284 100644
--- a/joblib/numpy_pickle_compat.py
+++ b/joblib/numpy_pickle_compat.py
@@ -1,9 +1,12 @@
"""Numpy pickle compatibility functions."""
+
import pickle
import os
import zlib
import inspect
+
from io import BytesIO
+
from .numpy_pickle_utils import _ZFILE_PREFIX
from .numpy_pickle_utils import Unpickler
from .numpy_pickle_utils import _ensure_native_byte_order
@@ -11,7 +14,13 @@ from .numpy_pickle_utils import _ensure_native_byte_order
def hex_str(an_int):
"""Convert an int to an hexadecimal string."""
- pass
+ return '{:#x}'.format(an_int)
+
+
+def asbytes(s):
+ if isinstance(s, bytes):
+ return s
+ return s.encode('latin1')
_MAX_LEN = len(hex_str(2 ** 64))
@@ -25,7 +34,30 @@ def read_zfile(file_handle):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.seek(0)
+ header_length = len(_ZFILE_PREFIX) + _MAX_LEN
+ length = file_handle.read(header_length)
+ length = length[len(_ZFILE_PREFIX):]
+ length = int(length, 16)
+
+ # With python2 and joblib version <= 0.8.4 compressed pickle header is one
+ # character wider so we need to ignore an additional space if present.
+ # Note: the first byte of the zlib data is guaranteed not to be a
+ # space according to
+ # https://tools.ietf.org/html/rfc6713#section-2.1
+ next_byte = file_handle.read(1)
+ if next_byte != b' ':
+ # The zlib compressed data has started and we need to go back
+ # one byte
+ file_handle.seek(header_length)
+
+ # We use the known length of the data to tell Zlib the size of the
+ # buffer to allocate.
+ data = zlib.decompress(file_handle.read(), 15, length)
+ assert len(data) == length, (
+ "Incorrect data length while decompressing %s."
+ "The file could be corrupted." % file_handle)
+ return data
def write_zfile(file_handle, data, compress=1):
@@ -35,7 +67,14 @@ def write_zfile(file_handle, data, compress=1):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.write(_ZFILE_PREFIX)
+ length = hex_str(len(data))
+ # Store the length of the data
+ file_handle.write(asbytes(length.ljust(_MAX_LEN)))
+ file_handle.write(zlib.compress(asbytes(data), compress))
+
+###############################################################################
+# Utility objects for persistence.
class NDArrayWrapper(object):
@@ -53,7 +92,34 @@ class NDArrayWrapper(object):
def read(self, unpickler):
"""Reconstruct the array."""
- pass
+ filename = os.path.join(unpickler._dirname, self.filename)
+ # Load the array from the disk
+ # use getattr instead of self.allow_mmap to ensure backward compat
+ # with NDArrayWrapper instances pickled with joblib < 0.9.0
+ allow_mmap = getattr(self, 'allow_mmap', True)
+ kwargs = {}
+ if allow_mmap:
+ kwargs['mmap_mode'] = unpickler.mmap_mode
+ if "allow_pickle" in inspect.signature(unpickler.np.load).parameters:
+ # Required in numpy 1.16.3 and later to aknowledge the security
+ # risk.
+ kwargs["allow_pickle"] = True
+ array = unpickler.np.load(filename, **kwargs)
+
+ # Detect byte order mismatch and swap as needed.
+ array = _ensure_native_byte_order(array)
+
+ # Reconstruct subclasses. This does not work with old
+ # versions of numpy
+ if (hasattr(array, '__array_prepare__') and
+ self.subclass not in (unpickler.np.ndarray,
+ unpickler.np.memmap)):
+ # We need to reconstruct another subclass
+ new_array = unpickler.np.core.multiarray._reconstruct(
+ self.subclass, (0,), 'b')
+ return new_array.__array_prepare__(array)
+ else:
+ return array
class ZNDArrayWrapper(NDArrayWrapper):
@@ -79,11 +145,20 @@ class ZNDArrayWrapper(NDArrayWrapper):
def read(self, unpickler):
"""Reconstruct the array from the meta-information and the z-file."""
- pass
+ # Here we a simply reproducing the unpickling mechanism for numpy
+ # arrays
+ filename = os.path.join(unpickler._dirname, self.filename)
+ array = unpickler.np.core.multiarray._reconstruct(*self.init_args)
+ with open(filename, 'rb') as f:
+ data = read_zfile(f)
+ state = self.state + (data,)
+ array.__setstate__(state)
+ return array
class ZipNumpyUnpickler(Unpickler):
"""A subclass of the Unpickler to unpickle our numpy pickles."""
+
dispatch = Unpickler.dispatch.copy()
def __init__(self, filename, file_handle, mmap_mode=None):
@@ -99,6 +174,9 @@ class ZipNumpyUnpickler(Unpickler):
np = None
self.np = np
+ def _open_pickle(self, file_handle):
+ return BytesIO(read_zfile(file_handle))
+
def load_build(self):
"""Set the state of a newly created object.
@@ -106,7 +184,15 @@ class ZipNumpyUnpickler(Unpickler):
NDArrayWrapper, by the array we are interested in. We
replace them directly in the stack of pickler.
"""
- pass
+ Unpickler.load_build(self)
+ if isinstance(self.stack[-1], NDArrayWrapper):
+ if self.np is None:
+ raise ImportError("Trying to unpickle an ndarray, "
+ "but numpy didn't import correctly")
+ nd_array_wrapper = self.stack.pop()
+ array = nd_array_wrapper.read(self)
+ self.stack.append(array)
+
dispatch[pickle.BUILD[0]] = load_build
@@ -136,4 +222,23 @@ def load_compatibility(filename):
This function can load numpy array files saved separately during the
dump.
"""
- pass
+ with open(filename, 'rb') as file_handle:
+ # We are careful to open the file handle early and keep it open to
+ # avoid race-conditions on renames. That said, if data is stored in
+ # companion files, moving the directory will create a race when
+ # joblib tries to access the companion files.
+ unpickler = ZipNumpyUnpickler(filename, file_handle=file_handle)
+ try:
+ obj = unpickler.load()
+ except UnicodeDecodeError as exc:
+ # More user-friendly error message
+ new_exc = ValueError(
+ 'You may be trying to read with '
+ 'python 3 a joblib pickle generated with python 2. '
+ 'This feature is not supported by joblib.')
+ new_exc.__cause__ = exc
+ raise new_exc
+ finally:
+ if hasattr(unpickler, 'file_handle'):
+ unpickler.file_handle.close()
+ return obj
diff --git a/joblib/numpy_pickle_utils.py b/joblib/numpy_pickle_utils.py
index e79528e..23cfb34 100644
--- a/joblib/numpy_pickle_utils.py
+++ b/joblib/numpy_pickle_utils.py
@@ -1,33 +1,66 @@
"""Utilities for fast persistence of big data, with optional compression."""
+
+# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
+# Copyright (c) 2009 Gael Varoquaux
+# License: BSD Style, 3 clauses.
+
import pickle
import io
import sys
import warnings
import contextlib
+
from .compressor import _ZFILE_PREFIX
from .compressor import _COMPRESSORS
+
try:
import numpy as np
except ImportError:
np = None
+
Unpickler = pickle._Unpickler
Pickler = pickle._Pickler
xrange = range
+
+
try:
+ # The python standard library can be built without bz2 so we make bz2
+ # usage optional.
+ # see https://github.com/scikit-learn/scikit-learn/issues/7526 for more
+ # details.
import bz2
except ImportError:
bz2 = None
+
+# Buffer size used in io.BufferedReader and io.BufferedWriter
_IO_BUFFER_SIZE = 1024 ** 2
def _is_raw_file(fileobj):
"""Check if fileobj is a raw file object, e.g created with open."""
- pass
+ fileobj = getattr(fileobj, 'raw', fileobj)
+ return isinstance(fileobj, io.FileIO)
+
+
+def _get_prefixes_max_len():
+ # Compute the max prefix len of registered compressors.
+ prefixes = [len(compressor.prefix) for compressor in _COMPRESSORS.values()]
+ prefixes += [len(_ZFILE_PREFIX)]
+ return max(prefixes)
def _is_numpy_array_byte_order_mismatch(array):
"""Check if numpy array is having byte order mismatch"""
- pass
+ return ((sys.byteorder == 'big' and
+ (array.dtype.byteorder == '<' or
+ (array.dtype.byteorder == '|' and array.dtype.fields and
+ all(e[0].byteorder == '<'
+ for e in array.dtype.fields.values())))) or
+ (sys.byteorder == 'little' and
+ (array.dtype.byteorder == '>' or
+ (array.dtype.byteorder == '|' and array.dtype.fields and
+ all(e[0].byteorder == '>'
+ for e in array.dtype.fields.values())))))
def _ensure_native_byte_order(array):
@@ -35,9 +68,13 @@ def _ensure_native_byte_order(array):
Does nothing if array already uses the system byte order.
"""
- pass
+ if _is_numpy_array_byte_order_mismatch(array):
+ array = array.byteswap().view(array.dtype.newbyteorder('='))
+ return array
+###############################################################################
+# Cache file utilities
def _detect_compressor(fileobj):
"""Return the compressor matching fileobj.
@@ -49,17 +86,35 @@ def _detect_compressor(fileobj):
-------
str in {'zlib', 'gzip', 'bz2', 'lzma', 'xz', 'compat', 'not-compressed'}
"""
- pass
+ # Read the magic number in the first bytes of the file.
+ max_prefix_len = _get_prefixes_max_len()
+ if hasattr(fileobj, 'peek'):
+ # Peek allows to read those bytes without moving the cursor in the
+ # file whic.
+ first_bytes = fileobj.peek(max_prefix_len)
+ else:
+ # Fallback to seek if the fileobject is not peekable.
+ first_bytes = fileobj.read(max_prefix_len)
+ fileobj.seek(0)
+
+ if first_bytes.startswith(_ZFILE_PREFIX):
+ return "compat"
+ else:
+ for name, compressor in _COMPRESSORS.items():
+ if first_bytes.startswith(compressor.prefix):
+ return name
+
+ return "not-compressed"
def _buffered_read_file(fobj):
"""Return a buffered version of a read file object."""
- pass
+ return io.BufferedReader(fobj, buffer_size=_IO_BUFFER_SIZE)
def _buffered_write_file(fobj):
"""Return a buffered version of a write file object."""
- pass
+ return io.BufferedWriter(fobj, buffer_size=_IO_BUFFER_SIZE)
@contextlib.contextmanager
@@ -90,18 +145,70 @@ def _read_fileobject(fileobj, filename, mmap_mode=None):
a file like object
"""
- pass
-
-
-def _write_fileobject(filename, compress=('zlib', 3)):
+ # Detect if the fileobj contains compressed data.
+ compressor = _detect_compressor(fileobj)
+
+ if compressor == 'compat':
+ # Compatibility with old pickle mode: simply return the input
+ # filename "as-is" and let the compatibility function be called by the
+ # caller.
+ warnings.warn("The file '%s' has been generated with a joblib "
+ "version less than 0.10. "
+ "Please regenerate this pickle file." % filename,
+ DeprecationWarning, stacklevel=2)
+ yield filename
+ else:
+ if compressor in _COMPRESSORS:
+ # based on the compressor detected in the file, we open the
+ # correct decompressor file object, wrapped in a buffer.
+ compressor_wrapper = _COMPRESSORS[compressor]
+ inst = compressor_wrapper.decompressor_file(fileobj)
+ fileobj = _buffered_read_file(inst)
+
+ # Checking if incompatible load parameters with the type of file:
+ # mmap_mode cannot be used with compressed file or in memory buffers
+ # such as io.BytesIO.
+ if mmap_mode is not None:
+ if isinstance(fileobj, io.BytesIO):
+ warnings.warn('In memory persistence is not compatible with '
+ 'mmap_mode "%(mmap_mode)s" flag passed. '
+ 'mmap_mode option will be ignored.'
+ % locals(), stacklevel=2)
+ elif compressor != 'not-compressed':
+ warnings.warn('mmap_mode "%(mmap_mode)s" is not compatible '
+ 'with compressed file %(filename)s. '
+ '"%(mmap_mode)s" flag will be ignored.'
+ % locals(), stacklevel=2)
+ elif not _is_raw_file(fileobj):
+ warnings.warn('"%(fileobj)r" is not a raw file, mmap_mode '
+ '"%(mmap_mode)s" flag will be ignored.'
+ % locals(), stacklevel=2)
+
+ yield fileobj
+
+
+def _write_fileobject(filename, compress=("zlib", 3)):
"""Return the right compressor file object in write mode."""
- pass
+ compressmethod = compress[0]
+ compresslevel = compress[1]
+
+ if compressmethod in _COMPRESSORS.keys():
+ file_instance = _COMPRESSORS[compressmethod].compressor_file(
+ filename, compresslevel=compresslevel)
+ return _buffered_write_file(file_instance)
+ else:
+ file_instance = _COMPRESSORS['zlib'].compressor_file(
+ filename, compresslevel=compresslevel)
+ return _buffered_write_file(file_instance)
-BUFFER_SIZE = 2 ** 18
+# Utility functions/variables from numpy required for writing arrays.
+# We need at least the functions introduced in version 1.9 of numpy. Here,
+# we use the ones from numpy 1.10.2.
+BUFFER_SIZE = 2 ** 18 # size of buffer for reading npz files in bytes
-def _read_bytes(fp, size, error_template='ran out of data'):
+def _read_bytes(fp, size, error_template="ran out of data"):
"""Read from file-like object until size bytes are read.
TODO python2_drop: is it still needed? The docstring mentions python 2.6
@@ -127,4 +234,20 @@ def _read_bytes(fp, size, error_template='ran out of data'):
The data read in bytes.
"""
- pass
+ data = bytes()
+ while True:
+ # io files (default in python3) return None or raise on
+ # would-block, python2 file will truncate, probably nothing can be
+ # done about that. note that regular files can't be non-blocking
+ try:
+ r = fp.read(size - len(data))
+ data += r
+ if len(r) == 0 or len(data) == size:
+ break
+ except io.BlockingIOError:
+ pass
+ if len(data) != size:
+ msg = "EOF: reading %s, expected %d bytes got %d"
+ raise ValueError(msg % (error_template, size, len(data)))
+ else:
+ return data
diff --git a/joblib/parallel.py b/joblib/parallel.py
index f141ee0..fa4fd3c 100644
--- a/joblib/parallel.py
+++ b/joblib/parallel.py
@@ -1,7 +1,12 @@
"""
Helpers for embarrassingly parallel code.
"""
+# Author: Gael Varoquaux < gael dot varoquaux at normalesup dot org >
+# Copyright: 2010, Gael Varoquaux
+# License: BSD 3 clause
+
from __future__ import division
+
import os
import sys
from math import sqrt
@@ -16,41 +21,90 @@ import warnings
import queue
import weakref
from contextlib import nullcontext
+
from multiprocessing import TimeoutError
+
from ._multiprocessing_helpers import mp
+
from .logger import Logger, short_format_time
from .disk import memstr_to_bytes
-from ._parallel_backends import FallbackToBackend, MultiprocessingBackend, ThreadingBackend, SequentialBackend, LokyBackend
+from ._parallel_backends import (FallbackToBackend, MultiprocessingBackend,
+ ThreadingBackend, SequentialBackend,
+ LokyBackend)
from ._utils import eval_expr, _Sentinel
-from ._parallel_backends import AutoBatchingMixin
-from ._parallel_backends import ParallelBackendBase
-IS_PYPY = hasattr(sys, 'pypy_version_info')
-BACKENDS = {'threading': ThreadingBackend, 'sequential': SequentialBackend}
+
+# Make sure that those two classes are part of the public joblib.parallel API
+# so that 3rd party backend implementers can import them from here.
+from ._parallel_backends import AutoBatchingMixin # noqa
+from ._parallel_backends import ParallelBackendBase # noqa
+
+
+IS_PYPY = hasattr(sys, "pypy_version_info")
+
+
+BACKENDS = {
+ 'threading': ThreadingBackend,
+ 'sequential': SequentialBackend,
+}
+# name of the backend used by default by Parallel outside of any context
+# managed by ``parallel_config`` or ``parallel_backend``.
+
+# threading is the only backend that is always everywhere
DEFAULT_BACKEND = 'threading'
+
MAYBE_AVAILABLE_BACKENDS = {'multiprocessing', 'loky'}
+
+# if multiprocessing is available, so is loky, we set it as the default
+# backend
if mp is not None:
BACKENDS['multiprocessing'] = MultiprocessingBackend
from .externals import loky
BACKENDS['loky'] = LokyBackend
DEFAULT_BACKEND = 'loky'
+
+
DEFAULT_THREAD_BACKEND = 'threading'
+
+
+# Thread local value that can be overridden by the ``parallel_config`` context
+# manager
_backend = threading.local()
def _register_dask():
"""Register Dask Backend if called with parallel_config(backend="dask")"""
- pass
+ try:
+ from ._dask import DaskDistributedBackend
+ register_parallel_backend('dask', DaskDistributedBackend)
+ except ImportError as e:
+ msg = ("To use the dask.distributed backend you must install both "
+ "the `dask` and distributed modules.\n\n"
+ "See https://dask.pydata.org/en/latest/install.html for more "
+ "information.")
+ raise ImportError(msg) from e
+
+
+EXTERNAL_BACKENDS = {
+ 'dask': _register_dask,
+}
+
+# Sentinels for the default values of the Parallel constructor and
+# the parallel_config and parallel_backend context managers
+default_parallel_config = {
+ "backend": _Sentinel(default_value=None),
+ "n_jobs": _Sentinel(default_value=None),
+ "verbose": _Sentinel(default_value=0),
+ "temp_folder": _Sentinel(default_value=None),
+ "max_nbytes": _Sentinel(default_value="1M"),
+ "mmap_mode": _Sentinel(default_value="r"),
+ "prefer": _Sentinel(default_value=None),
+ "require": _Sentinel(default_value=None),
+}
-EXTERNAL_BACKENDS = {'dask': _register_dask}
-default_parallel_config = {'backend': _Sentinel(default_value=None),
- 'n_jobs': _Sentinel(default_value=None), 'verbose': _Sentinel(
- default_value=0), 'temp_folder': _Sentinel(default_value=None),
- 'max_nbytes': _Sentinel(default_value='1M'), 'mmap_mode': _Sentinel(
- default_value='r'), 'prefer': _Sentinel(default_value=None), 'require':
- _Sentinel(default_value=None)}
-VALID_BACKEND_HINTS = 'processes', 'threads', None
-VALID_BACKEND_CONSTRAINTS = 'sharedmem', None
+
+VALID_BACKEND_HINTS = ('processes', 'threads', None)
+VALID_BACKEND_CONSTRAINTS = ('sharedmem', None)
def _get_config_param(param, context_config, key):
@@ -59,21 +113,105 @@ def _get_config_param(param, context_config, key):
Explicitly setting it in Parallel has priority over setting in a
parallel_(config/backend) context manager.
"""
- pass
+ if param is not default_parallel_config[key]:
+ # param is explicitly set, return it
+ return param
+ if context_config[key] is not default_parallel_config[key]:
+ # there's a context manager and the key is set, return it
+ return context_config[key]
-def get_active_backend(prefer=default_parallel_config['prefer'], require=
- default_parallel_config['require'], verbose=default_parallel_config[
- 'verbose']):
- """Return the active default backend"""
- pass
+ # Otherwise, we are in the default_parallel_config,
+ # return the default value
+ return param.default_value
-def _get_active_backend(prefer=default_parallel_config['prefer'], require=
- default_parallel_config['require'], verbose=default_parallel_config[
- 'verbose']):
+def get_active_backend(
+ prefer=default_parallel_config["prefer"],
+ require=default_parallel_config["require"],
+ verbose=default_parallel_config["verbose"],
+):
+ """Return the active default backend"""
+ backend, config = _get_active_backend(prefer, require, verbose)
+ n_jobs = _get_config_param(
+ default_parallel_config['n_jobs'], config, "n_jobs"
+ )
+ return backend, n_jobs
+
+
+def _get_active_backend(
+ prefer=default_parallel_config["prefer"],
+ require=default_parallel_config["require"],
+ verbose=default_parallel_config["verbose"],
+):
"""Return the active default backend"""
- pass
+
+ backend_config = getattr(_backend, "config", default_parallel_config)
+
+ backend = _get_config_param(
+ default_parallel_config['backend'], backend_config, "backend"
+ )
+ prefer = _get_config_param(prefer, backend_config, "prefer")
+ require = _get_config_param(require, backend_config, "require")
+ verbose = _get_config_param(verbose, backend_config, "verbose")
+
+ if prefer not in VALID_BACKEND_HINTS:
+ raise ValueError(
+ f"prefer={prefer} is not a valid backend hint, "
+ f"expected one of {VALID_BACKEND_HINTS}"
+ )
+ if require not in VALID_BACKEND_CONSTRAINTS:
+ raise ValueError(
+ f"require={require} is not a valid backend constraint, "
+ f"expected one of {VALID_BACKEND_CONSTRAINTS}"
+ )
+ if prefer == 'processes' and require == 'sharedmem':
+ raise ValueError(
+ "prefer == 'processes' and require == 'sharedmem'"
+ " are inconsistent settings"
+ )
+
+ explicit_backend = True
+ if backend is None:
+
+ # We are either outside of the scope of any parallel_(config/backend)
+ # context manager or the context manager did not set a backend.
+ # create the default backend instance now.
+ backend = BACKENDS[DEFAULT_BACKEND](nesting_level=0)
+ explicit_backend = False
+
+ # Try to use the backend set by the user with the context manager.
+
+ nesting_level = backend.nesting_level
+ uses_threads = getattr(backend, 'uses_threads', False)
+ supports_sharedmem = getattr(backend, 'supports_sharedmem', False)
+ # Force to use thread-based backend if the provided backend does not
+ # match the shared memory constraint or if the backend is not explicitly
+ # given and threads are preferred.
+ force_threads = (require == 'sharedmem' and not supports_sharedmem)
+ force_threads |= (
+ not explicit_backend and prefer == 'threads' and not uses_threads
+ )
+ if force_threads:
+ # This backend does not match the shared memory constraint:
+ # fallback to the default thead-based backend.
+ sharedmem_backend = BACKENDS[DEFAULT_THREAD_BACKEND](
+ nesting_level=nesting_level
+ )
+ # Warn the user if we forced the backend to thread-based, while the
+ # user explicitly specified a non-thread-based backend.
+ if verbose >= 10 and explicit_backend:
+ print(
+ f"Using {sharedmem_backend.__class__.__name__} as "
+ f"joblib backend instead of {backend.__class__.__name__} "
+ "as the latter does not provide shared memory semantics."
+ )
+ # Force to n_jobs=1 by default
+ thread_config = backend_config.copy()
+ thread_config['n_jobs'] = 1
+ return sharedmem_backend, thread_config
+
+ return backend, backend_config
class parallel_config:
@@ -215,26 +353,98 @@ class parallel_config:
[-1, -2, -3, -4, -5]
"""
-
- def __init__(self, backend=default_parallel_config['backend'], *,
- n_jobs=default_parallel_config['n_jobs'], verbose=
- default_parallel_config['verbose'], temp_folder=
- default_parallel_config['temp_folder'], max_nbytes=
- default_parallel_config['max_nbytes'], mmap_mode=
- default_parallel_config['mmap_mode'], prefer=
- default_parallel_config['prefer'], require=default_parallel_config[
- 'require'], inner_max_num_threads=None, **backend_params):
- self.old_parallel_config = getattr(_backend, 'config',
- default_parallel_config)
- backend = self._check_backend(backend, inner_max_num_threads, **
- backend_params)
- new_config = {'n_jobs': n_jobs, 'verbose': verbose, 'temp_folder':
- temp_folder, 'max_nbytes': max_nbytes, 'mmap_mode': mmap_mode,
- 'prefer': prefer, 'require': require, 'backend': backend}
+ def __init__(
+ self,
+ backend=default_parallel_config["backend"],
+ *,
+ n_jobs=default_parallel_config["n_jobs"],
+ verbose=default_parallel_config["verbose"],
+ temp_folder=default_parallel_config["temp_folder"],
+ max_nbytes=default_parallel_config["max_nbytes"],
+ mmap_mode=default_parallel_config["mmap_mode"],
+ prefer=default_parallel_config["prefer"],
+ require=default_parallel_config["require"],
+ inner_max_num_threads=None,
+ **backend_params
+ ):
+ # Save the parallel info and set the active parallel config
+ self.old_parallel_config = getattr(
+ _backend, "config", default_parallel_config
+ )
+
+ backend = self._check_backend(
+ backend, inner_max_num_threads, **backend_params
+ )
+
+ new_config = {
+ "n_jobs": n_jobs,
+ "verbose": verbose,
+ "temp_folder": temp_folder,
+ "max_nbytes": max_nbytes,
+ "mmap_mode": mmap_mode,
+ "prefer": prefer,
+ "require": require,
+ "backend": backend
+ }
self.parallel_config = self.old_parallel_config.copy()
- self.parallel_config.update({k: v for k, v in new_config.items() if
- not isinstance(v, _Sentinel)})
- setattr(_backend, 'config', self.parallel_config)
+ self.parallel_config.update({
+ k: v for k, v in new_config.items()
+ if not isinstance(v, _Sentinel)
+ })
+
+ setattr(_backend, "config", self.parallel_config)
+
+ def _check_backend(self, backend, inner_max_num_threads, **backend_params):
+ if backend is default_parallel_config['backend']:
+ if inner_max_num_threads is not None or len(backend_params) > 0:
+ raise ValueError(
+ "inner_max_num_threads and other constructor "
+ "parameters backend_params are only supported "
+ "when backend is not None."
+ )
+ return backend
+
+ if isinstance(backend, str):
+ # Handle non-registered or missing backends
+ if backend not in BACKENDS:
+ if backend in EXTERNAL_BACKENDS:
+ register = EXTERNAL_BACKENDS[backend]
+ register()
+ elif backend in MAYBE_AVAILABLE_BACKENDS:
+ warnings.warn(
+ f"joblib backend '{backend}' is not available on "
+ f"your system, falling back to {DEFAULT_BACKEND}.",
+ UserWarning,
+ stacklevel=2
+ )
+ BACKENDS[backend] = BACKENDS[DEFAULT_BACKEND]
+ else:
+ raise ValueError(
+ f"Invalid backend: {backend}, expected one of "
+ f"{sorted(BACKENDS.keys())}"
+ )
+
+ backend = BACKENDS[backend](**backend_params)
+
+ if inner_max_num_threads is not None:
+ msg = (
+ f"{backend.__class__.__name__} does not accept setting the "
+ "inner_max_num_threads argument."
+ )
+ assert backend.supports_inner_max_num_threads, msg
+ backend.inner_max_num_threads = inner_max_num_threads
+
+ # If the nesting_level of the backend is not set previously, use the
+ # nesting level from the previous active_backend to set it
+ if backend.nesting_level is None:
+ parent_backend = self.old_parallel_config['backend']
+ if parent_backend is default_parallel_config['backend']:
+ nesting_level = 0
+ else:
+ nesting_level = parent_backend.nesting_level
+ backend.nesting_level = nesting_level
+
+ return backend
def __enter__(self):
return self.parallel_config
@@ -242,6 +452,9 @@ class parallel_config:
def __exit__(self, type, value, traceback):
self.unregister()
+ def unregister(self):
+ setattr(_backend, "config", self.old_parallel_config)
+
class parallel_backend(parallel_config):
"""Change the default backend used by Parallel inside a with block.
@@ -324,23 +537,37 @@ class parallel_backend(parallel_config):
joblib.parallel_config: context manager to change the backend
configuration.
"""
+ def __init__(self, backend, n_jobs=-1, inner_max_num_threads=None,
+ **backend_params):
+
+ super().__init__(
+ backend=backend,
+ n_jobs=n_jobs,
+ inner_max_num_threads=inner_max_num_threads,
+ **backend_params
+ )
- def __init__(self, backend, n_jobs=-1, inner_max_num_threads=None, **
- backend_params):
- super().__init__(backend=backend, n_jobs=n_jobs,
- inner_max_num_threads=inner_max_num_threads, **backend_params)
if self.old_parallel_config is None:
self.old_backend_and_jobs = None
else:
- self.old_backend_and_jobs = self.old_parallel_config['backend'
- ], self.old_parallel_config['n_jobs']
- self.new_backend_and_jobs = self.parallel_config['backend'
- ], self.parallel_config['n_jobs']
+ self.old_backend_and_jobs = (
+ self.old_parallel_config["backend"],
+ self.old_parallel_config["n_jobs"],
+ )
+ self.new_backend_and_jobs = (
+ self.parallel_config["backend"],
+ self.parallel_config["n_jobs"],
+ )
def __enter__(self):
return self.new_backend_and_jobs
+# Under Linux or OS X the default start method of multiprocessing
+# can cause third party libraries to crash. Under Python 3.4+ it is possible
+# to set an environment variable to switch the default start method from
+# 'fork' to 'forkserver' or 'spawn' to avoid this issue albeit at the cost
+# of causing semantic changes and some additional pool instantiation overhead.
DEFAULT_MP_CONTEXT = None
if hasattr(mp, 'get_context'):
method = os.environ.get('JOBLIB_START_METHOD', '').strip() or None
@@ -351,36 +578,49 @@ if hasattr(mp, 'get_context'):
class BatchedCalls(object):
"""Wrap a sequence of (func, args, kwargs) tuples as a single callable"""
- def __init__(self, iterator_slice, backend_and_jobs, reducer_callback=
- None, pickle_cache=None):
+ def __init__(self, iterator_slice, backend_and_jobs, reducer_callback=None,
+ pickle_cache=None):
self.items = list(iterator_slice)
self._size = len(self.items)
self._reducer_callback = reducer_callback
if isinstance(backend_and_jobs, tuple):
self._backend, self._n_jobs = backend_and_jobs
else:
+ # this is for backward compatibility purposes. Before 0.12.6,
+ # nested backends were returned without n_jobs indications.
self._backend, self._n_jobs = backend_and_jobs, None
self._pickle_cache = pickle_cache if pickle_cache is not None else {}
def __call__(self):
+ # Set the default nested backend to self._backend but do not set the
+ # change the default number of processes to -1
with parallel_config(backend=self._backend, n_jobs=self._n_jobs):
- return [func(*args, **kwargs) for func, args, kwargs in self.items]
+ return [func(*args, **kwargs)
+ for func, args, kwargs in self.items]
def __reduce__(self):
if self._reducer_callback is not None:
self._reducer_callback()
- return BatchedCalls, (self.items, (self._backend, self._n_jobs),
- None, self._pickle_cache)
+ # no need to pickle the callback.
+ return (
+ BatchedCalls,
+ (self.items, (self._backend, self._n_jobs), None,
+ self._pickle_cache)
+ )
def __len__(self):
return self._size
-TASK_DONE = 'Done'
-TASK_ERROR = 'Error'
-TASK_PENDING = 'Pending'
+# Possible exit status for a task
+TASK_DONE = "Done"
+TASK_ERROR = "Error"
+TASK_PENDING = "Pending"
+###############################################################################
+# CPU count that works also when multiprocessing has been disabled via
+# the JOBLIB_MULTIPROCESSING environment variable
def cpu_count(only_physical_cores=False):
"""Return the number of CPUs.
@@ -392,8 +632,14 @@ def cpu_count(only_physical_cores=False):
If only_physical_cores is True, do not take hyperthreading / SMT logical
cores into account.
"""
- pass
+ if mp is None:
+ return 1
+
+ return loky.cpu_count(only_physical_cores=only_physical_cores)
+
+###############################################################################
+# For verbosity
def _verbosity_filter(index, verbose):
""" Returns False for indices increasingly apart, the distance
@@ -401,14 +647,32 @@ def _verbosity_filter(index, verbose):
We use a lag increasing as the square of index
"""
- pass
-
-
+ if not verbose:
+ return True
+ elif verbose > 10:
+ return False
+ if index == 0:
+ return False
+ verbose = .5 * (11 - verbose) ** 2
+ scale = sqrt(index / verbose)
+ next_scale = sqrt((index + 1) / verbose)
+ return (int(next_scale) == int(scale))
+
+
+###############################################################################
def delayed(function):
"""Decorator used to capture the arguments of a function."""
- pass
+
+ def delayed_function(*args, **kwargs):
+ return function, args, kwargs
+ try:
+ delayed_function = functools.wraps(function)(delayed_function)
+ except AttributeError:
+ " functools.wraps fails on some callable objects "
+ return delayed_function
+###############################################################################
class BatchCompletionCallBack(object):
"""Callback to keep track of completed results and schedule the next tasks.
@@ -424,20 +688,35 @@ class BatchCompletionCallBack(object):
failure.
"""
+ ##########################################################################
+ # METHODS CALLED BY THE MAIN THREAD #
+ ##########################################################################
def __init__(self, dispatch_timestamp, batch_size, parallel):
self.dispatch_timestamp = dispatch_timestamp
self.batch_size = batch_size
self.parallel = parallel
self.parallel_call_id = parallel._call_id
+
+ # Internals to keep track of the status and outcome of the task.
+
+ # Used to hold a reference to the future-like object returned by the
+ # backend after launching this task
+ # This will be set later when calling `register_job`, as it is only
+ # created once the task has been submitted.
self.job = None
+
if not parallel._backend.supports_retrieve_callback:
+ # The status is only used for asynchronous result retrieval in the
+ # callback.
self.status = None
else:
+ # The initial status for the job is TASK_PENDING.
+ # Once it is done, it will be either TASK_DONE, or TASK_ERROR.
self.status = TASK_PENDING
def register_job(self, job):
"""Register the object returned by `apply_async`."""
- pass
+ self.job = job
def get_result(self, timeout):
"""Returns the raw result of the task that was submitted.
@@ -456,7 +735,35 @@ class BatchCompletionCallBack(object):
return it or raise. It will block at most `self.timeout` seconds
waiting for retrieval to complete, after that it raises a TimeoutError.
"""
- pass
+
+ backend = self.parallel._backend
+
+ if backend.supports_retrieve_callback:
+ # We assume that the result has already been retrieved by the
+ # callback thread, and is stored internally. It's just waiting to
+ # be returned.
+ return self._return_or_raise()
+
+ # For other backends, the main thread needs to run the retrieval step.
+ try:
+ if backend.supports_timeout:
+ result = self.job.get(timeout=timeout)
+ else:
+ result = self.job.get()
+ outcome = dict(result=result, status=TASK_DONE)
+ except BaseException as e:
+ outcome = dict(result=e, status=TASK_ERROR)
+ self._register_outcome(outcome)
+
+ return self._return_or_raise()
+
+ def _return_or_raise(self):
+ try:
+ if self.status == TASK_ERROR:
+ raise self._result
+ return self._result
+ finally:
+ del self._result
def get_status(self, timeout):
"""Get the status of the task.
@@ -464,27 +771,76 @@ class BatchCompletionCallBack(object):
This function also checks if the timeout has been reached and register
the TimeoutError outcome when it is the case.
"""
- pass
+ if timeout is None or self.status != TASK_PENDING:
+ return self.status
+
+ # The computation are running and the status is pending.
+ # Check that we did not wait for this jobs more than `timeout`.
+ now = time.time()
+ if not hasattr(self, "_completion_timeout_counter"):
+ self._completion_timeout_counter = now
+
+ if (now - self._completion_timeout_counter) > timeout:
+ outcome = dict(result=TimeoutError(), status=TASK_ERROR)
+ self._register_outcome(outcome)
+
+ return self.status
+ ##########################################################################
+ # METHODS CALLED BY CALLBACK THREADS #
+ ##########################################################################
def __call__(self, out):
"""Function called by the callback thread after a job is completed."""
+
+ # If the backend doesn't support callback retrievals, the next batch of
+ # tasks is dispatched regardless. The result will be retrieved by the
+ # main thread when calling `get_result`.
if not self.parallel._backend.supports_retrieve_callback:
self._dispatch_new()
return
+
+ # If the backend supports retrieving the result in the callback, it
+ # registers the task outcome (TASK_ERROR or TASK_DONE), and schedules
+ # the next batch if needed.
with self.parallel._lock:
+ # Edge case where while the task was processing, the `parallel`
+ # instance has been reset and a new call has been issued, but the
+ # worker managed to complete the task and trigger this callback
+ # call just before being aborted by the reset.
if self.parallel._call_id != self.parallel_call_id:
return
+
+ # When aborting, stop as fast as possible and do not retrieve the
+ # result as it won't be returned by the Parallel call.
if self.parallel._aborting:
return
+
+ # Retrieves the result of the task in the main process and dispatch
+ # a new batch if needed.
job_succeeded = self._retrieve_result(out)
+
if not self.parallel.return_ordered:
+ # Append the job to the queue in the order of completion
+ # instead of submission.
self.parallel._jobs.append(self)
+
if job_succeeded:
self._dispatch_new()
def _dispatch_new(self):
"""Schedule the next batch of tasks to be processed."""
- pass
+
+ # This steps ensure that auto-batching works as expected.
+ this_batch_duration = time.time() - self.dispatch_timestamp
+ self.parallel._backend.batch_completed(self.batch_size,
+ this_batch_duration)
+
+ # Schedule the next batch of tasks.
+ with self.parallel._lock:
+ self.parallel.n_completed_tasks += self.batch_size
+ self.parallel.print_progress()
+ if self.parallel._original_iterator is not None:
+ self.parallel.dispatch_next()
def _retrieve_result(self, out):
"""Fetch and register the outcome of a task.
@@ -493,16 +849,48 @@ class BatchCompletionCallBack(object):
This function is only called by backends that support retrieving
the task result in the callback thread.
"""
- pass
-
+ try:
+ result = self.parallel._backend.retrieve_result_callback(out)
+ outcome = dict(status=TASK_DONE, result=result)
+ except BaseException as e:
+ # Avoid keeping references to parallel in the error.
+ e.__traceback__ = None
+ outcome = dict(result=e, status=TASK_ERROR)
+
+ self._register_outcome(outcome)
+ return outcome['status'] != TASK_ERROR
+
+ ##########################################################################
+ # This method can be called either in the main thread #
+ # or in the callback thread. #
+ ##########################################################################
def _register_outcome(self, outcome):
"""Register the outcome of a task.
This method can be called only once, future calls will be ignored.
"""
- pass
+ # Covers the edge case where the main thread tries to register a
+ # `TimeoutError` while the callback thread tries to register a result
+ # at the same time.
+ with self.parallel._lock:
+ if self.status not in (TASK_PENDING, None):
+ return
+ self.status = outcome["status"]
+
+ self._result = outcome["result"]
+
+ # Once the result and the status are extracted, the last reference to
+ # the job can be deleted.
+ self.job = None
+
+ # As soon as an error as been spotted, early stopping flags are sent to
+ # the `parallel` instance.
+ if self.status == TASK_ERROR:
+ self.parallel._exception = True
+ self.parallel._aborting = True
+###############################################################################
def register_parallel_backend(name, factory, make_default=False):
"""Register a new Parallel backend factory.
@@ -518,7 +906,10 @@ def register_parallel_backend(name, factory, make_default=False):
.. versionadded:: 0.10
"""
- pass
+ BACKENDS[name] = factory
+ if make_default:
+ global DEFAULT_BACKEND
+ DEFAULT_BACKEND = name
def effective_n_jobs(n_jobs=-1):
@@ -542,11 +933,18 @@ def effective_n_jobs(n_jobs=-1):
.. versionadded:: 0.10
"""
- pass
+ if n_jobs == 1:
+ return 1
+ backend, backend_n_jobs = get_active_backend()
+ if n_jobs is None:
+ n_jobs = backend_n_jobs
+ return backend.effective_n_jobs(n_jobs=n_jobs)
+
+###############################################################################
class Parallel(Logger):
- """ Helper class for readable parallel mapping.
+ ''' Helper class for readable parallel mapping.
Read more in the :ref:`User Guide <parallel>`.
@@ -795,95 +1193,148 @@ class Parallel(Logger):
[Parallel(n_jobs=2)]: Done 6 out of 6 | elapsed: 0.0s remaining: 0.0s
[Parallel(n_jobs=2)]: Done 6 out of 6 | elapsed: 0.0s finished
- """
-
- def __init__(self, n_jobs=default_parallel_config['n_jobs'], backend=
- default_parallel_config['backend'], return_as='list', verbose=
- default_parallel_config['verbose'], timeout=None, pre_dispatch=
- '2 * n_jobs', batch_size='auto', temp_folder=
- default_parallel_config['temp_folder'], max_nbytes=
- default_parallel_config['max_nbytes'], mmap_mode=
- default_parallel_config['mmap_mode'], prefer=
- default_parallel_config['prefer'], require=default_parallel_config[
- 'require']):
+ ''' # noqa: E501
+ def __init__(
+ self,
+ n_jobs=default_parallel_config["n_jobs"],
+ backend=default_parallel_config['backend'],
+ return_as="list",
+ verbose=default_parallel_config["verbose"],
+ timeout=None,
+ pre_dispatch='2 * n_jobs',
+ batch_size='auto',
+ temp_folder=default_parallel_config["temp_folder"],
+ max_nbytes=default_parallel_config["max_nbytes"],
+ mmap_mode=default_parallel_config["mmap_mode"],
+ prefer=default_parallel_config["prefer"],
+ require=default_parallel_config["require"],
+ ):
+ # Initiate parent Logger class state
super().__init__()
+
+ # Interpret n_jobs=None as 'unset'
if n_jobs is None:
- n_jobs = default_parallel_config['n_jobs']
- active_backend, context_config = _get_active_backend(prefer=prefer,
- require=require, verbose=verbose)
+ n_jobs = default_parallel_config["n_jobs"]
+
+ active_backend, context_config = _get_active_backend(
+ prefer=prefer, require=require, verbose=verbose
+ )
+
nesting_level = active_backend.nesting_level
- self.verbose = _get_config_param(verbose, context_config, 'verbose')
+
+ self.verbose = _get_config_param(verbose, context_config, "verbose")
self.timeout = timeout
self.pre_dispatch = pre_dispatch
- if return_as not in {'list', 'generator', 'generator_unordered'}:
+
+ if return_as not in {"list", "generator", "generator_unordered"}:
raise ValueError(
- f'Expected `return_as` parameter to be a string equal to "list","generator" or "generator_unordered", but got {return_as} instead.'
- )
+ 'Expected `return_as` parameter to be a string equal to "list"'
+ f',"generator" or "generator_unordered", but got {return_as} '
+ "instead."
+ )
self.return_as = return_as
- self.return_generator = return_as != 'list'
- self.return_ordered = return_as != 'generator_unordered'
- self._backend_args = {k: _get_config_param(param, context_config, k
- ) for param, k in [(max_nbytes, 'max_nbytes'), (temp_folder,
- 'temp_folder'), (mmap_mode, 'mmap_mode'), (prefer, 'prefer'), (
- require, 'require'), (verbose, 'verbose')]}
- if isinstance(self._backend_args['max_nbytes'], str):
- self._backend_args['max_nbytes'] = memstr_to_bytes(self.
- _backend_args['max_nbytes'])
- self._backend_args['verbose'] = max(0, self._backend_args['verbose'
- ] - 50)
+ self.return_generator = return_as != "list"
+ self.return_ordered = return_as != "generator_unordered"
+
+ # Check if we are under a parallel_config or parallel_backend
+ # context manager and use the config from the context manager
+ # for arguments that are not explicitly set.
+ self._backend_args = {
+ k: _get_config_param(param, context_config, k) for param, k in [
+ (max_nbytes, "max_nbytes"),
+ (temp_folder, "temp_folder"),
+ (mmap_mode, "mmap_mode"),
+ (prefer, "prefer"),
+ (require, "require"),
+ (verbose, "verbose"),
+ ]
+ }
+
+ if isinstance(self._backend_args["max_nbytes"], str):
+ self._backend_args["max_nbytes"] = memstr_to_bytes(
+ self._backend_args["max_nbytes"]
+ )
+ self._backend_args["verbose"] = max(
+ 0, self._backend_args["verbose"] - 50
+ )
+
if DEFAULT_MP_CONTEXT is not None:
self._backend_args['context'] = DEFAULT_MP_CONTEXT
- elif hasattr(mp, 'get_context'):
+ elif hasattr(mp, "get_context"):
self._backend_args['context'] = mp.get_context()
+
if backend is default_parallel_config['backend'] or backend is None:
backend = active_backend
+
elif isinstance(backend, ParallelBackendBase):
+ # Use provided backend as is, with the current nesting_level if it
+ # is not set yet.
if backend.nesting_level is None:
backend.nesting_level = nesting_level
+
elif hasattr(backend, 'Pool') and hasattr(backend, 'Lock'):
+ # Make it possible to pass a custom multiprocessing context as
+ # backend to change the start method to forkserver or spawn or
+ # preload modules on the forkserver helper process.
self._backend_args['context'] = backend
backend = MultiprocessingBackend(nesting_level=nesting_level)
+
elif backend not in BACKENDS and backend in MAYBE_AVAILABLE_BACKENDS:
warnings.warn(
- f"joblib backend '{backend}' is not available on your system, falling back to {DEFAULT_BACKEND}."
- , UserWarning, stacklevel=2)
+ f"joblib backend '{backend}' is not available on "
+ f"your system, falling back to {DEFAULT_BACKEND}.",
+ UserWarning,
+ stacklevel=2)
BACKENDS[backend] = BACKENDS[DEFAULT_BACKEND]
backend = BACKENDS[DEFAULT_BACKEND](nesting_level=nesting_level)
+
else:
try:
backend_factory = BACKENDS[backend]
except KeyError as e:
- raise ValueError('Invalid backend: %s, expected one of %r' %
- (backend, sorted(BACKENDS.keys()))) from e
+ raise ValueError("Invalid backend: %s, expected one of %r"
+ % (backend, sorted(BACKENDS.keys()))) from e
backend = backend_factory(nesting_level=nesting_level)
- n_jobs = _get_config_param(n_jobs, context_config, 'n_jobs')
+
+ n_jobs = _get_config_param(n_jobs, context_config, "n_jobs")
if n_jobs is None:
+ # No specific context override and no specific value request:
+ # default to the default of the backend.
n_jobs = backend.default_n_jobs
try:
n_jobs = int(n_jobs)
except ValueError:
- raise ValueError('n_jobs could not be converted to int')
+ raise ValueError("n_jobs could not be converted to int")
self.n_jobs = n_jobs
- if require == 'sharedmem' and not getattr(backend,
- 'supports_sharedmem', False):
- raise ValueError('Backend %s does not support shared memory' %
- backend)
- if batch_size == 'auto' or isinstance(batch_size, Integral
- ) and batch_size > 0:
+
+ if (require == 'sharedmem' and
+ not getattr(backend, 'supports_sharedmem', False)):
+ raise ValueError("Backend %s does not support shared memory"
+ % backend)
+
+ if (batch_size == 'auto' or isinstance(batch_size, Integral) and
+ batch_size > 0):
self.batch_size = batch_size
else:
raise ValueError(
- "batch_size must be 'auto' or a positive integer, got: %r" %
- batch_size)
+ "batch_size must be 'auto' or a positive integer, got: %r"
+ % batch_size)
+
if not isinstance(backend, SequentialBackend):
if self.return_generator and not backend.supports_return_generator:
- raise ValueError('Backend {} does not support return_as={}'
- .format(backend, return_as))
+ raise ValueError(
+ "Backend {} does not support "
+ "return_as={}".format(backend, return_as)
+ )
+ # This lock is used to coordinate the main thread of this process
+ # with the async callback thread of our the pool.
self._lock = threading.RLock()
self._jobs = collections.deque()
self._pending_outputs = list()
self._ready_batches = queue.Queue()
self._reducer_callback = None
+
+ # Internal variables
self._backend = backend
self._running = False
self._managed_backend = False
@@ -904,7 +1355,35 @@ class Parallel(Logger):
def _initialize_backend(self):
"""Build a process or thread pool and return the number of workers"""
- pass
+ try:
+ n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
+ **self._backend_args)
+ if self.timeout is not None and not self._backend.supports_timeout:
+ warnings.warn(
+ 'The backend class {!r} does not support timeout. '
+ "You have set 'timeout={}' in Parallel but "
+ "the 'timeout' parameter will not be used.".format(
+ self._backend.__class__.__name__,
+ self.timeout))
+
+ except FallbackToBackend as e:
+ # Recursively initialize the backend in case of requested fallback.
+ self._backend = e.backend
+ n_jobs = self._initialize_backend()
+
+ return n_jobs
+
+ def _effective_n_jobs(self):
+ if self._backend:
+ return self._backend.effective_n_jobs(self.n_jobs)
+ return 1
+
+ def _terminate_and_reset(self):
+ if hasattr(self._backend, 'stop_call') and self._calling:
+ self._backend.stop_call()
+ self._calling = False
+ if not self._managed_backend:
+ self._backend.terminate()
def _dispatch(self, batch):
"""Queue the batch for computing, with or without multiprocessing
@@ -913,7 +1392,31 @@ class Parallel(Logger):
indirectly via dispatch_one_batch.
"""
- pass
+ # If job.get() catches an exception, it closes the queue:
+ if self._aborting:
+ return
+
+ batch_size = len(batch)
+
+ self.n_dispatched_tasks += batch_size
+ self.n_dispatched_batches += 1
+
+ dispatch_timestamp = time.time()
+
+ batch_tracker = BatchCompletionCallBack(
+ dispatch_timestamp, batch_size, self
+ )
+
+ if self.return_ordered:
+ self._jobs.append(batch_tracker)
+
+ # If return_ordered is False, the batch_tracker is not stored in the
+ # jobs queue at the time of submission. Instead, it will be appended to
+ # the queue by itself as soon as the callback is triggered to be able
+ # to return the results in the order of completion.
+
+ job = self._backend.apply_async(batch, callback=batch_tracker)
+ batch_tracker.register_job(job)
def dispatch_next(self):
"""Dispatch more data for parallel processing
@@ -923,7 +1426,9 @@ class Parallel(Logger):
against concurrent consumption of the unprotected iterator.
"""
- pass
+ if not self.dispatch_one_batch(self._original_iterator):
+ self._iterating = False
+ self._original_iterator = None
def dispatch_one_batch(self, iterator):
"""Prefetch the tasks for the next batch and dispatch them.
@@ -935,41 +1440,381 @@ class Parallel(Logger):
lock so calling this function should be thread safe.
"""
- pass
+
+ if self._aborting:
+ return False
+
+ batch_size = self._get_batch_size()
+
+ with self._lock:
+ # to ensure an even distribution of the workload between workers,
+ # we look ahead in the original iterators more than batch_size
+ # tasks - However, we keep consuming only one batch at each
+ # dispatch_one_batch call. The extra tasks are stored in a local
+ # queue, _ready_batches, that is looked-up prior to re-consuming
+ # tasks from the origal iterator.
+ try:
+ tasks = self._ready_batches.get(block=False)
+ except queue.Empty:
+ # slice the iterator n_jobs * batchsize items at a time. If the
+ # slice returns less than that, then the current batchsize puts
+ # too much weight on a subset of workers, while other may end
+ # up starving. So in this case, re-scale the batch size
+ # accordingly to distribute evenly the last items between all
+ # workers.
+ n_jobs = self._cached_effective_n_jobs
+ big_batch_size = batch_size * n_jobs
+
+ try:
+ islice = list(itertools.islice(iterator, big_batch_size))
+ except Exception as e:
+ # Handle the fact that the generator of task raised an
+ # exception. As this part of the code can be executed in
+ # a thread internal to the backend, register a task with
+ # an error that will be raised in the user's thread.
+ if isinstance(e.__context__, queue.Empty):
+ # Suppress the cause of the exception if it is
+ # queue.Empty to avoid cluttered traceback. Only do it
+ # if the __context__ is really empty to avoid messing
+ # with causes of the original error.
+ e.__cause__ = None
+ batch_tracker = BatchCompletionCallBack(
+ 0, batch_size, self
+ )
+ self._jobs.append(batch_tracker)
+ batch_tracker._register_outcome(dict(
+ result=e, status=TASK_ERROR
+ ))
+ return True
+
+ if len(islice) == 0:
+ return False
+ elif (iterator is self._original_iterator and
+ len(islice) < big_batch_size):
+ # We reached the end of the original iterator (unless
+ # iterator is the ``pre_dispatch``-long initial slice of
+ # the original iterator) -- decrease the batch size to
+ # account for potential variance in the batches running
+ # time.
+ final_batch_size = max(1, len(islice) // (10 * n_jobs))
+ else:
+ final_batch_size = max(1, len(islice) // n_jobs)
+
+ # enqueue n_jobs batches in a local queue
+ for i in range(0, len(islice), final_batch_size):
+ tasks = BatchedCalls(islice[i:i + final_batch_size],
+ self._backend.get_nested_backend(),
+ self._reducer_callback,
+ self._pickle_cache)
+ self._ready_batches.put(tasks)
+
+ # finally, get one task.
+ tasks = self._ready_batches.get(block=False)
+ if len(tasks) == 0:
+ # No more tasks available in the iterator: tell caller to stop.
+ return False
+ else:
+ self._dispatch(tasks)
+ return True
def _get_batch_size(self):
"""Returns the effective batch size for dispatch"""
- pass
+ if self.batch_size == 'auto':
+ return self._backend.compute_batch_size()
+ else:
+ # Fixed batch size strategy
+ return self.batch_size
def _print(self, msg):
"""Display the message on stout or stderr depending on verbosity"""
- pass
+ # XXX: Not using the logger framework: need to
+ # learn to use logger better.
+ if not self.verbose:
+ return
+ if self.verbose < 50:
+ writer = sys.stderr.write
+ else:
+ writer = sys.stdout.write
+ writer(f"[{self}]: {msg}\n")
def _is_completed(self):
"""Check if all tasks have been completed"""
- pass
+ return self.n_completed_tasks == self.n_dispatched_tasks and not (
+ self._iterating or self._aborting
+ )
def print_progress(self):
"""Display the process of the parallel execution only a fraction
of time, controlled by self.verbose.
"""
- pass
+
+ if not self.verbose:
+ return
+
+ elapsed_time = time.time() - self._start_time
+
+ if self._is_completed():
+ # Make sure that we get a last message telling us we are done
+ self._print(
+ f"Done {self.n_completed_tasks:3d} out of "
+ f"{self.n_completed_tasks:3d} | elapsed: "
+ f"{short_format_time(elapsed_time)} finished"
+ )
+ return
+
+ # Original job iterator becomes None once it has been fully
+ # consumed: at this point we know the total number of jobs and we are
+ # able to display an estimation of the remaining time based on already
+ # completed jobs. Otherwise, we simply display the number of completed
+ # tasks.
+ elif self._original_iterator is not None:
+ if _verbosity_filter(self.n_dispatched_batches, self.verbose):
+ return
+ self._print(
+ f"Done {self.n_completed_tasks:3d} tasks | elapsed: "
+ f"{short_format_time(elapsed_time)}"
+ )
+ else:
+ index = self.n_completed_tasks
+ # We are finished dispatching
+ total_tasks = self.n_dispatched_tasks
+ # We always display the first loop
+ if not index == 0:
+ # Display depending on the number of remaining items
+ # A message as soon as we finish dispatching, cursor is 0
+ cursor = (total_tasks - index + 1 -
+ self._pre_dispatch_amount)
+ frequency = (total_tasks // self.verbose) + 1
+ is_last_item = (index + 1 == total_tasks)
+ if (is_last_item or cursor % frequency):
+ return
+ remaining_time = (elapsed_time / index) * \
+ (self.n_dispatched_tasks - index * 1.0)
+ # only display status if remaining time is greater or equal to 0
+ self._print(
+ f"Done {index:3d} out of {total_tasks:3d} | elapsed: "
+ f"{short_format_time(elapsed_time)} remaining: "
+ f"{short_format_time(remaining_time)}"
+ )
+
+ def _abort(self):
+ # Stop dispatching new jobs in the async callback thread
+ self._aborting = True
+
+ # If the backend allows it, cancel or kill remaining running
+ # tasks without waiting for the results as we will raise
+ # the exception we got back to the caller instead of returning
+ # any result.
+ backend = self._backend
+ if (not self._aborted and hasattr(backend, 'abort_everything')):
+ # If the backend is managed externally we need to make sure
+ # to leave it in a working state to allow for future jobs
+ # scheduling.
+ ensure_ready = self._managed_backend
+ backend.abort_everything(ensure_ready=ensure_ready)
+ self._aborted = True
+
+ def _start(self, iterator, pre_dispatch):
+ # Only set self._iterating to True if at least a batch
+ # was dispatched. In particular this covers the edge
+ # case of Parallel used with an exhausted iterator. If
+ # self._original_iterator is None, then this means either
+ # that pre_dispatch == "all", n_jobs == 1 or that the first batch
+ # was very quick and its callback already dispatched all the
+ # remaining jobs.
+ self._iterating = False
+ if self.dispatch_one_batch(iterator):
+ self._iterating = self._original_iterator is not None
+
+ while self.dispatch_one_batch(iterator):
+ pass
+
+ if pre_dispatch == "all":
+ # The iterable was consumed all at once by the above for loop.
+ # No need to wait for async callbacks to trigger to
+ # consumption.
+ self._iterating = False
def _get_outputs(self, iterator, pre_dispatch):
"""Iterator returning the tasks' output as soon as they are ready."""
- pass
+ dispatch_thread_id = threading.get_ident()
+ detach_generator_exit = False
+ try:
+ self._start(iterator, pre_dispatch)
+ # first yield returns None, for internal use only. This ensures
+ # that we enter the try/except block and start dispatching the
+ # tasks.
+ yield
+
+ with self._backend.retrieval_context():
+ yield from self._retrieve()
+
+ except GeneratorExit:
+ # The generator has been garbage collected before being fully
+ # consumed. This aborts the remaining tasks if possible and warn
+ # the user if necessary.
+ self._exception = True
+
+ # In some interpreters such as PyPy, GeneratorExit can be raised in
+ # a different thread than the one used to start the dispatch of the
+ # parallel tasks. This can lead to hang when a thread attempts to
+ # join itself. As workaround, we detach the execution of the
+ # aborting code to a dedicated thread. We then need to make sure
+ # the rest of the function does not call `_terminate_and_reset`
+ # in finally.
+ if dispatch_thread_id != threading.get_ident():
+ if not IS_PYPY:
+ warnings.warn(
+ "A generator produced by joblib.Parallel has been "
+ "gc'ed in an unexpected thread. This behavior should "
+ "not cause major -issues but to make sure, please "
+ "report this warning and your use case at "
+ "https://github.com/joblib/joblib/issues so it can "
+ "be investigated."
+ )
+
+ detach_generator_exit = True
+ _parallel = self
+
+ class _GeneratorExitThread(threading.Thread):
+ def run(self):
+ _parallel._abort()
+ if _parallel.return_generator:
+ _parallel._warn_exit_early()
+ _parallel._terminate_and_reset()
+
+ _GeneratorExitThread(
+ name="GeneratorExitThread"
+ ).start()
+ return
+
+ # Otherwise, we are in the thread that started the dispatch: we can
+ # safely abort the execution and warn the user.
+ self._abort()
+ if self.return_generator:
+ self._warn_exit_early()
+
+ raise
+
+ # Note: we catch any BaseException instead of just Exception instances
+ # to also include KeyboardInterrupt
+ except BaseException:
+ self._exception = True
+ self._abort()
+ raise
+ finally:
+ # Store the unconsumed tasks and terminate the workers if necessary
+ _remaining_outputs = ([] if self._exception else self._jobs)
+ self._jobs = collections.deque()
+ self._running = False
+ if not detach_generator_exit:
+ self._terminate_and_reset()
+
+ while len(_remaining_outputs) > 0:
+ batched_results = _remaining_outputs.popleft()
+ batched_results = batched_results.get_result(self.timeout)
+ for result in batched_results:
+ yield result
def _wait_retrieval(self):
"""Return True if we need to continue retrieving some tasks."""
- pass
+
+ # If the input load is still being iterated over, it means that tasks
+ # are still on the dispatch waitlist and their results will need to
+ # be retrieved later on.
+ if self._iterating:
+ return True
+
+ # If some of the dispatched tasks are still being processed by the
+ # workers, wait for the compute to finish before starting retrieval
+ if self.n_completed_tasks < self.n_dispatched_tasks:
+ return True
+
+ # For backends that does not support retrieving asynchronously the
+ # result to the main process, all results must be carefully retrieved
+ # in the _retrieve loop in the main thread while the backend is alive.
+ # For other backends, the actual retrieval is done asynchronously in
+ # the callback thread, and we can terminate the backend before the
+ # `self._jobs` result list has been emptied. The remaining results
+ # will be collected in the `finally` step of the generator.
+ if not self._backend.supports_retrieve_callback:
+ if len(self._jobs) > 0:
+ return True
+
+ return False
+
+ def _retrieve(self):
+ while self._wait_retrieval():
+
+ # If the callback thread of a worker has signaled that its task
+ # triggered an exception, or if the retrieval loop has raised an
+ # exception (e.g. `GeneratorExit`), exit the loop and surface the
+ # worker traceback.
+ if self._aborting:
+ self._raise_error_fast()
+ break
+
+ # If the next job is not ready for retrieval yet, we just wait for
+ # async callbacks to progress.
+ if ((len(self._jobs) == 0) or
+ (self._jobs[0].get_status(
+ timeout=self.timeout) == TASK_PENDING)):
+ time.sleep(0.01)
+ continue
+
+ # We need to be careful: the job list can be filling up as
+ # we empty it and Python list are not thread-safe by
+ # default hence the use of the lock
+ with self._lock:
+ batched_results = self._jobs.popleft()
+
+ # Flatten the batched results to output one output at a time
+ batched_results = batched_results.get_result(self.timeout)
+ for result in batched_results:
+ self._nb_consumed += 1
+ yield result
def _raise_error_fast(self):
"""If we are aborting, raise if a job caused an error."""
- pass
+
+ # Find the first job whose status is TASK_ERROR if it exists.
+ with self._lock:
+ error_job = next((job for job in self._jobs
+ if job.status == TASK_ERROR), None)
+
+ # If this error job exists, immediately raise the error by
+ # calling get_result. This job might not exists if abort has been
+ # called directly or if the generator is gc'ed.
+ if error_job is not None:
+ error_job.get_result(self.timeout)
def _warn_exit_early(self):
"""Warn the user if the generator is gc'ed before being consumned."""
- pass
+ ready_outputs = self.n_completed_tasks - self._nb_consumed
+ is_completed = self._is_completed()
+ msg = ""
+ if ready_outputs:
+ msg += (
+ f"{ready_outputs} tasks have been successfully executed "
+ " but not used."
+ )
+ if not is_completed:
+ msg += " Additionally, "
+
+ if not is_completed:
+ msg += (
+ f"{self.n_dispatched_tasks - self.n_completed_tasks} tasks "
+ "which were still being processed by the workers have been "
+ "cancelled."
+ )
+
+ if msg:
+ msg += (
+ " You could benefit from adjusting the input task "
+ "iterator to limit unnecessary computation time."
+ )
+
+ warnings.warn(msg)
def _get_sequential_output(self, iterable):
"""Separate loop for sequential output.
@@ -977,58 +1822,188 @@ class Parallel(Logger):
This simplifies the traceback in case of errors and reduces the
overhead of calling sequential tasks with `joblib`.
"""
- pass
+ try:
+ self._iterating = True
+ self._original_iterator = iterable
+ batch_size = self._get_batch_size()
+
+ if batch_size != 1:
+ it = iter(iterable)
+ iterable_batched = iter(
+ lambda: tuple(itertools.islice(it, batch_size)), ()
+ )
+ iterable = (
+ task for batch in iterable_batched for task in batch
+ )
+
+ # first yield returns None, for internal use only. This ensures
+ # that we enter the try/except block and setup the generator.
+ yield None
+
+ # Sequentially call the tasks and yield the results.
+ for func, args, kwargs in iterable:
+ self.n_dispatched_batches += 1
+ self.n_dispatched_tasks += 1
+ res = func(*args, **kwargs)
+ self.n_completed_tasks += 1
+ self.print_progress()
+ yield res
+ self._nb_consumed += 1
+ except BaseException:
+ self._exception = True
+ self._aborting = True
+ self._aborted = True
+ raise
+ finally:
+ self.print_progress()
+ self._running = False
+ self._iterating = False
+ self._original_iterator = None
def _reset_run_tracking(self):
"""Reset the counters and flags used to track the execution."""
- pass
+
+ # Makes sur the parallel instance was not previously running in a
+ # thread-safe way.
+ with getattr(self, '_lock', nullcontext()):
+ if self._running:
+ msg = 'This Parallel instance is already running !'
+ if self.return_generator is True:
+ msg += (
+ " Before submitting new tasks, you must wait for the "
+ "completion of all the previous tasks, or clean all "
+ "references to the output generator."
+ )
+ raise RuntimeError(msg)
+ self._running = True
+
+ # Counter to keep track of the task dispatched and completed.
+ self.n_dispatched_batches = 0
+ self.n_dispatched_tasks = 0
+ self.n_completed_tasks = 0
+
+ # Following count is incremented by one each time the user iterates
+ # on the output generator, it is used to prepare an informative
+ # warning message in case the generator is deleted before all the
+ # dispatched tasks have been consumed.
+ self._nb_consumed = 0
+
+ # Following flags are used to synchronize the threads in case one of
+ # the tasks error-out to ensure that all workers abort fast and that
+ # the backend terminates properly.
+
+ # Set to True as soon as a worker signals that a task errors-out
+ self._exception = False
+ # Set to True in case of early termination following an incident
+ self._aborting = False
+ # Set to True after abortion is complete
+ self._aborted = False
def __call__(self, iterable):
"""Main function to dispatch parallel tasks."""
+
self._reset_run_tracking()
self._start_time = time.time()
+
if not self._managed_backend:
n_jobs = self._initialize_backend()
else:
n_jobs = self._effective_n_jobs()
+
if n_jobs == 1:
+ # If n_jobs==1, run the computation sequentially and return
+ # immediately to avoid overheads.
output = self._get_sequential_output(iterable)
next(output)
return output if self.return_generator else list(output)
+
+ # Let's create an ID that uniquely identifies the current call. If the
+ # call is interrupted early and that the same instance is immediately
+ # re-used, this id will be used to prevent workers that were
+ # concurrently finalizing a task from the previous call to run the
+ # callback.
with self._lock:
self._call_id = uuid4().hex
+
+ # self._effective_n_jobs should be called in the Parallel.__call__
+ # thread only -- store its value in an attribute for further queries.
self._cached_effective_n_jobs = n_jobs
+
if isinstance(self._backend, LokyBackend):
+ # For the loky backend, we add a callback executed when reducing
+ # BatchCalls, that makes the loky executor use a temporary folder
+ # specific to this Parallel object when pickling temporary memmaps.
+ # This callback is necessary to ensure that several Parallel
+ # objects using the same reusable executor don't use the same
+ # temporary resources.
def _batched_calls_reducer_callback():
- self._backend._workers._temp_folder_manager.set_current_context(
- self._id)
+ # Relevant implementation detail: the following lines, called
+ # when reducing BatchedCalls, are called in a thread-safe
+ # situation, meaning that the context of the temporary folder
+ # manager will not be changed in between the callback execution
+ # and the end of the BatchedCalls pickling. The reason is that
+ # pickling (the only place where set_current_context is used)
+ # is done from a single thread (the queue_feeder_thread).
+ self._backend._workers._temp_folder_manager.set_current_context( # noqa
+ self._id
+ )
self._reducer_callback = _batched_calls_reducer_callback
+
+ # self._effective_n_jobs should be called in the Parallel.__call__
+ # thread only -- store its value in an attribute for further queries.
self._cached_effective_n_jobs = n_jobs
+
backend_name = self._backend.__class__.__name__
if n_jobs == 0:
- raise RuntimeError('%s has no active worker.' % backend_name)
+ raise RuntimeError("%s has no active worker." % backend_name)
+
self._print(
- f'Using backend {backend_name} with {n_jobs} concurrent workers.')
+ f"Using backend {backend_name} with {n_jobs} concurrent workers."
+ )
if hasattr(self._backend, 'start_call'):
self._backend.start_call()
+
+ # Following flag prevents double calls to `backend.stop_call`.
self._calling = True
+
iterator = iter(iterable)
pre_dispatch = self.pre_dispatch
+
if pre_dispatch == 'all':
+ # prevent further dispatch via multiprocessing callback thread
self._original_iterator = None
self._pre_dispatch_amount = 0
else:
self._original_iterator = iterator
if hasattr(pre_dispatch, 'endswith'):
- pre_dispatch = eval_expr(pre_dispatch.replace('n_jobs', str
- (n_jobs)))
+ pre_dispatch = eval_expr(
+ pre_dispatch.replace("n_jobs", str(n_jobs))
+ )
self._pre_dispatch_amount = pre_dispatch = int(pre_dispatch)
+
+ # The main thread will consume the first pre_dispatch items and
+ # the remaining items will later be lazily dispatched by async
+ # callbacks upon task completions.
+
+ # TODO: this iterator should be batch_size * n_jobs
iterator = itertools.islice(iterator, self._pre_dispatch_amount)
+
+ # Use a caching dict for callables that are pickled with cloudpickle to
+ # improve performances. This cache is used only in the case of
+ # functions that are defined in the __main__ module, functions that
+ # are defined locally (inside another function) and lambda expressions.
self._pickle_cache = dict()
+
output = self._get_outputs(iterator, pre_dispatch)
self._call_ref = weakref.ref(output)
+
+ # The first item from the output is blank, but it makes the interpreter
+ # progress until it enters the Try/Except block of the generator and
+ # reaches the first `yield` statement. This starts the asynchronous
+ # dispatch of the tasks to the workers.
next(output)
+
return output if self.return_generator else list(output)
def __repr__(self):
diff --git a/joblib/pool.py b/joblib/pool.py
index a5c2643..c0c3549 100644
--- a/joblib/pool.py
+++ b/joblib/pool.py
@@ -9,27 +9,42 @@ available as it implements subclasses of multiprocessing Pool
that uses a custom alternative to SimpleQueue.
"""
+# Author: Olivier Grisel <olivier.grisel@ensta.org>
+# Copyright: 2012, Olivier Grisel
+# License: BSD 3 clause
+
import copyreg
import sys
import warnings
from time import sleep
+
try:
WindowsError
except NameError:
WindowsError = type(None)
+
from pickle import Pickler
+
from pickle import HIGHEST_PROTOCOL
from io import BytesIO
+
from ._memmapping_reducer import get_memmapping_reducers
from ._memmapping_reducer import TemporaryResourcesManager
from ._multiprocessing_helpers import mp, assert_spawning
+
+# We need the class definition to derive from it, not the multiprocessing.Pool
+# factory function
from multiprocessing.pool import Pool
+
try:
import numpy as np
except ImportError:
np = None
+###############################################################################
+# Enable custom pickling in Pool queues
+
class CustomizablePickler(Pickler):
"""Pickler that accepts custom reducers.
@@ -48,20 +63,38 @@ class CustomizablePickler(Pickler):
"""
+ # We override the pure Python pickler as its the only way to be able to
+ # customize the dispatch table without side effects in Python 2.7
+ # to 3.2. For Python 3.3+ leverage the new dispatch_table
+ # feature from https://bugs.python.org/issue14166 that makes it possible
+ # to use the C implementation of the Pickler which is faster.
+
def __init__(self, writer, reducers=None, protocol=HIGHEST_PROTOCOL):
Pickler.__init__(self, writer, protocol=protocol)
if reducers is None:
reducers = {}
if hasattr(Pickler, 'dispatch'):
+ # Make the dispatch registry an instance level attribute instead of
+ # a reference to the class dictionary under Python 2
self.dispatch = Pickler.dispatch.copy()
else:
+ # Under Python 3 initialize the dispatch table with a copy of the
+ # default registry
self.dispatch_table = copyreg.dispatch_table.copy()
for type, reduce_func in reducers.items():
self.register(type, reduce_func)
def register(self, type, reduce_func):
"""Attach a reducer function to a given type in the dispatch table."""
- pass
+ if hasattr(Pickler, 'dispatch'):
+ # Python 2 pickler dispatching is not explicitly customizable.
+ # Let us use a closure to workaround this limitation.
+ def dispatcher(self, obj):
+ reduced = reduce_func(obj)
+ self.save_reduce(obj=obj, *reduced)
+ self.dispatch[type] = dispatcher
+ else:
+ self.dispatch_table[type] = reduce_func
class CustomizablePicklingQueue(object):
@@ -93,14 +126,54 @@ class CustomizablePicklingQueue(object):
def __getstate__(self):
assert_spawning(self)
- return (self._reader, self._writer, self._rlock, self._wlock, self.
- _reducers)
+ return (self._reader, self._writer, self._rlock, self._wlock,
+ self._reducers)
def __setstate__(self, state):
- (self._reader, self._writer, self._rlock, self._wlock, self._reducers
- ) = state
+ (self._reader, self._writer, self._rlock, self._wlock,
+ self._reducers) = state
self._make_methods()
+ def empty(self):
+ return not self._reader.poll()
+
+ def _make_methods(self):
+ self._recv = recv = self._reader.recv
+ racquire, rrelease = self._rlock.acquire, self._rlock.release
+
+ def get():
+ racquire()
+ try:
+ return recv()
+ finally:
+ rrelease()
+
+ self.get = get
+
+ if self._reducers:
+ def send(obj):
+ buffer = BytesIO()
+ CustomizablePickler(buffer, self._reducers).dump(obj)
+ self._writer.send_bytes(buffer.getvalue())
+ self._send = send
+ else:
+ self._send = send = self._writer.send
+ if self._wlock is None:
+ # writes to a message oriented win32 pipe are atomic
+ self.put = send
+ else:
+ wlock_acquire, wlock_release = (
+ self._wlock.acquire, self._wlock.release)
+
+ def put(obj):
+ wlock_acquire()
+ try:
+ return send(obj)
+ finally:
+ wlock_release()
+
+ self.put = put
+
class PicklingPool(Pool):
"""Pool implementation with customizable pickling reducers.
@@ -120,7 +193,7 @@ class PicklingPool(Pool):
"""
def __init__(self, processes=None, forward_reducers=None,
- backward_reducers=None, **kwargs):
+ backward_reducers=None, **kwargs):
if forward_reducers is None:
forward_reducers = dict()
if backward_reducers is None:
@@ -131,6 +204,15 @@ class PicklingPool(Pool):
poolargs.update(kwargs)
super(PicklingPool, self).__init__(**poolargs)
+ def _setup_queues(self):
+ context = getattr(self, '_ctx', mp)
+ self._inqueue = CustomizablePicklingQueue(context,
+ self._forward_reducers)
+ self._outqueue = CustomizablePicklingQueue(context,
+ self._backward_reducers)
+ self._quick_put = self._inqueue._send
+ self._quick_get = self._outqueue._recv
+
class MemmappingPool(PicklingPool):
"""Process pool that shares large arrays to avoid memory copy.
@@ -209,21 +291,64 @@ class MemmappingPool(PicklingPool):
"""
- def __init__(self, processes=None, temp_folder=None, max_nbytes=
- 1000000.0, mmap_mode='r', forward_reducers=None, backward_reducers=
- None, verbose=0, context_id=None, prewarm=False, **kwargs):
+ def __init__(self, processes=None, temp_folder=None, max_nbytes=1e6,
+ mmap_mode='r', forward_reducers=None, backward_reducers=None,
+ verbose=0, context_id=None, prewarm=False, **kwargs):
+
if context_id is not None:
- warnings.warn(
- 'context_id is deprecated and ignored in joblib 0.9.4 and will be removed in 0.11'
- , DeprecationWarning)
+ warnings.warn('context_id is deprecated and ignored in joblib'
+ ' 0.9.4 and will be removed in 0.11',
+ DeprecationWarning)
+
manager = TemporaryResourcesManager(temp_folder)
self._temp_folder_manager = manager
- forward_reducers, backward_reducers = get_memmapping_reducers(
- temp_folder_resolver=manager.resolve_temp_folder_name,
- max_nbytes=max_nbytes, mmap_mode=mmap_mode, forward_reducers=
- forward_reducers, backward_reducers=backward_reducers, verbose=
- verbose, unlink_on_gc_collect=False, prewarm=prewarm)
- poolargs = dict(processes=processes, forward_reducers=
- forward_reducers, backward_reducers=backward_reducers)
+
+ # The usage of a temp_folder_resolver over a simple temp_folder is
+ # superfluous for multiprocessing pools, as they don't get reused, see
+ # get_memmapping_executor for more details. We still use it for code
+ # simplicity.
+ forward_reducers, backward_reducers = \
+ get_memmapping_reducers(
+ temp_folder_resolver=manager.resolve_temp_folder_name,
+ max_nbytes=max_nbytes, mmap_mode=mmap_mode,
+ forward_reducers=forward_reducers,
+ backward_reducers=backward_reducers, verbose=verbose,
+ unlink_on_gc_collect=False, prewarm=prewarm)
+
+ poolargs = dict(
+ processes=processes,
+ forward_reducers=forward_reducers,
+ backward_reducers=backward_reducers)
poolargs.update(kwargs)
super(MemmappingPool, self).__init__(**poolargs)
+
+ def terminate(self):
+ n_retries = 10
+ for i in range(n_retries):
+ try:
+ super(MemmappingPool, self).terminate()
+ break
+ except OSError as e:
+ if isinstance(e, WindowsError):
+ # Workaround occasional "[Error 5] Access is denied" issue
+ # when trying to terminate a process under windows.
+ sleep(0.1)
+ if i + 1 == n_retries:
+ warnings.warn("Failed to terminate worker processes in"
+ " multiprocessing pool: %r" % e)
+
+ # Clean up the temporary resources as the workers should now be off.
+ self._temp_folder_manager._clean_temporary_resources()
+
+ @property
+ def _temp_folder(self):
+ # Legacy property in tests. could be removed if we refactored the
+ # memmapping tests. SHOULD ONLY BE USED IN TESTS!
+ # We cache this property because it is called late in the tests - at
+ # this point, all context have been unregistered, and
+ # resolve_temp_folder_name raises an error.
+ if getattr(self, '_cached_temp_folder', None) is not None:
+ return self._cached_temp_folder
+ else:
+ self._cached_temp_folder = self._temp_folder_manager.resolve_temp_folder_name() # noqa
+ return self._cached_temp_folder
diff --git a/joblib/testing.py b/joblib/testing.py
index e20431c..caab7d2 100644
--- a/joblib/testing.py
+++ b/joblib/testing.py
@@ -1,14 +1,18 @@
"""
Helper for testing.
"""
+
import sys
import warnings
import os.path
import re
import subprocess
import threading
+
import pytest
import _pytest
+
+
raises = pytest.raises
warns = pytest.warns
SkipTest = _pytest.runner.Skipped
@@ -23,11 +27,17 @@ param = pytest.param
def warnings_to_stdout():
""" Redirect all warnings to stdout.
"""
- pass
+ showwarning_orig = warnings.showwarning
+
+ def showwarning(msg, cat, fname, lno, file=None, line=0):
+ showwarning_orig(msg, cat, os.path.basename(fname), line, sys.stdout)
+
+ warnings.showwarning = showwarning
+ # warnings.simplefilter('always')
-def check_subprocess_call(cmd, timeout=5, stdout_regex=None, stderr_regex=None
- ):
+def check_subprocess_call(cmd, timeout=5, stdout_regex=None,
+ stderr_regex=None):
"""Runs a command in a subprocess with timeout in seconds.
A SIGTERM is sent after `timeout` and if it does not terminate, a
@@ -36,4 +46,54 @@ def check_subprocess_call(cmd, timeout=5, stdout_regex=None, stderr_regex=None
Also checks returncode is zero, stdout if stdout_regex is set, and
stderr if stderr_regex is set.
"""
- pass
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+
+ def terminate_process(): # pragma: no cover
+ """
+ Attempt to terminate a leftover process spawned during test execution:
+ ideally this should not be needed but can help avoid clogging the CI
+ workers in case of deadlocks.
+ """
+ warnings.warn(f"Timeout running {cmd}")
+ proc.terminate()
+
+ def kill_process(): # pragma: no cover
+ """
+ Kill a leftover process spawned during test execution: ideally this
+ should not be needed but can help avoid clogging the CI workers in
+ case of deadlocks.
+ """
+ warnings.warn(f"Timeout running {cmd}")
+ proc.kill()
+
+ try:
+ if timeout is not None:
+ terminate_timer = threading.Timer(timeout, terminate_process)
+ terminate_timer.start()
+ kill_timer = threading.Timer(2 * timeout, kill_process)
+ kill_timer.start()
+ stdout, stderr = proc.communicate()
+ stdout, stderr = stdout.decode(), stderr.decode()
+ if proc.returncode != 0:
+ message = (
+ 'Non-zero return code: {}.\nStdout:\n{}\n'
+ 'Stderr:\n{}').format(
+ proc.returncode, stdout, stderr)
+ raise ValueError(message)
+
+ if (stdout_regex is not None and
+ not re.search(stdout_regex, stdout)):
+ raise ValueError(
+ "Unexpected stdout: {!r} does not match:\n{!r}".format(
+ stdout_regex, stdout))
+ if (stderr_regex is not None and
+ not re.search(stderr_regex, stderr)):
+ raise ValueError(
+ "Unexpected stderr: {!r} does not match:\n{!r}".format(
+ stderr_regex, stderr))
+
+ finally:
+ if timeout is not None:
+ terminate_timer.cancel()
+ kill_timer.cancel()