diff --git a/joblib/backports.py b/joblib/backports.py
index ed11ccb..63a9151 100644
--- a/joblib/backports.py
+++ b/joblib/backports.py
@@ -74,6 +74,33 @@ class LooseVersion(Version):
def __repr__(self):
return "LooseVersion ('%s')" % str(self)
+def _concurrency_safe_rename_win32(src, dst):
+ """Renames ``src`` into ``dst`` overwriting ``dst`` if it exists.
+
+ On Windows os.replace can yield permission errors if executed by two
+ different processes.
+ """
+ max_retries = 10
+ retry_delay = 0.001 # Initial delay set to 1ms
+ for i in range(max_retries):
+ try:
+ replace(src, dst)
+ break
+ except OSError as e:
+ if e.winerror not in access_denied_errors:
+ raise
+ if i == max_retries - 1:
+ raise
+ time.sleep(retry_delay)
+ retry_delay *= 1.5
+
+if os.name == 'nt':
+ access_denied_errors = (5, 13)
+ from os import replace
+ concurrency_safe_rename = _concurrency_safe_rename_win32
+else:
+ from os import replace as concurrency_safe_rename
+
try:
import numpy as np
@@ -90,18 +117,12 @@ try:
newly-created memmap that sends a maybe_unlink request for the
memmaped file to resource_tracker.
"""
- pass
+ if unlink_on_gc_collect:
+ util.finalize(np.memmap(filename, dtype=dtype, mode=mode,
+ offset=offset, shape=shape, order=order),
+ util.get_context().resource_tracker.maybe_unlink,
+ [filename])
+ return np.memmap(filename, dtype=dtype, mode=mode,
+ offset=offset, shape=shape, order=order)
except ImportError:
-if os.name == 'nt':
- access_denied_errors = (5, 13)
- from os import replace
-
- def concurrency_safe_rename(src, dst):
- """Renames ``src`` into ``dst`` overwriting ``dst`` if it exists.
-
- On Windows os.replace can yield permission errors if executed by two
- different processes.
- """
- pass
-else:
- from os import replace as concurrency_safe_rename
\ No newline at end of file
+ make_memmap = None
\ No newline at end of file
diff --git a/joblib/externals/cloudpickle/cloudpickle.py b/joblib/externals/cloudpickle/cloudpickle.py
index 2c5102a..7d05776 100644
--- a/joblib/externals/cloudpickle/cloudpickle.py
+++ b/joblib/externals/cloudpickle/cloudpickle.py
@@ -118,7 +118,24 @@ def _whichmodule(obj, name):
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
- pass
+ if isinstance(obj, type) and obj.__module__ == '__main__':
+ return obj.__module__
+
+ module_name = getattr(obj, '__module__', None)
+ if module_name is not None:
+ return module_name
+
+ # Protect the iteration by using a copy of sys.modules against dynamic
+ # modules that trigger imports of other modules upon calls to getattr
+ for module_name, module in list(sys.modules.items()):
+ if module_name == '__main__' or module is None:
+ continue
+ try:
+ if _getattribute(module, name)[0] is obj:
+ return module_name
+ except Exception:
+ pass
+ return None
def _should_pickle_by_reference(obj, name=None):
"""Test whether an function or a class should be pickled by reference
@@ -134,11 +151,58 @@ def _should_pickle_by_reference(obj, name=None):
functions and classes or for attributes of modules that have been
explicitly registered to be pickled by value.
"""
- pass
+ if name is None:
+ name = getattr(obj, '__name__', None)
+ if name is None:
+ return False
+
+ module_name = _whichmodule(obj, name)
+ if module_name is None:
+ return False
+
+ if module_name == "__main__":
+ return False
+
+ module = sys.modules.get(module_name, None)
+ if module is None:
+ return False
+
+ if module_name in _PICKLE_BY_VALUE_MODULES:
+ return False
+
+ if not hasattr(module, "__file__"):
+ # Module is not a regular Python module with source code, for instance
+ # it could live in a zip file as this is the case for stdlib modules in
+ # the Windows binary distribution of Python.
+ return True
+
+ # Check if the module has been explicitly registered to be pickled by value
+ if module.__file__ is None:
+ return False
+
+ return True
def _extract_code_globals(co):
"""Find all globals names read or written to by codeblock co."""
- pass
+ if co in _extract_code_globals_cache:
+ return _extract_code_globals_cache[co]
+
+ out_names = set()
+ for instr in _walk_global_ops(co):
+ if instr.opname in ("LOAD_GLOBAL", "STORE_GLOBAL", "DELETE_GLOBAL"):
+ # Extract the names of globals that are read/written to by adding
+ # `LOAD_GLOBAL`, `STORE_GLOBAL`, `DELETE_GLOBAL` opcodes
+ # to `out_names`.
+ out_names.add(co.co_names[instr.arg])
+
+ # Add the names of the global variables used in nested functions
+ if co.co_consts:
+ for const in co.co_consts:
+ if isinstance(const, types.CodeType):
+ out_names.update(_extract_code_globals(const))
+
+ _extract_code_globals_cache[co] = out_names
+ return out_names
def _find_imported_submodules(code, top_level_dependencies):
"""Find currently imported submodules used by a function.
@@ -165,7 +229,34 @@ def _find_imported_submodules(code, top_level_dependencies):
that calling func once depickled does not fail due to concurrent.futures
not being imported
"""
- pass
+ submodules = []
+ for name in code.co_names:
+ for module_name, module in list(sys.modules.items()):
+ if module_name == '__main__' or module is None:
+ continue
+
+ # Skip modules that are not in the top-level dependencies
+ is_dependency = False
+ for dependency in top_level_dependencies:
+ if module_name == dependency.__name__:
+ is_dependency = True
+ break
+ if module_name.startswith(dependency.__name__ + '.'):
+ is_dependency = True
+ break
+ if not is_dependency:
+ continue
+
+ if hasattr(module, name) and getattr(module, name) is not None:
+ submodules.append(module)
+ break
+
+ # Find submodules in nested code objects
+ for const in code.co_consts:
+ if isinstance(const, types.CodeType):
+ submodules.extend(_find_imported_submodules(const, top_level_dependencies))
+
+ return submodules
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
@@ -179,18 +270,27 @@ for k, v in types.__dict__.items():
def _walk_global_ops(code):
"""Yield referenced name for global-referencing instructions in code."""
- pass
+ for instr in dis.get_instructions(code):
+ op = instr.opcode
+ if op in GLOBAL_OPS:
+ yield instr
def _extract_class_dict(cls):
"""Retrieve a copy of the dict of a class without the inherited method."""
- pass
+ clsdict = dict(cls.__dict__)
+ if len(cls.__bases__) == 1:
+ inherited_dict = cls.__bases__[0].__dict__
+ for name, value in inherited_dict.items():
+ if name in clsdict and clsdict[name] is value:
+ clsdict.pop(name)
+ return clsdict
def is_tornado_coroutine(func):
"""Return whether `func` is a Tornado coroutine function.
Running coroutines are not supported.
"""
- pass
+ return getattr(func, '_is_coroutine', False)
def instance(cls):
"""Create a new instance of a class.
@@ -205,7 +305,7 @@ def instance(cls):
instance : cls
A new instance of ``cls``.
"""
- pass
+ return cls()
@instance
class _empty_cell_value:
@@ -226,7 +326,28 @@ def _make_skeleton_class(type_constructor, name, bases, type_kwargs, class_track
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
- pass
+ if class_tracker_id is not None:
+ if class_tracker_id in _DYNAMIC_CLASS_TRACKER_BY_ID:
+ return _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id]
+
+ # Build a new class with a custom metaclass that will make the class
+ # definition available via the class tracker at unpickling time.
+ class Meta(type):
+ def __new__(metacls, name, bases, clsdict):
+ return super().__new__(metacls, name, bases, clsdict)
+
+ # Create a new class with an empty dictionary
+ clsdict = {}
+ for k, v in type_kwargs.items():
+ clsdict[k] = v
+
+ cls = Meta(name, bases, clsdict)
+
+ if class_tracker_id is not None:
+ _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = cls
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[cls] = class_tracker_id
+
+ return cls
def _make_skeleton_enum(bases, name, qualname, members, module, class_tracker_id, extra):
"""Build dynamic enum with an empty __dict__ to be filled once memoized
@@ -242,19 +363,70 @@ def _make_skeleton_enum(bases, name, qualname, members, module, class_tracker_id
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
- pass
+ if class_tracker_id is not None:
+ if class_tracker_id in _DYNAMIC_CLASS_TRACKER_BY_ID:
+ return _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id]
+
+ metacls = type(bases[0]) if bases else type(Enum)
+ classdict = metacls.__prepare__(name, bases)
+
+ # Create a new Enum class
+ enum_class = metacls.__new__(metacls, name, bases, classdict)
+ enum_class.__module__ = module
+ enum_class.__qualname__ = qualname
+
+ # Create the enum members
+ for member_name, member_value in members:
+ enum_member = enum_class._member_type_.__new__(
+ enum_class._member_type_, member_name)
+ enum_member._name_ = member_name
+ enum_member._value_ = member_value
+ setattr(enum_class, member_name, enum_member)
+
+ if class_tracker_id is not None:
+ _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = enum_class
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[enum_class] = class_tracker_id
+
+ return enum_class
def _code_reduce(obj):
"""code object reducer."""
- pass
+ if hasattr(obj, "co_posonlyargcount"):
+ args = (
+ obj.co_argcount, obj.co_posonlyargcount,
+ obj.co_kwonlyargcount, obj.co_nlocals,
+ obj.co_stacksize, obj.co_flags, obj.co_code,
+ obj.co_consts, obj.co_names, obj.co_varnames,
+ obj.co_filename, obj.co_name, obj.co_firstlineno,
+ obj.co_lnotab, obj.co_freevars, obj.co_cellvars,
+ )
+ else:
+ args = (
+ obj.co_argcount, obj.co_kwonlyargcount,
+ obj.co_nlocals, obj.co_stacksize, obj.co_flags,
+ obj.co_code, obj.co_consts, obj.co_names,
+ obj.co_varnames, obj.co_filename,
+ obj.co_name, obj.co_firstlineno, obj.co_lnotab,
+ obj.co_freevars, obj.co_cellvars,
+ )
+ return types.CodeType, args
def _cell_reduce(obj):
"""Cell (containing values of a function's free variables) reducer."""
- pass
+ f = obj.cell_contents
+ return _empty_cell_value if f is None else f
def _file_reduce(obj):
"""Save a file."""
- pass
+ import io
+
+ if obj.closed:
+ raise pickle.PicklingError("Cannot pickle closed files")
+
+ if obj.mode == 'r':
+ return io.StringIO, (obj.read(),)
+ else:
+ raise pickle.PicklingError("Cannot pickle files in write mode")
def _dynamic_class_reduce(obj):
"""Save a class that can't be referenced as a module attribute.
@@ -263,11 +435,53 @@ def _dynamic_class_reduce(obj):
functions, or that otherwise can't be serialized as attribute lookups
from importable modules.
"""
- pass
+ if obj is type(None):
+ return type, (None,)
+
+ # Get the type of the class
+ type_constructor = type(obj)
+
+ # Get the class name
+ name = obj.__name__
+
+ # Get the class bases
+ bases = obj.__bases__
+
+ # Get the class dict
+ dict_items = _extract_class_dict(obj).items()
+
+ # Get the class module
+ module = obj.__module__
+
+ # Get the class qualname
+ qualname = getattr(obj, "__qualname__", None)
+
+ # Get the class tracker id
+ class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(obj)
+
+ # Build the type kwargs
+ type_kwargs = {
+ "__module__": module,
+ "__qualname__": qualname,
+ }
+
+ # Return the class constructor and its arguments
+ return _make_skeleton_class, (type_constructor, name, bases, type_kwargs,
+ class_tracker_id, None)
def _class_reduce(obj):
"""Select the reducer depending on the dynamic nature of the class obj."""
- pass
+ if obj is type(None):
+ return type, (None,)
+ elif obj is type(Ellipsis):
+ return type, (Ellipsis,)
+ elif obj is type(NotImplemented):
+ return type, (NotImplemented,)
+ elif obj in _BUILTIN_TYPE_NAMES:
+ return obj.__name__
+ elif not _should_pickle_by_reference(obj):
+ return _dynamic_class_reduce(obj)
+ return NotImplemented
def _function_setstate(obj, state):
"""Update the state of a dynamic function.
@@ -276,7 +490,12 @@ def _function_setstate(obj, state):
cannot rely on the native setstate routine of pickle.load_build, that calls
setattr on items of the slotstate. Instead, we have to modify them inplace.
"""
- pass
+ state, slotstate = state
+ obj.__dict__.update(state)
+
+ obj_globals = obj.__globals__
+ obj_globals.clear()
+ obj_globals.update(slotstate)
_DATACLASSE_FIELD_TYPE_SENTINELS = {dataclasses._FIELD.name: dataclasses._FIELD, dataclasses._FIELD_CLASSVAR.name: dataclasses._FIELD_CLASSVAR, dataclasses._FIELD_INITVAR.name: dataclasses._FIELD_INITVAR}
class Pickler(pickle.Pickler):
@@ -311,7 +530,21 @@ class Pickler(pickle.Pickler):
def _dynamic_function_reduce(self, func):
"""Reduce a function that is not pickleable via attribute lookup."""
- pass
+ if is_tornado_coroutine(func):
+ return NotImplemented
+
+ if PYPY:
+ # PyPy does not have the concept of builtin-functions, so
+ # reduce them as normal functions.
+ return self._function_reduce(func)
+
+ # Handle builtin functions
+ if hasattr(func, '__code__') and isinstance(func.__code__, builtin_code_type):
+ return self._builtin_function_reduce(func)
+
+ # Handle normal functions
+ state = _function_getstate(func)
+ return _function_setstate, (func.__new__(type(func)), state)
def _function_reduce(self, obj):
"""Reducer for function objects.
@@ -322,7 +555,13 @@ class Pickler(pickle.Pickler):
obj using a custom cloudpickle reducer designed specifically to handle
dynamic functions.
"""
- pass
+ if obj.__module__ == "__main__":
+ return self._dynamic_function_reduce(obj)
+
+ if _should_pickle_by_reference(obj):
+ return NotImplemented
+
+ return self._dynamic_function_reduce(obj)
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
@@ -364,7 +603,12 @@ class Pickler(pickle.Pickler):
reducers, such as Exceptions. See
https://github.com/cloudpipe/cloudpickle/issues/248
"""
- pass
+ if isinstance(obj, type):
+ return _class_reduce(obj)
+ elif isinstance(obj, types.FunctionType):
+ return self._function_reduce(obj)
+ else:
+ return NotImplemented
else:
dispatch = pickle.Pickler.dispatch.copy()
@@ -374,7 +618,12 @@ class Pickler(pickle.Pickler):
The name of this method is somewhat misleading: all types get
dispatched here.
"""
- pass
+ if isinstance(obj, type):
+ return self.save_reduce(_class_reduce(obj), obj=obj)
+ elif isinstance(obj, types.FunctionType):
+ return self.save_reduce(self._function_reduce(obj), obj=obj)
+ else:
+ return super().save_global(obj, name=name, pack=pack)
dispatch[type] = save_global
def save_function(self, obj, name=None):
@@ -383,7 +632,10 @@ class Pickler(pickle.Pickler):
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
- pass
+ if isinstance(obj, types.FunctionType):
+ return self.save_reduce(self._function_reduce(obj), obj=obj)
+ else:
+ return super().save_function(obj, name=name)
def save_pypy_builtin_func(self, obj):
"""Save pypy equivalent of builtin functions.
diff --git a/joblib/externals/loky/backend/_posix_reduction.py b/joblib/externals/loky/backend/_posix_reduction.py
index c0cc3ed..0e94840 100644
--- a/joblib/externals/loky/backend/_posix_reduction.py
+++ b/joblib/externals/loky/backend/_posix_reduction.py
@@ -6,9 +6,47 @@ from multiprocessing.context import get_spawning_popen
from .reduction import register
HAVE_SEND_HANDLE = hasattr(socket, 'CMSG_LEN') and hasattr(socket, 'SCM_RIGHTS') and hasattr(socket.socket, 'sendmsg')
+def _mk_inheritable(fd):
+ """Make a file descriptor inheritable by child processes."""
+ os.set_inheritable(fd, True)
+ return fd
+
def DupFd(fd):
"""Return a wrapper for an fd."""
- pass
+ popen = get_spawning_popen()
+ if popen is not None:
+ return popen.DupFd(fd)
+ else:
+ return _mk_inheritable(os.dup(fd))
+
+def _reduce_socket(s):
+ """Reduce a socket object."""
+ if HAVE_SEND_HANDLE and getattr(s, '_inheritable', False):
+ return s, (None,)
+ else:
+ return _rebuild_socket, (DupFd(s.fileno()),
+ s.family, s.type, s.proto)
+
+def _rebuild_socket(fd, family, type, proto):
+ """Rebuild a socket object."""
+ s = socket.socket(family, type, proto, fileno=fd)
+ if HAVE_SEND_HANDLE:
+ s._inheritable = True
+ return s
+
+def rebuild_connection(df, readable, writable):
+ """Rebuild a connection object."""
+ fd = df
+ if not isinstance(fd, int):
+ fd = fd.detach()
+ conn = Connection(fd, readable, writable)
+ conn._inheritable = True
+ return conn
+
+def reduce_connection(conn):
+ """Reduce a connection object."""
+ df = DupFd(conn.fileno())
+ return rebuild_connection, (df, conn.readable, conn.writable)
register(socket.socket, _reduce_socket)
register(_socket.socket, _reduce_socket)
register(Connection, reduce_connection)
\ No newline at end of file
diff --git a/joblib/externals/loky/backend/context.py b/joblib/externals/loky/backend/context.py
index 31567b1..0543686 100644
--- a/joblib/externals/loky/backend/context.py
+++ b/joblib/externals/loky/backend/context.py
@@ -46,11 +46,43 @@ def cpu_count(only_physical_cores=False):
It is also always larger or equal to 1.
"""
- pass
+ # Get the number of logical cores
+ try:
+ os_cpu_count = mp.cpu_count()
+ except NotImplementedError:
+ os_cpu_count = 1
+
+ if sys.platform == 'win32':
+ os_cpu_count = min(os_cpu_count, _MAX_WINDOWS_WORKERS)
+
+ cpu_count_user = _cpu_count_user(os_cpu_count)
+ if cpu_count_user is not None:
+ return cpu_count_user
+
+ if only_physical_cores:
+ physical_cores, exception = _count_physical_cores()
+ if physical_cores != "not found":
+ return max(1, physical_cores)
+
+ return max(1, os_cpu_count)
def _cpu_count_user(os_cpu_count):
"""Number of user defined available CPUs"""
- pass
+ cpu_count_user = os.environ.get('LOKY_MAX_CPU_COUNT', None)
+ if cpu_count_user is not None:
+ if cpu_count_user.strip() == '':
+ return None
+ try:
+ cpu_count_user = float(cpu_count_user)
+ if cpu_count_user > 0:
+ return int(min(cpu_count_user, os_cpu_count))
+ else:
+ return max(1, int(cpu_count_user * os_cpu_count))
+ except ValueError:
+ warnings.warn("LOKY_MAX_CPU_COUNT should be an integer or a float."
+ " Got '{}'. Using {} CPUs."
+ .format(cpu_count_user, os_cpu_count))
+ return None
def _count_physical_cores():
"""Return a tuple (number of physical cores, exception)
@@ -60,7 +92,60 @@ def _count_physical_cores():
The number of physical cores is cached to avoid repeating subprocess calls.
"""
- pass
+ global physical_cores_cache
+ if physical_cores_cache is not None:
+ return physical_cores_cache
+
+ if sys.platform == 'linux':
+ try:
+ # Try to get the number of physical cores from /proc/cpuinfo
+ with open('/proc/cpuinfo', 'rb') as f:
+ cpuinfo = f.read().decode('ascii')
+ cores = set()
+ for line in cpuinfo.split('\n'):
+ if line.startswith('physical id'):
+ phys_id = line.split(':')[1].strip()
+ elif line.startswith('cpu cores'):
+ nb_cores = int(line.split(':')[1].strip())
+ cores.add((phys_id, nb_cores))
+ if cores:
+ physical_cores_cache = (sum(nb_cores for _, nb_cores in cores), None)
+ return physical_cores_cache
+ except Exception as e:
+ physical_cores_cache = ("not found", e)
+ return physical_cores_cache
+
+ elif sys.platform == 'win32':
+ try:
+ # Try to get the number of physical cores from wmic
+ cmd = ['wmic', 'cpu', 'get', 'NumberOfCores']
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if p.returncode == 0:
+ stdout = stdout.decode('ascii')
+ cores = [int(l) for l in stdout.split('\n')[1:] if l.strip()]
+ if cores:
+ physical_cores_cache = (sum(cores), None)
+ return physical_cores_cache
+ except Exception as e:
+ physical_cores_cache = ("not found", e)
+ return physical_cores_cache
+
+ elif sys.platform == 'darwin':
+ try:
+ # Try to get the number of physical cores from sysctl
+ cmd = ['sysctl', '-n', 'hw.physicalcpu']
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if p.returncode == 0:
+ physical_cores_cache = (int(stdout.strip()), None)
+ return physical_cores_cache
+ except Exception as e:
+ physical_cores_cache = ("not found", e)
+ return physical_cores_cache
+
+ physical_cores_cache = ("not found", "unknown platform")
+ return physical_cores_cache
class LokyContext(BaseContext):
"""Context relying on the LokyProcess."""
@@ -70,37 +155,45 @@ class LokyContext(BaseContext):
def Queue(self, maxsize=0, reducers=None):
"""Returns a queue object"""
- pass
+ from .queues import Queue
+ return Queue(maxsize, reducers=reducers, ctx=self.get_context())
def SimpleQueue(self, reducers=None):
"""Returns a queue object"""
- pass
+ from .queues import SimpleQueue
+ return SimpleQueue(reducers=reducers, ctx=self.get_context())
if sys.platform != 'win32':
'For Unix platform, use our custom implementation of synchronize\n ensuring that we use the loky.backend.resource_tracker to clean-up\n the semaphores in case of a worker crash.\n '
def Semaphore(self, value=1):
"""Returns a semaphore object"""
- pass
+ from .synchronize import Semaphore
+ return Semaphore(value, ctx=self.get_context())
def BoundedSemaphore(self, value):
"""Returns a bounded semaphore object"""
- pass
+ from .synchronize import BoundedSemaphore
+ return BoundedSemaphore(value, ctx=self.get_context())
def Lock(self):
"""Returns a lock object"""
- pass
+ from .synchronize import Lock
+ return Lock(ctx=self.get_context())
def RLock(self):
"""Returns a recurrent lock object"""
- pass
+ from .synchronize import RLock
+ return RLock(ctx=self.get_context())
def Condition(self, lock=None):
"""Returns a condition object"""
- pass
+ from .synchronize import Condition
+ return Condition(lock, ctx=self.get_context())
def Event(self):
"""Returns an event object"""
- pass
+ from .synchronize import Event
+ return Event(ctx=self.get_context())
class LokyInitMainContext(LokyContext):
"""Extra context with LokyProcess, which does load the main module
@@ -118,6 +211,30 @@ class LokyInitMainContext(LokyContext):
"""
_name = 'loky_init_main'
Process = LokyInitMainProcess
+def get_context(method=None):
+ """Returns a BaseContext or instance of subclass of BaseContext.
+
+ method parameter can be 'fork', 'spawn', 'forkserver', 'loky' or None.
+ If None, the default context is returned.
+ """
+ if method is None:
+ # Get the default context
+ if _DEFAULT_START_METHOD is None:
+ _DEFAULT_START_METHOD = 'loky'
+ method = _DEFAULT_START_METHOD
+
+ if method not in START_METHODS:
+ raise ValueError(
+ "Method '{}' not in available methods {}".format(
+ method, START_METHODS))
+
+ if method == 'loky':
+ return ctx_loky
+ elif method == 'loky_init_main':
+ return mp.context._concrete_contexts['loky_init_main']
+ else:
+ return mp_get_context(method)
+
ctx_loky = LokyContext()
mp.context._concrete_contexts['loky'] = ctx_loky
mp.context._concrete_contexts['loky_init_main'] = LokyInitMainContext()
\ No newline at end of file
diff --git a/joblib/externals/loky/backend/reduction.py b/joblib/externals/loky/backend/reduction.py
index 4b3cb7a..c41eb8f 100644
--- a/joblib/externals/loky/backend/reduction.py
+++ b/joblib/externals/loky/backend/reduction.py
@@ -6,10 +6,35 @@ import sys
import os
from multiprocessing import util
from pickle import loads, HIGHEST_PROTOCOL
+from multiprocessing.reduction import register
+
_dispatch_table = {}
+def _reduce_method(m):
+ """Helper function for pickling methods."""
+ if m.__self__ is None:
+ return getattr, (m.__self__.__class__, m.__func__.__name__)
+ else:
+ return getattr, (m.__self__, m.__func__.__name__)
+
+def _reduce_method_descriptor(m):
+ """Helper function for pickling method descriptors."""
+ return getattr, (m.__objclass__, m.__name__)
+
+def _reduce_partial(p):
+ """Helper function for pickling partial functions."""
+ return _rebuild_partial, (p.func, p.args, p.keywords or {})
+
+def _rebuild_partial(func, args, keywords):
+ """Helper function for rebuilding partial functions."""
+ return functools.partial(func, *args, **keywords)
+
class _C:
- pass
+ def f(self):
+ pass
+ @classmethod
+ def h(cls):
+ pass
register(type(_C().f), _reduce_method)
register(type(_C.h), _reduce_method)
if not hasattr(sys, 'pypy_version_info'):
@@ -30,9 +55,47 @@ _LokyPickler = None
_loky_pickler_name = None
set_loky_pickler()
+def set_loky_pickler(loky_pickler=None):
+ """Select the pickler to use in loky.
+
+ Parameters
+ ----------
+ loky_pickler: str in {'pickle', 'cloudpickle', None}, default=None
+ If None, use the value of the environment variable LOKY_PICKLER.
+ If 'pickle', use the standard pickle module.
+ If 'cloudpickle', use the cloudpickle module.
+ """
+ global _LokyPickler, _loky_pickler_name
+
+ if loky_pickler is None:
+ loky_pickler = ENV_LOKY_PICKLER
+
+ if loky_pickler == _loky_pickler_name:
+ return
+
+ if loky_pickler == 'pickle':
+ from pickle import Pickler
+ _LokyPickler = Pickler
+ elif loky_pickler == 'cloudpickle':
+ from joblib.externals.cloudpickle import CloudPickler
+ _LokyPickler = CloudPickler
+ else:
+ raise ValueError(
+ "Invalid value for LOKY_PICKLER: '{}'. Supported values are "
+ "'pickle' and 'cloudpickle'".format(loky_pickler))
+ _loky_pickler_name = loky_pickler
+
def dump(obj, file, reducers=None, protocol=None):
"""Replacement for pickle.dump() using _LokyPickler."""
- pass
+ if protocol is None:
+ protocol = HIGHEST_PROTOCOL
+ _LokyPickler(file, protocol=protocol).dump(obj)
+
+def dumps(obj, reducers=None, protocol=None):
+ """Replacement for pickle.dumps() using _LokyPickler."""
+ buf = io.BytesIO()
+ dump(obj, buf, reducers=reducers, protocol=protocol)
+ return buf.getbuffer()
__all__ = ['dump', 'dumps', 'loads', 'register', 'set_loky_pickler']
if sys.platform == 'win32':
from multiprocessing.reduction import duplicate
diff --git a/joblib/func_inspect.py b/joblib/func_inspect.py
index 239018a..4c125e0 100644
--- a/joblib/func_inspect.py
+++ b/joblib/func_inspect.py
@@ -33,11 +33,36 @@ def get_func_code(func):
This function does a bit more magic than inspect, and is thus
more robust.
"""
- pass
+ source_file = None
+ try:
+ source_file = inspect.getsourcefile(func)
+ except:
+ source_file = None
+
+ if source_file is None:
+ try:
+ source_file = inspect.getfile(func)
+ except:
+ source_file = None
+
+ if source_file is None:
+ return None, None, None
+
+ try:
+ source_lines = inspect.findsource(func)
+ source_lines, first_line = source_lines
+ source_lines = ''.join(source_lines)
+ except:
+ return None, None, None
+
+ return source_lines, source_file, first_line
def _clean_win_chars(string):
"""Windows cannot encode some characters in filename."""
- pass
+ import urllib.parse
+ if os.name == 'nt':
+ return urllib.parse.quote(string, safe='')
+ return string
def get_func_name(func, resolv_alias=True, win_characters=True):
""" Return the function import path (as a list of module names), and
@@ -53,15 +78,62 @@ def get_func_name(func, resolv_alias=True, win_characters=True):
If true, substitute special characters using urllib.quote
This is useful in Windows, as it cannot encode some filenames
"""
- pass
+ if hasattr(func, '__module__'):
+ module = func.__module__
+ else:
+ try:
+ module = inspect.getmodule(func)
+ if module is not None:
+ module = module.__name__
+ except:
+ module = None
+ if module is None:
+ module = ''
+
+ module_parts = module.split('.')
+
+ if hasattr(func, '__name__'):
+ name = func.__name__
+ else:
+ name = 'unknown'
+ if hasattr(func, '__class__'):
+ name = func.__class__.__name__
+
+ if win_characters:
+ name = _clean_win_chars(name)
+
+ if resolv_alias:
+ # Attempt to resolve name aliases using inspect
+ if hasattr(func, '__code__'):
+ try:
+ code = func.__code__
+ filename = code.co_filename
+ first_line = code.co_firstlineno
+ name = '%s-%d' % (name, first_line)
+ except:
+ pass
+
+ return module_parts, name
def _signature_str(function_name, arg_sig):
"""Helper function to output a function signature"""
- pass
+ args = []
+ if arg_sig.args:
+ args.extend(arg_sig.args)
+ if arg_sig.varargs:
+ args.append('*' + arg_sig.varargs)
+ if arg_sig.varkw:
+ args.append('**' + arg_sig.varkw)
+ return '%s(%s)' % (function_name, ', '.join(args))
def _function_called_str(function_name, args, kwargs):
"""Helper function to output a function call"""
- pass
+ parts = []
+ if args:
+ parts.extend(repr(arg) for arg in args)
+ if kwargs:
+ parts.extend('%s=%r' % (k, v) for k, v in sorted(kwargs.items()))
+ return '%s(%s)' % (function_name, ', '.join(parts))
def filter_args(func, ignore_lst, args=(), kwargs=dict()):
""" Filters the given args and kwargs using a list of arguments to
@@ -84,10 +156,46 @@ def filter_args(func, ignore_lst, args=(), kwargs=dict()):
filtered_args: list
List of filtered positional and keyword arguments.
"""
- pass
+ arg_spec = inspect.getfullargspec(func)
+ arg_names = list(arg_spec.args)
+ output_args = list()
+
+ # Filter positional arguments
+ if '*' not in ignore_lst:
+ for arg_name, arg in zip(arg_names, args):
+ if arg_name not in ignore_lst:
+ output_args.append(arg)
+
+ # Filter keyword arguments
+ if '**' not in ignore_lst:
+ for arg_name in arg_names[len(args):]:
+ if arg_name in kwargs:
+ if arg_name not in ignore_lst:
+ output_args.append(kwargs[arg_name])
+ else:
+ # Check if the parameter has a default value
+ default_arg = arg_spec.defaults[arg_names.index(arg_name) - len(arg_names)]
+ if default_arg not in ignore_lst:
+ output_args.append(default_arg)
+
+ return output_args
def format_call(func, args, kwargs, object_name='Memory'):
""" Returns a nicely formatted statement displaying the function
call with the given arguments.
"""
- pass
\ No newline at end of file
+ path, name = get_func_name(func)
+ path = [object_name] + list(path)
+ module_path = '.'.join(path)
+
+ arg_str = _function_called_str(name, args, kwargs)
+ return '%s.%s' % (module_path, arg_str)
+
+def format_signature(func):
+ """Return a formatted signature for the function."""
+ arg_spec = inspect.getfullargspec(func)
+ path, name = get_func_name(func)
+ module_path = '.'.join(path)
+
+ signature = _signature_str(name, arg_spec)
+ return '%s.%s' % (module_path, signature)
\ No newline at end of file
diff --git a/joblib/hashing.py b/joblib/hashing.py
index 410eaeb..938c489 100644
--- a/joblib/hashing.py
+++ b/joblib/hashing.py
@@ -38,6 +38,15 @@ class Hasher(Pickler):
protocol = 3
Pickler.__init__(self, self.stream, protocol=protocol)
self._hash = hashlib.new(hash_name)
+
+ def save_global(self, obj, name=None, pack=struct.pack):
+ """Save a global object"""
+ self._hash.update(str(obj).encode('utf-8'))
+
+ def save_set(self, obj, pack=struct.pack):
+ """Save a set object"""
+ self._hash.update(str(_ConsistentSet(obj)).encode('utf-8'))
+
dispatch = Pickler.dispatch.copy()
dispatch[type(len)] = save_global
dispatch[type(object)] = save_global
@@ -73,7 +82,24 @@ class NumpyHasher(Hasher):
than pickling them. Off course, this is a total abuse of
the Pickler class.
"""
- pass
+ if isinstance(obj, type):
+ return Hasher.save_global(self, obj)
+ if isinstance(obj, self.np.ndarray) and not obj.dtype.hasobject:
+ # Compute a hash of the object
+ try:
+ self._hash.update(self._getbuffer(obj))
+ except (TypeError, BufferError):
+ # Cater for non-single-segment arrays: this creates a
+ # copy, and thus aleviates this issue.
+ # XXX: There might be a more efficient way of doing this
+ self._hash.update(self._getbuffer(obj.flatten()))
+
+ # We also hash the dtype and the shape to distinguish
+ # different views of the same data with different dtypes.
+ self._hash.update(str(obj.dtype).encode('utf-8'))
+ self._hash.update(str(obj.shape).encode('utf-8'))
+ return
+ return Hasher.save(self, obj)
def hash(obj, hash_name='md5', coerce_mmap=False):
""" Quick calculation of a hash to identify uniquely Python objects
@@ -87,4 +113,10 @@ def hash(obj, hash_name='md5', coerce_mmap=False):
coerce_mmap: boolean
Make no difference between np.memmap and np.ndarray
"""
- pass
\ No newline at end of file
+ try:
+ import numpy as np
+ hasher = NumpyHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
+ except ImportError:
+ hasher = Hasher(hash_name=hash_name)
+ hasher.save(obj)
+ return hasher._hash.hexdigest()
\ No newline at end of file
diff --git a/joblib/logger.py b/joblib/logger.py
index 3de4188..f2107a8 100644
--- a/joblib/logger.py
+++ b/joblib/logger.py
@@ -9,15 +9,38 @@ import sys
import os
import shutil
import logging
-import pprint
+import pprint as _pprint
from .disk import mkdirp
+def pformat(obj, depth=3):
+ """Return a formatted representation of the object."""
+ return _pprint.pformat(obj, depth=depth)
+
def _squeeze_time(t):
"""Remove .1s to the time under Windows: this is the time it take to
stat files. This is needed to make results similar to timings under
Unix, for tests
"""
- pass
+ if sys.platform.startswith('win'):
+ return max(0, t - .1)
+ else:
+ return t
+
+def format_time(t):
+ """Format time in seconds for human-readable output"""
+ t = _squeeze_time(t)
+ if t > 60:
+ return "%.1f min" % (t / 60.)
+ else:
+ return "%.2f s" % t
+
+def short_format_time(t):
+ """Format time in seconds for short human-readable output"""
+ t = _squeeze_time(t)
+ if t > 60:
+ return "%.1fm" % (t / 60.)
+ else:
+ return "%.1fs" % t
class Logger(object):
""" Base class for logging messages.
@@ -37,7 +60,11 @@ class Logger(object):
def format(self, obj, indent=0):
"""Return the formatted representation of the object."""
- pass
+ if indent == 0:
+ prefix = ''
+ else:
+ prefix = ' ' * indent
+ return prefix + pprint.pformat(obj, depth=self.depth)
class PrintTime(object):
""" Print and log messages while keeping track of time.
diff --git a/joblib/numpy_pickle_compat.py b/joblib/numpy_pickle_compat.py
index 65e3046..d46bbc9 100644
--- a/joblib/numpy_pickle_compat.py
+++ b/joblib/numpy_pickle_compat.py
@@ -10,7 +10,7 @@ from .numpy_pickle_utils import _ensure_native_byte_order
def hex_str(an_int):
"""Convert an int to an hexadecimal string."""
- pass
+ return hex(an_int)[2:]
_MAX_LEN = len(hex_str(2 ** 64))
_CHUNK_SIZE = 64 * 1024
@@ -21,7 +21,26 @@ def read_zfile(file_handle):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.seek(0)
+ header = file_handle.read(len(_ZFILE_PREFIX))
+ if header != _ZFILE_PREFIX:
+ raise ValueError("Unknown file type")
+
+ length = file_handle.read(_MAX_LEN)
+ length = int(length, 16)
+
+ # Decompress small files in memory
+ data = BytesIO()
+ chunk = file_handle.read(_CHUNK_SIZE)
+ decompressor = zlib.decompressobj()
+ while chunk:
+ data.write(decompressor.decompress(chunk))
+ chunk = file_handle.read(_CHUNK_SIZE)
+ data.write(decompressor.flush())
+ data = data.getvalue()
+ if len(data) != length:
+ raise ValueError("File corrupted")
+ return data
def write_zfile(file_handle, data, compress=1):
"""Write the data in the given file as a Z-file.
@@ -30,7 +49,18 @@ def write_zfile(file_handle, data, compress=1):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.write(_ZFILE_PREFIX)
+ length = hex_str(len(data))
+ # Add padding to length to make it fixed width
+ file_handle.write(length.zfill(_MAX_LEN).encode('ascii'))
+ compressor = zlib.compressobj(compress)
+ chunk = data[0:_CHUNK_SIZE]
+ pos = _CHUNK_SIZE
+ while chunk:
+ file_handle.write(compressor.compress(chunk))
+ chunk = data[pos:pos + _CHUNK_SIZE]
+ pos += _CHUNK_SIZE
+ file_handle.write(compressor.flush())
class NDArrayWrapper(object):
"""An object to be persisted instead of numpy arrays.
@@ -47,7 +77,23 @@ class NDArrayWrapper(object):
def read(self, unpickler):
"""Reconstruct the array."""
- pass
+ filename = os.path.join(unpickler._dirname, self.filename)
+ # Load the array from the disk
+ np = unpickler.np
+ if np is None:
+ raise ImportError("Trying to unpickle an ndarray, "
+ "but numpy is not available")
+ array = _ensure_native_byte_order(np.load(filename, mmap_mode=unpickler.mmap_mode if self.allow_mmap else None))
+ # Reconstruct subclasses. This does not work with old
+ # versions of numpy
+ if (not np.issubdtype(array.dtype, np.dtype('O')) and
+ self.subclass not in (type(None), type(array))):
+ new_array = np.ndarray.__new__(self.subclass, array.shape,
+ array.dtype, buffer=array)
+ # Preserve side effects of viewing arrays
+ new_array.__array_finalize__(array)
+ array = new_array
+ return array
class ZNDArrayWrapper(NDArrayWrapper):
"""An object to be persisted instead of numpy arrays.
@@ -72,7 +118,23 @@ class ZNDArrayWrapper(NDArrayWrapper):
def read(self, unpickler):
"""Reconstruct the array from the meta-information and the z-file."""
- pass
+ # Get the array parameters
+ init_args, state = self.init_args, self.state
+
+ # Read the array data from the z-file
+ filename = os.path.join(unpickler._dirname, self.filename)
+ with open(filename, 'rb') as f:
+ array_bytes = read_zfile(f)
+
+ # Reconstruct the array
+ np = unpickler.np
+ if np is None:
+ raise ImportError("Trying to unpickle an ndarray, "
+ "but numpy is not available")
+ array = np.ndarray(*init_args)
+ array.__setstate__(state)
+ array.data = np.frombuffer(array_bytes, dtype=array.dtype)
+ return array
class ZipNumpyUnpickler(Unpickler):
"""A subclass of the Unpickler to unpickle our numpy pickles."""
@@ -98,7 +160,28 @@ class ZipNumpyUnpickler(Unpickler):
NDArrayWrapper, by the array we are interested in. We
replace them directly in the stack of pickler.
"""
- pass
+ stack = self.stack
+ state = stack.pop()
+ instance = stack[-1]
+ if isinstance(instance, NDArrayWrapper):
+ # We replace the wrapper by the array
+ array = instance.read(self)
+ stack[-1] = array
+ return
+ setstate = getattr(instance, "__setstate__", None)
+ if setstate is not None:
+ setstate(state)
+ return
+ slotstate = None
+ if isinstance(state, tuple) and len(state) == 2:
+ state, slotstate = state
+ if state:
+ instance_dict = instance.__dict__
+ for k, v in state.items():
+ instance_dict[k] = v
+ if slotstate:
+ for k, v in slotstate.items():
+ setattr(instance, k, v)
dispatch[pickle.BUILD[0]] = load_build
def load_compatibility(filename):
@@ -127,4 +210,10 @@ def load_compatibility(filename):
This function can load numpy array files saved separately during the
dump.
"""
- pass
\ No newline at end of file
+ with open(filename, 'rb') as file_handle:
+ # We are careful to open the file handle early and keep it open to
+ # avoid race-conditions on renames.
+ # XXX: This code should be refactored to use a context handler
+ unpickler = ZipNumpyUnpickler(filename, file_handle)
+ obj = unpickler.load()
+ return obj
\ No newline at end of file