back to Claude Sonnet 3.5 - Fill-in summary
Claude Sonnet 3.5 - Fill-in: joblib
Failed to run pytests for test test
ImportError while loading conftest '/testbed/conftest.py'.
conftest.py:9: in <module>
from joblib.parallel import mp
joblib/__init__.py:114: in <module>
from .memory import Memory
joblib/memory.py:21: in <module>
from . import hashing
joblib/hashing.py:34: in <module>
class Hasher(Pickler):
joblib/hashing.py:45: in Hasher
dispatch[type(len)] = save_global
E NameError: name 'save_global' is not defined
Patch diff
diff --git a/joblib/_dask.py b/joblib/_dask.py
index 726f453..a8b7f08 100644
--- a/joblib/_dask.py
+++ b/joblib/_dask.py
@@ -64,7 +64,15 @@ class _WeakKeyDictionary:
def _make_tasks_summary(tasks):
"""Summarize of list of (func, args, kwargs) function calls"""
- pass
+ num_tasks = len(tasks)
+ if num_tasks == 0:
+ return 0, False, ''
+
+ first_func = tasks[0][0]
+ mixed = any(task[0] != first_func for task in tasks[1:])
+ funcname = funcname(first_func)
+
+ return num_tasks, mixed, funcname
class Batch:
@@ -143,7 +151,11 @@ or
joblib.Parallel will never access those results
"""
- pass
+ self.client.cancel(list(self._results.keys()))
+ self._results.clear()
+ self._callbacks.clear()
+ if ensure_ready:
+ self.client.restart()
@contextlib.contextmanager
def retrieval_context(self):
@@ -152,4 +164,10 @@ or
This removes thread from the worker's thread pool (using 'secede').
Seceding avoids deadlock in nested parallelism settings.
"""
- pass
+ if hasattr(thread_state, 'on_worker'):
+ secede()
+ try:
+ yield
+ finally:
+ if hasattr(thread_state, 'on_worker'):
+ rejoin()
diff --git a/joblib/_memmapping_reducer.py b/joblib/_memmapping_reducer.py
index 5012683..2ea8325 100644
--- a/joblib/_memmapping_reducer.py
+++ b/joblib/_memmapping_reducer.py
@@ -45,7 +45,15 @@ def unlink_file(filename):
it takes for the last reference of the memmap to be closed, yielding (on
Windows) a PermissionError in the resource_tracker loop.
"""
- pass
+ max_retries = 10
+ for retry in range(max_retries):
+ try:
+ os.unlink(filename)
+ return
+ except PermissionError:
+ if retry == max_retries - 1:
+ raise
+ time.sleep(0.1 * (2 ** retry)) # Exponential backoff
resource_tracker._CLEANUP_FUNCS['file'] = unlink_file
@@ -68,7 +76,15 @@ class _WeakArrayKeyMap:
def _get_backing_memmap(a):
"""Recursively look up the original np.memmap instance base if any."""
- pass
+ if isinstance(a, np.memmap):
+ return a
+ elif hasattr(a, '__array_interface__'):
+ base = a.__array_interface__.get('data')[0]
+ if base is not None:
+ return _get_backing_memmap(base)
+ elif hasattr(a, 'base'):
+ return _get_backing_memmap(a.base)
+ return None
def _get_temp_dir(pool_folder_name, temp_folder=None):
@@ -101,18 +117,53 @@ def _get_temp_dir(pool_folder_name, temp_folder=None):
whether the temporary folder is written to the system shared memory
folder or some other temporary folder.
"""
- pass
+ if temp_folder is None:
+ temp_folder = os.environ.get('JOBLIB_TEMP_FOLDER', None)
+
+ if temp_folder is None:
+ if os.path.exists(SYSTEM_SHARED_MEM_FS):
+ try:
+ temp_folder = SYSTEM_SHARED_MEM_FS
+ if os.path.getsize(SYSTEM_SHARED_MEM_FS) < SYSTEM_SHARED_MEM_FS_MIN_SIZE:
+ warnings.warn("The filesystem at %s is too small for"
+ " joblib memmapping" % SYSTEM_SHARED_MEM_FS)
+ except OSError:
+ temp_folder = None
+
+ if temp_folder is None:
+ temp_folder = tempfile.gettempdir()
+
+ pool_folder = os.path.join(temp_folder, 'joblib_memmapping', pool_folder_name)
+
+ return pool_folder, temp_folder == SYSTEM_SHARED_MEM_FS
def has_shareable_memory(a):
"""Return True if a is backed by some mmap buffer directly or not."""
- pass
+ if isinstance(a, np.memmap):
+ return True
+
+ base = _get_backing_memmap(a)
+ if base is not None:
+ return True
+
+ return False
def _strided_from_memmap(filename, dtype, mode, offset, order, shape,
strides, total_buffer_len, unlink_on_gc_collect):
"""Reconstruct an array view on a memory mapped file."""
- pass
+ if mode == 'w+':
+ mode = 'r+'
+
+ mm = make_memmap(filename, dtype=dtype, shape=(total_buffer_len,),
+ mode=mode, offset=offset, order=order)
+
+ if unlink_on_gc_collect:
+ resource_tracker.register(filename, 'file')
+
+ array = as_strided(mm, shape=shape, strides=strides)
+ return array
def _reduce_memmap_backed(a, m):
@@ -122,12 +173,34 @@ def _reduce_memmap_backed(a, m):
m is expected to be an instance of np.memmap on the top of the ``base``
attribute ancestry of a. ``m.base`` should be the real python mmap object.
"""
- pass
+ if not isinstance(m, np.memmap):
+ raise ValueError("m is not a memmap instance")
+
+ offset = m.offset
+ mode = m.mode
+
+ if isinstance(a, np.memmap):
+ # if a is already a memmap instance, we need to respect its view
+ # on the original memmap
+ offset += a.offset
+ mode = a.mode
+
+ order = 'C' if m.flags['C_CONTIGUOUS'] else 'F'
+
+ return (_strided_from_memmap,
+ (m.filename, a.dtype, mode, offset, order, a.shape, a.strides,
+ m._mmap.size(), True))
def reduce_array_memmap_backward(a):
"""reduce a np.array or a np.memmap from a child process"""
- pass
+ m = _get_backing_memmap(a)
+ if m is not None:
+ return _reduce_memmap_backed(a, m)
+ else:
+ # This is a regular in-memory array that can be pickled without
+ # problem
+ return (loads, (dumps(a, protocol=HIGHEST_PROTOCOL),))
class ArrayMemmapForwardReducer(object):
diff --git a/joblib/_parallel_backends.py b/joblib/_parallel_backends.py
index 87fe642..50f5b75 100644
--- a/joblib/_parallel_backends.py
+++ b/joblib/_parallel_backends.py
@@ -51,12 +51,17 @@ class ParallelBackendBase(metaclass=ABCMeta):
scheduling overhead and better use of CPU cache prefetching heuristics)
as long as all the workers have enough work to do.
"""
- pass
+ if n_jobs == -1:
+ return cpu_count()
+ return max(1, min(n_jobs, cpu_count()))
@abstractmethod
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
+ result = func()
+ if callback is not None:
+ callback(result)
+ return result
def retrieve_result_callback(self, out):
"""Called within the callback function passed in apply_async.
@@ -74,7 +79,9 @@ class ParallelBackendBase(metaclass=ABCMeta):
This makes it possible to reuse an existing backend instance for
successive independent calls to Parallel with different parameters.
"""
- pass
+ self.parallel = parallel
+ self.n_jobs = self.effective_n_jobs(n_jobs)
+ return self.n_jobs
def start_call(self):
"""Call-back method called at the beginning of a Parallel call"""
@@ -90,11 +97,30 @@ class ParallelBackendBase(metaclass=ABCMeta):
def compute_batch_size(self):
"""Determine the optimal batch size"""
- pass
+ if self._effective_batch_size == 1:
+ return 1
+ if self._smoothed_batch_duration == 0.0:
+ return self._effective_batch_size
+ optimal_duration = (self.MIN_IDEAL_BATCH_DURATION +
+ self.MAX_IDEAL_BATCH_DURATION) / 2
+ return int(self._effective_batch_size *
+ optimal_duration / self._smoothed_batch_duration)
def batch_completed(self, batch_size, duration):
"""Callback indicate how long it took to run a batch"""
- pass
+ if (self._smoothed_batch_duration == 0.0 or
+ duration < self._smoothed_batch_duration):
+ self._smoothed_batch_duration = duration
+ else:
+ sd = self._smoothed_batch_duration
+ self._smoothed_batch_duration = 0.8 * sd + 0.2 * duration
+
+ ideal_duration = (self.MIN_IDEAL_BATCH_DURATION +
+ self.MAX_IDEAL_BATCH_DURATION) / 2
+ ratio = self._smoothed_batch_duration / ideal_duration
+ self._effective_batch_size = int(self._effective_batch_size / ratio)
+ # Clip to avoid crazy sizes
+ self._effective_batch_size = max(1, min(self._effective_batch_size, 1000))
def get_exceptions(self):
"""List of exception types to be captured."""
@@ -172,11 +198,14 @@ class SequentialBackend(ParallelBackendBase):
def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
- pass
+ return 1 # SequentialBackend always runs 1 job
def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
- pass
+ result = func()
+ if callback is not None:
+ callback(result)
+ return result
class PoolManagerMixin(object):
@@ -185,11 +214,16 @@ class PoolManagerMixin(object):
def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
- pass
+ if n_jobs == -1:
+ n_jobs = cpu_count()
+ return max(1, min(n_jobs, cpu_count()))
def terminate(self):
"""Shutdown the process or thread pool"""
- pass
+ if self._pool is not None:
+ self._pool.close()
+ self._pool.terminate()
+ self._pool = None
def _get_pool(self):
"""Used by apply_async to make it possible to implement lazy init"""
diff --git a/joblib/_store_backends.py b/joblib/_store_backends.py
index 0ce3682..11832d7 100644
--- a/joblib/_store_backends.py
+++ b/joblib/_store_backends.py
@@ -27,7 +27,15 @@ class CacheWarning(Warning):
def concurrency_safe_write(object_to_write, filename, write_func):
"""Writes an object into a unique file in a concurrency-safe way."""
- pass
+ temporary_filename = f"{filename}.tmp"
+ try:
+ with open(temporary_filename, 'wb') as f:
+ write_func(object_to_write, f)
+ concurrency_safe_rename(temporary_filename, filename)
+ except:
+ if os.path.exists(temporary_filename):
+ os.unlink(temporary_filename)
+ raise
class StoreBackendBase(metaclass=ABCMeta):
@@ -153,73 +161,177 @@ class StoreBackendMixin(object):
def load_item(self, call_id, verbose=1, timestamp=None, metadata=None):
"""Load an item from the store given its id as a list of str."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+
+ if not self._item_exists(full_path):
+ raise KeyError(f"No item with path {path} in the store")
+
+ with self._open_item(full_path, 'rb') as f:
+ item = numpy_pickle.load(f)
+
+ if verbose > 1:
+ print(f"[Memory] Loading {path} from {self.location}")
+
+ return item
def dump_item(self, call_id, item, verbose=1):
"""Dump an item in the store at the id given as a list of str."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+
+ try:
+ mkdirp(os.path.dirname(full_path))
+ self._concurrency_safe_write(item, full_path, numpy_pickle.dump)
+ if verbose > 1:
+ print(f"[Memory] Storing {path} in {self.location}")
+ except Exception as e:
+ raise CacheWarning(f"Error while dumping item: {e}")
def clear_item(self, call_id):
"""Clear the item at the id, given as a list of str."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ if self._item_exists(full_path):
+ os.remove(full_path)
def contains_item(self, call_id):
"""Check if there is an item at the id, given as a list of str."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ return self._item_exists(full_path)
def get_item_info(self, call_id):
"""Return information about item."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ if not self._item_exists(full_path):
+ raise KeyError(f"No item with path {path} in the store")
+
+ stats = os.stat(full_path)
+ return CacheItemInfo(path=full_path,
+ size=stats.st_size,
+ last_access=datetime.datetime.fromtimestamp(stats.st_atime))
def get_metadata(self, call_id):
"""Return actual metadata of an item."""
- pass
+ path = os.path.join(*call_id)
+ metadata_path = os.path.join(self.location, path + '.metadata')
+ if not self._item_exists(metadata_path):
+ return None
+ with self._open_item(metadata_path, 'r') as f:
+ return json.load(f)
def store_metadata(self, call_id, metadata):
"""Store metadata of a computation."""
- pass
+ path = os.path.join(*call_id)
+ metadata_path = os.path.join(self.location, path + '.metadata')
+ mkdirp(os.path.dirname(metadata_path))
+ self._concurrency_safe_write(metadata, metadata_path, json.dump)
def contains_path(self, call_id):
"""Check cached function is available in store."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ return os.path.exists(full_path)
def clear_path(self, call_id):
"""Clear all items with a common path in the store."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ if os.path.exists(full_path):
+ if os.path.isdir(full_path):
+ shutil.rmtree(full_path)
+ else:
+ os.remove(full_path)
def store_cached_func_code(self, call_id, func_code=None):
"""Store the code of the cached function."""
- pass
+ if func_code is not None:
+ path = os.path.join(*call_id)
+ func_code_path = os.path.join(self.location, path + '.py')
+ mkdirp(os.path.dirname(func_code_path))
+ self._concurrency_safe_write(func_code, func_code_path, lambda x, f: f.write(x))
def get_cached_func_code(self, call_id):
- """Store the code of the cached function."""
- pass
+ """Get the code of the cached function."""
+ path = os.path.join(*call_id)
+ func_code_path = os.path.join(self.location, path + '.py')
+ if not self._item_exists(func_code_path):
+ return None
+ with self._open_item(func_code_path, 'r') as f:
+ return f.read()
def get_cached_func_info(self, call_id):
"""Return information related to the cached function if it exists."""
- pass
+ path = os.path.join(*call_id)
+ full_path = os.path.join(self.location, path)
+ if not os.path.exists(full_path):
+ return None
+
+ func_code = self.get_cached_func_code(call_id)
+ metadata = self.get_metadata(call_id)
+
+ return {
+ 'path': full_path,
+ 'func_code': func_code,
+ 'metadata': metadata
+ }
def clear(self):
"""Clear the whole store content."""
- pass
+ if os.path.exists(self.location):
+ shutil.rmtree(self.location)
+ self.create_location(self.location)
- def enforce_store_limits(self, bytes_limit, items_limit=None, age_limit
- =None):
+ def enforce_store_limits(self, bytes_limit, items_limit=None, age_limit=None):
"""
Remove the store's oldest files to enforce item, byte, and age limits.
"""
- pass
+ items_to_delete = self._get_items_to_delete(bytes_limit, items_limit, age_limit)
+ for item in items_to_delete:
+ os.remove(item.path)
- def _get_items_to_delete(self, bytes_limit, items_limit=None, age_limit
- =None):
+ def _get_items_to_delete(self, bytes_limit, items_limit=None, age_limit=None):
"""
Get items to delete to keep the store under size, file, & age limits.
"""
- pass
+ items = []
+ total_size = 0
+ for dirpath, dirnames, filenames in os.walk(self.location):
+ for filename in filenames:
+ full_path = os.path.join(dirpath, filename)
+ stats = os.stat(full_path)
+ items.append(CacheItemInfo(path=full_path,
+ size=stats.st_size,
+ last_access=datetime.datetime.fromtimestamp(stats.st_atime)))
+ total_size += stats.st_size
+
+ items.sort(key=lambda x: x.last_access)
+ items_to_delete = []
+
+ now = datetime.datetime.now()
+
+ while (len(items) > items_limit if items_limit else False) or \
+ (total_size > bytes_limit if bytes_limit else False) or \
+ (age_limit and (now - items[0].last_access).total_seconds() > age_limit):
+ item = items.pop(0)
+ items_to_delete.append(item)
+ total_size -= item.size
+
+ return items_to_delete
def _concurrency_safe_write(self, to_write, filename, write_func):
"""Writes an object into a file in a concurrency-safe way."""
- pass
+ temporary_filename = f"{filename}.tmp"
+ try:
+ with open(temporary_filename, 'wb') as f:
+ write_func(to_write, f)
+ self._move_item(temporary_filename, filename)
+ except:
+ if os.path.exists(temporary_filename):
+ os.unlink(temporary_filename)
+ raise
def __repr__(self):
"""Printable representation of the store location."""
diff --git a/joblib/_utils.py b/joblib/_utils.py
index d2feff7..8665df1 100644
--- a/joblib/_utils.py
+++ b/joblib/_utils.py
@@ -18,7 +18,17 @@ def eval_expr(expr):
>>> eval_expr('1 + 2*3**(4) / (6 + -7)')
-161.0
"""
- pass
+ def eval_(node):
+ if isinstance(node, ast.Num):
+ return node.n
+ elif isinstance(node, ast.BinOp):
+ return operators[type(node.op)](eval_(node.left), eval_(node.right))
+ elif isinstance(node, ast.UnaryOp):
+ return operators[type(node.op)](eval_(node.operand))
+ else:
+ raise TypeError(node)
+
+ return eval_(ast.parse(expr, mode='eval').body)
@dataclass(frozen=True)
diff --git a/joblib/backports.py b/joblib/backports.py
index b7178c8..458b5a3 100644
--- a/joblib/backports.py
+++ b/joblib/backports.py
@@ -95,7 +95,22 @@ try:
newly-created memmap that sends a maybe_unlink request for the
memmaped file to resource_tracker.
"""
- pass
+ mm = np.memmap(filename, dtype=dtype, mode=mode, offset=offset,
+ shape=shape, order=order)
+
+ if unlink_on_gc_collect:
+ from joblib.disk import delete_folder
+ import weakref
+
+ def cleanup(path):
+ try:
+ delete_folder(path)
+ except OSError:
+ pass
+
+ weakref.finalize(mm, cleanup, filename)
+
+ return mm
except ImportError:
if os.name == 'nt':
access_denied_errors = 5, 13
@@ -107,6 +122,13 @@ if os.name == 'nt':
On Windows os.replace can yield permission errors if executed by two
different processes.
"""
- pass
+ for i in range(10): # Try up to 10 times
+ try:
+ return replace(src, dst)
+ except PermissionError as e:
+ if e.winerror not in access_denied_errors:
+ raise
+ time.sleep(0.1 * (2 ** i)) # Exponential backoff
+ raise PermissionError(f"Failed to rename {src} to {dst} after 10 attempts")
else:
from os import replace as concurrency_safe_rename
diff --git a/joblib/compressor.py b/joblib/compressor.py
index 7a72b91..cd61e06 100644
--- a/joblib/compressor.py
+++ b/joblib/compressor.py
@@ -42,7 +42,10 @@ def register_compressor(compressor_name, compressor, force=False):
compressor: CompressorWrapper
An instance of a 'CompressorWrapper'.
"""
- pass
+ if compressor_name in _COMPRESSORS and not force:
+ raise ValueError(f"Compressor '{compressor_name}' already registered. "
+ "Use force=True to override.")
+ _COMPRESSORS[compressor_name] = compressor
class CompressorWrapper:
@@ -68,11 +71,17 @@ class CompressorWrapper:
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError("bz2 module is not available")
+ if compresslevel is None:
+ compresslevel = 9
+ return self.fileobj_factory(fileobj, 'wb', compresslevel=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError("bz2 module is not available")
+ return self.fileobj_factory(fileobj, 'rb')
class BZ2CompressorWrapper(CompressorWrapper):
@@ -87,11 +96,15 @@ class BZ2CompressorWrapper(CompressorWrapper):
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError("lzma module is not available")
+ return self.fileobj_factory(fileobj, mode='wb', format=self._lzma_format)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError("lzma module is not available")
+ return self.fileobj_factory(fileobj, mode='rb', format=self._lzma_format)
class LZMACompressorWrapper(CompressorWrapper):
@@ -108,11 +121,15 @@ class LZMACompressorWrapper(CompressorWrapper):
def compressor_file(self, fileobj, compresslevel=None):
"""Returns an instance of a compressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError(LZ4_NOT_INSTALLED_ERROR)
+ return self.fileobj_factory(fileobj, mode='wb', compression_level=compresslevel)
def decompressor_file(self, fileobj):
"""Returns an instance of a decompressor file object."""
- pass
+ if self.fileobj_factory is None:
+ raise ValueError(LZ4_NOT_INSTALLED_ERROR)
+ return self.fileobj_factory(fileobj, mode='rb')
class XZCompressorWrapper(LZMACompressorWrapper):
@@ -210,28 +227,47 @@ class BinaryZlibFile(io.BufferedIOBase):
May be called more than once without error. Once the file is
closed, any other operation on it will raise a ValueError.
"""
- pass
+ with self._lock:
+ if self._mode == _MODE_CLOSED:
+ return
+ try:
+ if self._mode in (_MODE_READ, _MODE_READ_EOF):
+ self._decompressor = None
+ elif self._mode == _MODE_WRITE:
+ self._fp.write(self._compressor.flush())
+ self._compressor = None
+ finally:
+ try:
+ if self._closefp:
+ self._fp.close()
+ finally:
+ self._fp = None
+ self._closefp = False
+ self._mode = _MODE_CLOSED
@property
def closed(self):
"""True if this file is closed."""
- pass
+ return self._mode == _MODE_CLOSED
def fileno(self):
"""Return the file descriptor for the underlying file."""
- pass
+ self._check_not_closed()
+ return self._fp.fileno()
def seekable(self):
"""Return whether the file supports seeking."""
- pass
+ return self.readable() and self._fp.seekable()
def readable(self):
"""Return whether the file was opened for reading."""
- pass
+ self._check_not_closed()
+ return self._mode in (_MODE_READ, _MODE_READ_EOF)
def writable(self):
"""Return whether the file was opened for writing."""
- pass
+ self._check_not_closed()
+ return self._mode == _MODE_WRITE
def read(self, size=-1):
"""Read up to size uncompressed bytes from the file.
@@ -239,14 +275,27 @@ class BinaryZlibFile(io.BufferedIOBase):
If size is negative or omitted, read until EOF is reached.
Returns b'' if the file is already at EOF.
"""
- pass
+ with self._lock:
+ self._check_can_read()
+ if size == 0:
+ return b""
+
+ if self._mode == _MODE_READ_EOF or size < 0:
+ return self._read_all()
+
+ return self._read_limited(size)
def readinto(self, b):
"""Read up to len(b) bytes into b.
Returns the number of bytes read (0 for EOF).
"""
- pass
+ with self._lock:
+ self._check_can_read()
+ data = self.read(len(b))
+ n = len(data)
+ b[:n] = data
+ return n
def write(self, data):
"""Write a byte string to the file.
@@ -255,7 +304,11 @@ class BinaryZlibFile(io.BufferedIOBase):
always len(data). Note that due to buffering, the file on disk
may not reflect the data written until close() is called.
"""
- pass
+ with self._lock:
+ self._check_can_write()
+ compressed = self._compressor.compress(data)
+ self._fp.write(compressed)
+ return len(data)
def seek(self, offset, whence=0):
"""Change the file position.
@@ -272,11 +325,102 @@ class BinaryZlibFile(io.BufferedIOBase):
Note that seeking is emulated, so depending on the parameters,
this operation may be extremely slow.
"""
- pass
+ with self._lock:
+ self._check_can_seek()
+
+ if whence == 0:
+ if offset < 0:
+ raise ValueError("Negative seek position {}".format(offset))
+ return self._seek_forward(offset)
+ elif whence == 1:
+ return self._seek_forward(offset)
+ elif whence == 2:
+ if offset > 0:
+ raise ValueError("Positive seek position {}".format(offset))
+ return self._seek_backward(offset)
+ else:
+ raise ValueError("Invalid whence value")
def tell(self):
"""Return the current file position."""
- pass
+ with self._lock:
+ self._check_not_closed()
+ return self._pos
+
+ def _check_not_closed(self):
+ if self.closed:
+ raise ValueError("I/O operation on closed file")
+
+ def _check_can_read(self):
+ if self._mode not in (_MODE_READ, _MODE_READ_EOF):
+ raise io.UnsupportedOperation("File not open for reading")
+
+ def _check_can_write(self):
+ if self._mode != _MODE_WRITE:
+ raise io.UnsupportedOperation("File not open for writing")
+
+ def _check_can_seek(self):
+ if self._mode not in (_MODE_READ, _MODE_READ_EOF):
+ raise io.UnsupportedOperation("Seeking is only supported on files open for reading")
+ if not self._fp.seekable():
+ raise io.UnsupportedOperation("The underlying file object does not support seeking")
+
+ def _read_all(self):
+ chunks = []
+ while True:
+ chunk = self._fp.read(_BUFFER_SIZE)
+ if not chunk:
+ break
+ decompressed = self._decompressor.decompress(chunk)
+ if decompressed:
+ chunks.append(decompressed)
+ if self._decompressor.unused_data:
+ self._fp.seek(-len(self._decompressor.unused_data), 1)
+ self._mode = _MODE_READ_EOF
+ self._pos += sum(len(chunk) for chunk in chunks)
+ return b"".join(chunks)
+
+ def _read_limited(self, size):
+ chunks = []
+ while size > 0:
+ chunk = self._fp.read(min(_BUFFER_SIZE, size))
+ if not chunk:
+ break
+ decompressed = self._decompressor.decompress(chunk)
+ if decompressed:
+ chunks.append(decompressed)
+ size -= len(decompressed)
+ if self._decompressor.unused_data:
+ self._fp.seek(-len(self._decompressor.unused_data), 1)
+ self._pos += sum(len(chunk) for chunk in chunks)
+ return b"".join(chunks)
+
+ def _seek_forward(self, offset):
+ if offset < self._pos:
+ self._fp.seek(0)
+ self._decompressor = zlib.decompressobj(self.wbits)
+ self._pos = 0
+ else:
+ offset -= self._pos
+ while offset > 0:
+ chunk = self.read(min(_BUFFER_SIZE, offset))
+ if not chunk:
+ break
+ offset -= len(chunk)
+ return self._pos
+
+ def _seek_backward(self, offset):
+ if offset:
+ self._fp.seek(0)
+ self._decompressor = zlib.decompressobj(self.wbits)
+ self._pos = 0
+ while self._pos > offset:
+ chunk = self._fp.read(min(_BUFFER_SIZE, self._pos - offset))
+ if not chunk:
+ break
+ decompressed = self._decompressor.decompress(chunk)
+ self._pos += len(decompressed)
+ return self._pos
class ZlibCompressorWrapper(CompressorWrapper):
diff --git a/joblib/disk.py b/joblib/disk.py
index b35e507..707fcbf 100644
--- a/joblib/disk.py
+++ b/joblib/disk.py
@@ -15,20 +15,39 @@ except NameError:
def disk_used(path):
""" Return the disk usage in a directory."""
- pass
+ total_size = 0
+ for dirpath, dirnames, filenames in os.walk(path):
+ for f in filenames:
+ fp = os.path.join(dirpath, f)
+ total_size += os.path.getsize(fp)
+ return total_size
def memstr_to_bytes(text):
""" Convert a memory text to its value in bytes.
"""
- pass
+ units = {
+ 'K': 1024,
+ 'M': 1024 ** 2,
+ 'G': 1024 ** 3,
+ 'T': 1024 ** 4,
+ }
+ text = text.upper().strip()
+ if text[-1] in units:
+ return int(float(text[:-1]) * units[text[-1]])
+ else:
+ return int(float(text))
def mkdirp(d):
"""Ensure directory d exists (like mkdir -p on Unix)
No guarantee that the directory is writable.
"""
- pass
+ try:
+ os.makedirs(d)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
RM_SUBDIRS_RETRY_TIME = 0.1
@@ -47,9 +66,29 @@ def rm_subdirs(path, onerror=None):
exc_info is a tuple returned by sys.exc_info(). If onerror is None,
an exception is raised.
"""
- pass
+ for root, dirs, files in os.walk(path, topdown=False):
+ for name in dirs:
+ fullname = os.path.join(root, name)
+ try:
+ shutil.rmtree(fullname)
+ except Exception:
+ if onerror is not None:
+ onerror(os.rmdir, fullname, sys.exc_info())
+ else:
+ raise
def delete_folder(folder_path, onerror=None, allow_non_empty=True):
"""Utility function to cleanup a temporary folder if it still exists."""
- pass
+ if os.path.exists(folder_path):
+ if allow_non_empty:
+ shutil.rmtree(folder_path, onerror=onerror)
+ else:
+ try:
+ os.rmdir(folder_path)
+ except OSError as e:
+ if e.errno != errno.ENOTEMPTY:
+ if onerror is not None:
+ onerror(os.rmdir, folder_path, sys.exc_info())
+ else:
+ raise
diff --git a/joblib/executor.py b/joblib/executor.py
index 6eea29c..9b822a7 100644
--- a/joblib/executor.py
+++ b/joblib/executor.py
@@ -19,7 +19,22 @@ class MemmappingExecutor(_ReusablePoolExecutor):
"""Factory for ReusableExecutor with automatic memmapping for large
numpy arrays.
"""
- pass
+ global _executor_args
+ if _executor_args is None:
+ reducers = get_memmapping_reducers(temp_folder, context_id)
+ _executor_args = dict(
+ timeout=timeout,
+ initializer=initializer,
+ initargs=initargs,
+ env=env,
+ reducers=reducers,
+ **backend_args
+ )
+
+ executor = cls(n_jobs, **_executor_args)
+ executor.reduce_call = _executor_args['reducers']['reduce_call']
+ executor.temp_folder_manager = TemporaryResourcesManager(temp_folder)
+ return executor
class _TestingMemmappingExecutor(MemmappingExecutor):
@@ -30,4 +45,4 @@ class _TestingMemmappingExecutor(MemmappingExecutor):
def apply_async(self, func, args):
"""Schedule a func to be run"""
- pass
+ return super().apply_async(self.reduce_call(func), args)
diff --git a/joblib/externals/cloudpickle/cloudpickle.py b/joblib/externals/cloudpickle/cloudpickle.py
index 92fb769..24fa2a3 100644
--- a/joblib/externals/cloudpickle/cloudpickle.py
+++ b/joblib/externals/cloudpickle/cloudpickle.py
@@ -104,12 +104,14 @@ def register_pickle_by_value(module):
Note: this feature is considered experimental. See the cloudpickle
README.md file for more details and limitations.
"""
- pass
+ global _PICKLE_BY_VALUE_MODULES
+ _PICKLE_BY_VALUE_MODULES.add(module)
def unregister_pickle_by_value(module):
"""Unregister that the input module should be pickled by value."""
- pass
+ global _PICKLE_BY_VALUE_MODULES
+ _PICKLE_BY_VALUE_MODULES.discard(module)
def _whichmodule(obj, name):
@@ -121,7 +123,20 @@ def _whichmodule(obj, name):
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
- pass
+ if isinstance(obj, types.ModuleType):
+ return obj.__name__
+ module_name = getattr(obj, '__module__', None)
+ if module_name is not None:
+ return module_name
+ for module_name, module in sys.modules.items():
+ if module_name == '__main__':
+ continue
+ try:
+ if getattr(module, name, None) is obj:
+ return module_name
+ except Exception:
+ pass
+ return None
def _should_pickle_by_reference(obj, name=None):
@@ -138,7 +153,19 @@ def _should_pickle_by_reference(obj, name=None):
functions and classes or for attributes of modules that have been
explicitly registered to be pickled by value.
"""
- pass
+ if name is None:
+ name = getattr(obj, '__qualname__', None)
+ if name is None:
+ name = getattr(obj, '__name__', None)
+
+ module = _whichmodule(obj, name)
+ if module is None:
+ return False
+ if module in _PICKLE_BY_VALUE_MODULES:
+ return False
+ if module in ('__main__', '__mp_main__'):
+ return False
+ return True
def _extract_code_globals(co):
diff --git a/joblib/externals/loky/_base.py b/joblib/externals/loky/_base.py
index 6d789c8..6400c8c 100644
--- a/joblib/externals/loky/_base.py
+++ b/joblib/externals/loky/_base.py
@@ -1,6 +1,107 @@
from concurrent.futures import Future as _BaseFuture
-from concurrent.futures._base import LOGGER
+from concurrent.futures._base import LOGGER, PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED
+import threading
class Future(_BaseFuture):
- pass
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._state = PENDING
+ self._result = None
+ self._exception = None
+ self._waiters = []
+ self._done_callbacks = []
+
+ def cancel(self):
+ with self._condition:
+ if self._state in [RUNNING, FINISHED]:
+ return False
+ if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
+ return True
+ self._state = CANCELLED
+ self._condition.notify_all()
+ self._invoke_callbacks()
+ return True
+
+ def cancelled(self):
+ with self._condition:
+ return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]
+
+ def running(self):
+ with self._condition:
+ return self._state == RUNNING
+
+ def done(self):
+ with self._condition:
+ return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]
+
+ def result(self, timeout=None):
+ with self._condition:
+ if self._state == CANCELLED:
+ raise CancelledError()
+ elif self._state == FINISHED:
+ return self._result
+ self._condition.wait(timeout)
+ if self._state == CANCELLED:
+ raise CancelledError()
+ elif self._state == FINISHED:
+ return self._result
+ else:
+ raise TimeoutError()
+
+ def exception(self, timeout=None):
+ with self._condition:
+ if self._state == CANCELLED:
+ raise CancelledError()
+ elif self._state == FINISHED:
+ return self._exception
+ self._condition.wait(timeout)
+ if self._state == CANCELLED:
+ raise CancelledError()
+ elif self._state == FINISHED:
+ return self._exception
+ else:
+ raise TimeoutError()
+
+ def add_done_callback(self, fn):
+ with self._condition:
+ if self._state not in [CANCELLED, FINISHED]:
+ self._done_callbacks.append(fn)
+ return
+ fn(self)
+
+ def _invoke_callbacks(self):
+ for callback in self._done_callbacks:
+ try:
+ callback(self)
+ except Exception:
+ LOGGER.exception('exception calling callback for %r', self)
+
+ def set_running_or_notify_cancel(self):
+ with self._condition:
+ if self._state == CANCELLED:
+ self._state = CANCELLED_AND_NOTIFIED
+ return False
+ elif self._state == PENDING:
+ self._state = RUNNING
+ return True
+ else:
+ raise RuntimeError('Future in unexpected state')
+
+ def set_result(self, result):
+ with self._condition:
+ self._result = result
+ self._state = FINISHED
+ for waiter in self._waiters:
+ waiter.add_result(self)
+ self._condition.notify_all()
+ self._invoke_callbacks()
+
+ def set_exception(self, exception):
+ with self._condition:
+ self._exception = exception
+ self._state = FINISHED
+ for waiter in self._waiters:
+ waiter.add_exception(self)
+ self._condition.notify_all()
+ self._invoke_callbacks()
diff --git a/joblib/externals/loky/backend/_posix_reduction.py b/joblib/externals/loky/backend/_posix_reduction.py
index c819d41..25abb3b 100644
--- a/joblib/externals/loky/backend/_posix_reduction.py
+++ b/joblib/externals/loky/backend/_posix_reduction.py
@@ -10,7 +10,7 @@ HAVE_SEND_HANDLE = hasattr(socket, 'CMSG_LEN') and hasattr(socket, 'SCM_RIGHTS'
def DupFd(fd):
"""Return a wrapper for an fd."""
- pass
+ return os.dup(fd)
register(socket.socket, _reduce_socket)
diff --git a/joblib/externals/loky/backend/context.py b/joblib/externals/loky/backend/context.py
index 1e0d413..1a1925e 100644
--- a/joblib/externals/loky/backend/context.py
+++ b/joblib/externals/loky/backend/context.py
@@ -47,12 +47,41 @@ def cpu_count(only_physical_cores=False):
It is also always larger or equal to 1.
"""
- pass
+ global physical_cores_cache
+
+ if only_physical_cores and physical_cores_cache is None:
+ physical_cores_cache, _ = _count_physical_cores()
+
+ if only_physical_cores and physical_cores_cache != "not found":
+ return max(1, physical_cores_cache)
+
+ os_cpu_count = mp.cpu_count()
+ cpu_count_user = _cpu_count_user(os_cpu_count)
+
+ if sys.platform == 'win32':
+ cpu_count = min(cpu_count_user, _MAX_WINDOWS_WORKERS)
+ else:
+ cpu_count = cpu_count_user
+
+ return max(1, cpu_count)
def _cpu_count_user(os_cpu_count):
"""Number of user defined available CPUs"""
- pass
+ cpu_count_user = os.environ.get('LOKY_MAX_CPU_COUNT')
+ if cpu_count_user is not None:
+ try:
+ cpu_count_user = int(cpu_count_user)
+ except ValueError:
+ warnings.warn(
+ f"LOKY_MAX_CPU_COUNT environment variable cannot be parsed as "
+ f"an integer: {cpu_count_user!r}. Using {os_cpu_count} as the "
+ f"number of CPUs.", UserWarning)
+ cpu_count_user = os_cpu_count
+ else:
+ cpu_count_user = os_cpu_count
+
+ return min(cpu_count_user, os_cpu_count)
def _count_physical_cores():
@@ -63,7 +92,58 @@ def _count_physical_cores():
The number of physical cores is cached to avoid repeating subprocess calls.
"""
- pass
+ global physical_cores_cache
+
+ if physical_cores_cache is not None:
+ return physical_cores_cache, None
+
+ try:
+ if sys.platform == 'linux':
+ cores = subprocess.check_output(
+ ['lscpu', '-p=Core,Socket']).decode('utf-8')
+ cores = {line.split(',')[0:2] for line in cores.splitlines()
+ if not line.startswith('#')}
+ physical_cores_cache = len(cores)
+ elif sys.platform == 'darwin':
+ physical_cores_cache = int(subprocess.check_output(
+ ['sysctl', '-n', 'hw.physicalcpu']).decode('utf-8').strip())
+ elif sys.platform == 'win32':
+ import ctypes
+ import ctypes.wintypes
+
+ DWORD = ctypes.wintypes.DWORD
+ WORD = ctypes.wintypes.WORD
+
+ class SYSTEM_LOGICAL_PROCESSOR_INFORMATION(ctypes.Structure):
+ _fields_ = [
+ ('ProcessorMask', ctypes.c_void_p),
+ ('Relationship', DWORD),
+ ('_', ctypes.c_ulonglong)
+ ]
+
+ buffer = ctypes.create_string_buffer(1)
+ size = DWORD(ctypes.sizeof(buffer))
+ while True:
+ if ctypes.windll.kernel32.GetLogicalProcessorInformation(
+ ctypes.byref(buffer), ctypes.byref(size)):
+ break
+ buffer = ctypes.create_string_buffer(size.value)
+
+ system_info = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION *
+ (size.value // ctypes.sizeof(
+ SYSTEM_LOGICAL_PROCESSOR_INFORMATION))).\
+ from_buffer_copy(buffer)
+ physical_cores_cache = sum(info.Relationship == 1
+ for info in system_info)
+ else:
+ physical_cores_cache = "not found"
+ raise NotImplementedError(
+ "Counting physical cores not implemented for "
+ "this platform")
+ return physical_cores_cache, None
+ except Exception as e:
+ physical_cores_cache = "not found"
+ return "not found", e
class LokyContext(BaseContext):
@@ -74,11 +154,13 @@ class LokyContext(BaseContext):
def Queue(self, maxsize=0, reducers=None):
"""Returns a queue object"""
- pass
+ from .queues import Queue
+ return Queue(maxsize, reducers=reducers, ctx=self.get_context())
def SimpleQueue(self, reducers=None):
"""Returns a queue object"""
- pass
+ from .queues import SimpleQueue
+ return SimpleQueue(reducers=reducers, ctx=self.get_context())
if sys.platform != 'win32':
"""For Unix platform, use our custom implementation of synchronize
ensuring that we use the loky.backend.resource_tracker to clean-up
diff --git a/joblib/externals/loky/backend/fork_exec.py b/joblib/externals/loky/backend/fork_exec.py
index a8af34a..056bec8 100644
--- a/joblib/externals/loky/backend/fork_exec.py
+++ b/joblib/externals/loky/backend/fork_exec.py
@@ -4,4 +4,11 @@ import sys
def close_fds(keep_fds):
"""Close all the file descriptors except those in keep_fds."""
- pass
+ import resource
+ max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
+ for fd in range(3, max_fd):
+ if fd not in keep_fds:
+ try:
+ os.close(fd)
+ except OSError:
+ pass
diff --git a/joblib/externals/loky/backend/popen_loky_win32.py b/joblib/externals/loky/backend/popen_loky_win32.py
index b174751..fa533f6 100644
--- a/joblib/externals/loky/backend/popen_loky_win32.py
+++ b/joblib/externals/loky/backend/popen_loky_win32.py
@@ -62,14 +62,28 @@ class Popen(_Popen):
def get_command_line(pipe_handle, parent_pid, **kwds):
"""Returns prefix of command line used for spawning a child process."""
- pass
+ cmd = [sys.executable, '-c',
+ 'from joblib.externals.loky.backend.popen_loky_win32 import main; '
+ 'main(%r, %r)' % (pipe_handle, parent_pid)]
+ return cmd
def is_forking(argv):
"""Return whether commandline indicates we are forking."""
- pass
+ return len(argv) >= 2 and argv[1] == '--multiprocessing-fork'
def main(pipe_handle, parent_pid=None):
"""Run code specified by data received over pipe."""
- pass
+ fd = msvcrt.open_osfhandle(pipe_handle, os.O_RDONLY)
+ with open(fd, 'rb') as from_parent:
+ process.current_process()._inheriting = True
+ try:
+ preparation_data = load(from_parent)
+ spawn.prepare(preparation_data)
+ self = load(from_parent)
+ finally:
+ del process.current_process()._inheriting
+
+ exitcode = self._bootstrap()
+ sys.exit(exitcode)
diff --git a/joblib/externals/loky/backend/queues.py b/joblib/externals/loky/backend/queues.py
index 704e2a3..7596205 100644
--- a/joblib/externals/loky/backend/queues.py
+++ b/joblib/externals/loky/backend/queues.py
@@ -36,7 +36,12 @@ class Queue(mp_Queue):
Private API hook called when feeding data in the background thread
raises an exception. For overriding by concurrent.futures.
"""
- pass
+ if isinstance(e, OSError) and e.errno == errno.EPIPE:
+ if self._ignore_epipe:
+ return
+ util.debug('Error in queue feeder thread: %s' % e)
+ # Notify the queue management thread about the error
+ self._thread_queue.put(None)
class SimpleQueue(mp_SimpleQueue):
diff --git a/joblib/externals/loky/backend/reduction.py b/joblib/externals/loky/backend/reduction.py
index 6770d67..3e797bb 100644
--- a/joblib/externals/loky/backend/reduction.py
+++ b/joblib/externals/loky/backend/reduction.py
@@ -36,7 +36,20 @@ set_loky_pickler()
def dump(obj, file, reducers=None, protocol=None):
"""Replacement for pickle.dump() using _LokyPickler."""
- pass
+ if protocol is None:
+ protocol = HIGHEST_PROTOCOL
+ if reducers is None:
+ reducers = {}
+
+ if isinstance(file, io.IOBase):
+ pickler = _LokyPickler(file, protocol=protocol)
+ pickler.dispatch_table.update(reducers)
+ pickler.dump(obj)
+ else:
+ with open(file, 'wb') as f:
+ pickler = _LokyPickler(f, protocol=protocol)
+ pickler.dispatch_table.update(reducers)
+ pickler.dump(obj)
__all__ = ['dump', 'dumps', 'loads', 'register', 'set_loky_pickler']
diff --git a/joblib/externals/loky/backend/resource_tracker.py b/joblib/externals/loky/backend/resource_tracker.py
index 3550ef5..f876b9e 100644
--- a/joblib/externals/loky/backend/resource_tracker.py
+++ b/joblib/externals/loky/backend/resource_tracker.py
@@ -32,23 +32,64 @@ class ResourceTracker:
This can be run from any process. Usually a child process will use
the resource created by its parent."""
- pass
+ with self._lock:
+ if self._fd is not None:
+ # resource tracker already started
+ return
+ fds_to_pass = []
+ cmd = [sys.executable, '-c', 'from joblib.externals.loky.backend.resource_tracker import main; main()']
+ r, w = os.pipe()
+ self._fd = r
+ try:
+ fds_to_pass.append(w)
+ # process will out live us, so no need to wait on pid
+ pid = spawn.spawn_main(cmd, fds_to_pass)
+ self._pid = pid
+ except:
+ os.close(r)
+ raise
+ finally:
+ os.close(w)
def _check_alive(self):
"""Check for the existence of the resource tracker process."""
- pass
+ if self._pid is None:
+ return False
+ try:
+ os.kill(self._pid, 0)
+ except OSError:
+ return False
+ return True
def register(self, name, rtype):
"""Register a named resource, and increment its refcount."""
- pass
+ self.ensure_running()
+ msg = f'{name}:{rtype}:REGISTER\n'.encode('ascii')
+ if len(name) > 512:
+ # Prevent overflow on the C side
+ raise ValueError("name too long")
+ with self._lock:
+ os.write(self._fd, msg)
def unregister(self, name, rtype):
"""Unregister a named resource with resource tracker."""
- pass
+ self.ensure_running()
+ msg = f'{name}:{rtype}:UNREGISTER\n'.encode('ascii')
+ if len(name) > 512:
+ # Prevent overflow on the C side
+ raise ValueError("name too long")
+ with self._lock:
+ os.write(self._fd, msg)
def maybe_unlink(self, name, rtype):
"""Decrement the refcount of a resource, and delete it if it hits 0"""
- pass
+ self.ensure_running()
+ msg = f'{name}:{rtype}:MAYBE_UNLINK\n'.encode('ascii')
+ if len(name) > 512:
+ # Prevent overflow on the C side
+ raise ValueError("name too long")
+ with self._lock:
+ os.write(self._fd, msg)
_resource_tracker = ResourceTracker()
@@ -59,6 +100,60 @@ unregister = _resource_tracker.unregister
getfd = _resource_tracker.getfd
-def main(fd, verbose=0):
+def main(fd=None, verbose=0):
"""Run resource tracker."""
- pass
+ global VERBOSE
+ VERBOSE = verbose
+
+ if fd is None:
+ fd = sys.stdin.fileno()
+
+ cache = {}
+ try:
+ while True:
+ msg = os.read(fd, 512)
+ if not msg:
+ break
+ try:
+ name, rtype, action = msg.decode('ascii').strip().split(':')
+ except ValueError:
+ continue
+
+ if action == 'REGISTER':
+ cache.setdefault(rtype, {}).setdefault(name, 0)
+ cache[rtype][name] += 1
+ elif action == 'UNREGISTER':
+ if name in cache.get(rtype, {}):
+ cache[rtype][name] -= 1
+ if cache[rtype][name] == 0:
+ del cache[rtype][name]
+ cleanup = _CLEANUP_FUNCS.get(rtype)
+ if cleanup:
+ try:
+ cleanup(name)
+ except Exception as e:
+ warnings.warn(f'Error cleaning up {name}: {e}')
+ elif action == 'MAYBE_UNLINK':
+ if name in cache.get(rtype, {}):
+ cache[rtype][name] -= 1
+ if cache[rtype][name] == 0:
+ del cache[rtype][name]
+ cleanup = _CLEANUP_FUNCS.get(rtype)
+ if cleanup:
+ try:
+ cleanup(name)
+ except Exception as e:
+ warnings.warn(f'Error cleaning up {name}: {e}')
+ else:
+ warnings.warn(f'Unrecognized action: {action}')
+ finally:
+ if sys.platform != 'win32':
+ signal.signal(signal.SIGTERM, signal.SIG_IGN)
+ for rtype, rdict in cache.items():
+ for name in rdict:
+ cleanup = _CLEANUP_FUNCS.get(rtype)
+ if cleanup:
+ try:
+ cleanup(name)
+ except Exception as e:
+ warnings.warn(f'Error cleaning up {name}: {e}')
diff --git a/joblib/externals/loky/backend/spawn.py b/joblib/externals/loky/backend/spawn.py
index aadb9e2..6653777 100644
--- a/joblib/externals/loky/backend/spawn.py
+++ b/joblib/externals/loky/backend/spawn.py
@@ -20,7 +20,29 @@ else:
def get_preparation_data(name, init_main_module=True):
"""Return info about parent needed by child to unpickle process object."""
- pass
+ d = {}
+ if init_main_module:
+ d['init_main_module'] = True
+
+ # Get sys.path and sys.argv
+ d['sys_path'] = sys.path
+ d['sys_argv'] = sys.argv
+
+ # Get the main module's __spec__
+ main_module = sys.modules['__main__']
+ if hasattr(main_module, '__spec__'):
+ d['main_module_spec'] = main_module.__spec__
+
+ # Get the current working directory
+ d['cwd'] = os.getcwd()
+
+ # Get environment variables
+ d['env'] = dict(os.environ)
+
+ # Get the process name
+ d['name'] = name
+
+ return d
old_main_modules = []
@@ -28,4 +50,33 @@ old_main_modules = []
def prepare(data, parent_sentinel=None):
"""Try to get current process ready to unpickle process object."""
- pass
+ if 'init_main_module' in data and data['init_main_module']:
+ # Set up the main module
+ runpy.run_module(sys.modules['__main__'].__spec__.name,
+ run_name='__mp_main__', alter_sys=True)
+
+ # Update sys.path
+ sys.path = data.get('sys_path', sys.path)
+
+ # Update sys.argv
+ sys.argv = data.get('sys_argv', sys.argv)
+
+ # Change the current working directory
+ os.chdir(data.get('cwd', os.getcwd()))
+
+ # Update environment variables
+ os.environ.update(data.get('env', {}))
+
+ # Set the process name
+ util.set_process_name(data.get('name', 'loky_process'))
+
+ # Handle the parent sentinel (Windows-specific)
+ if parent_sentinel is not None and sys.platform == 'win32':
+ parent_sentinel = duplicate(parent_sentinel, inheritable=True)
+ process.current_process()._parent_pid = os.getppid()
+ process.current_process()._parent_sentinel = parent_sentinel
+
+ # Clear and update old_main_modules
+ global old_main_modules
+ old_main_modules.clear()
+ old_main_modules.append(sys.modules['__main__'])
diff --git a/joblib/externals/loky/backend/utils.py b/joblib/externals/loky/backend/utils.py
index 3a17859..89c0054 100644
--- a/joblib/externals/loky/backend/utils.py
+++ b/joblib/externals/loky/backend/utils.py
@@ -14,17 +14,37 @@ except ImportError:
def kill_process_tree(process, use_psutil=True):
"""Terminate process and its descendants with SIGKILL"""
- pass
+ if use_psutil and psutil is not None:
+ try:
+ parent = psutil.Process(process.pid)
+ children = parent.children(recursive=True)
+ for child in children:
+ child.kill()
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
+ else:
+ _kill_process_tree_without_psutil(process)
def _kill_process_tree_without_psutil(process):
"""Terminate a process and its descendants."""
- pass
+ if sys.platform != 'win32':
+ _posix_recursive_kill(process.pid)
+ else:
+ subprocess.call(['taskkill', '/F', '/T', '/PID', str(process.pid)])
def _posix_recursive_kill(pid):
"""Recursively kill the descendants of a process before killing it."""
- pass
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ for child in children:
+ child.kill()
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
def get_exitcodes_terminated_worker(processes):
@@ -33,9 +53,33 @@ def get_exitcodes_terminated_worker(processes):
If necessary, wait (up to .25s) for the system to correctly set the
exitcode of one terminated worker.
"""
- pass
+ exitcodes = []
+ for process in processes:
+ try:
+ exitcode = process.exitcode
+ if exitcode is None:
+ process.join(timeout=0.25)
+ exitcode = process.exitcode
+ if exitcode is not None:
+ exitcodes.append(exitcode)
+ except Exception:
+ pass
+
+ if exitcodes:
+ return _format_exitcodes(exitcodes)
+ return ""
def _format_exitcodes(exitcodes):
"""Format a list of exit code with names of the signals if possible"""
- pass
+ formatted_exitcodes = []
+ for exitcode in exitcodes:
+ if exitcode < 0:
+ try:
+ sig_name = signal.Signals(-exitcode).name
+ formatted_exitcodes.append(f"{exitcode} ({sig_name})")
+ except ValueError:
+ formatted_exitcodes.append(str(exitcode))
+ else:
+ formatted_exitcodes.append(str(exitcode))
+ return ", ".join(formatted_exitcodes)
diff --git a/joblib/externals/loky/cloudpickle_wrapper.py b/joblib/externals/loky/cloudpickle_wrapper.py
index 808ade4..75466fc 100644
--- a/joblib/externals/loky/cloudpickle_wrapper.py
+++ b/joblib/externals/loky/cloudpickle_wrapper.py
@@ -37,4 +37,16 @@ def wrap_non_picklable_objects(obj, keep_wrapper=True):
objects in the main scripts and to implement __reduce__ functions for
complex classes.
"""
- pass
+ if isinstance(obj, partial):
+ # Handle partial functions
+ return obj
+ elif callable(obj):
+ # Handle callable objects
+ if obj in WRAP_CACHE:
+ return WRAP_CACHE[obj]
+ wrapper = CallableObjectWrapper(obj, keep_wrapper=keep_wrapper)
+ WRAP_CACHE[obj] = wrapper
+ return wrapper
+ else:
+ # Handle non-callable objects
+ return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper)
diff --git a/joblib/externals/loky/initializers.py b/joblib/externals/loky/initializers.py
index 81c0c79..a96110f 100644
--- a/joblib/externals/loky/initializers.py
+++ b/joblib/externals/loky/initializers.py
@@ -3,7 +3,11 @@ import warnings
def _viztracer_init(init_kwargs):
"""Initialize viztracer's profiler in worker processes"""
- pass
+ try:
+ import viztracer
+ viztracer.set_tracer(**init_kwargs)
+ except ImportError:
+ warnings.warn("viztracer is not installed. Profiling will not be available.")
class _ChainedInitializer:
@@ -26,4 +30,11 @@ def _chain_initializers(initializer_and_args):
If some initializers are None, they are filtered out.
"""
- pass
+ valid_initializers = [
+ (init, args) for init, args in initializer_and_args
+ if init is not None
+ ]
+ if not valid_initializers:
+ return None
+ initializers, args_list = zip(*valid_initializers)
+ return _ChainedInitializer(initializers), args_list
diff --git a/joblib/externals/loky/process_executor.py b/joblib/externals/loky/process_executor.py
index c68582c..2e84f34 100644
--- a/joblib/externals/loky/process_executor.py
+++ b/joblib/externals/loky/process_executor.py
@@ -180,7 +180,12 @@ class _SafeQueue(Queue):
def _get_chunks(chunksize, *iterables):
"""Iterates over zip()ed iterables in chunks."""
- pass
+ it = zip(*iterables)
+ while True:
+ chunk = tuple(itertools.islice(it, chunksize))
+ if not chunk:
+ return
+ yield chunk
def _process_chunk(fn, chunk):
@@ -192,12 +197,16 @@ def _process_chunk(fn, chunk):
This function is run in a separate process.
"""
- pass
+ return [fn(*args) for args in chunk]
def _sendback_result(result_queue, work_id, result=None, exception=None):
"""Safely send back the given result or exception"""
- pass
+ try:
+ result_queue.put(_ResultItem(work_id, exception, result))
+ except BaseException as e:
+ exc = _ExceptionWithTraceback(e)
+ result_queue.put(_ResultItem(work_id, exc))
def _process_worker(call_queue, result_queue, initializer, initargs,
@@ -221,7 +230,39 @@ def _process_worker(call_queue, result_queue, initializer, initargs,
workers timeout.
current_depth: Nested parallelism level, to avoid infinite spawning.
"""
- pass
+ if initializer is not None:
+ try:
+ initializer(*initargs)
+ except BaseException:
+ _sendback_result(result_queue, None, exception=_ExceptionWithTraceback(sys.exc_info()[1]))
+ return
+
+ watchdog = None
+ if timeout is not None:
+ watchdog = threading.Thread(target=_shutdown_worker,
+ args=(call_queue, result_queue, timeout, processes_management_lock, worker_exit_lock))
+ watchdog.daemon = True
+ watchdog.start()
+
+ while True:
+ try:
+ call_item = call_queue.get(block=True, timeout=timeout)
+ except queue.Empty:
+ if watchdog is not None:
+ watchdog.join()
+ processes_management_lock.acquire()
+ return
+ if call_item is None:
+ processes_management_lock.acquire()
+ return
+ try:
+ r = call_item()
+ except BaseException as e:
+ exc = _ExceptionWithTraceback(e)
+ _sendback_result(result_queue, call_item.work_id, exception=exc)
+ else:
+ _sendback_result(result_queue, call_item.work_id, result=r)
+ del call_item
class _ExecutorManagerThread(threading.Thread):
@@ -272,7 +313,8 @@ def _chain_from_iterable_of_lists(iterable):
Each item in *iterable* should be a list. This function is
careful not to keep references to yielded objects.
"""
- pass
+ for element in iterable:
+ yield from element
class LokyRecursionError(RuntimeError):
@@ -375,7 +417,13 @@ class ProcessPoolExecutor(Executor):
def _ensure_executor_running(self):
"""ensures all workers and management thread are running"""
- pass
+ with self._processes_management_lock:
+ if len(self._processes) == 0:
+ self._adjust_process_count()
+ if self._executor_manager_thread is None:
+ self._executor_manager_thread = _ExecutorManagerThread(self)
+ self._executor_manager_thread.start()
+ _threads_wakeups[self._executor_manager_thread] = self._executor_manager_thread_wakeup
submit.__doc__ = Executor.submit.__doc__
def map(self, fn, *iterables, **kwargs):
@@ -400,5 +448,14 @@ class ProcessPoolExecutor(Executor):
before the given timeout.
Exception: If fn(*args) raises for any values.
"""
- pass
+ timeout = kwargs.get('timeout', None)
+ chunksize = kwargs.get('chunksize', 1)
+
+ if chunksize < 1:
+ raise ValueError("chunksize must be >= 1.")
+
+ results = super().map(partial(_process_chunk, fn),
+ _get_chunks(chunksize, *iterables),
+ timeout=timeout)
+ return _chain_from_iterable_of_lists(results)
shutdown.__doc__ = Executor.shutdown.__doc__
diff --git a/joblib/externals/loky/reusable_executor.py b/joblib/externals/loky/reusable_executor.py
index 5509ecc..c3469f6 100644
--- a/joblib/externals/loky/reusable_executor.py
+++ b/joblib/externals/loky/reusable_executor.py
@@ -18,7 +18,11 @@ def _get_next_executor_id():
The purpose of this monotonic id is to help debug and test automated
instance creation.
"""
- pass
+ global _next_executor_id
+ with _executor_lock:
+ executor_id = _next_executor_id
+ _next_executor_id += 1
+ return executor_id
def get_reusable_executor(max_workers=None, context=None, timeout=10,
@@ -64,7 +68,48 @@ def get_reusable_executor(max_workers=None, context=None, timeout=10,
in the children before any module is loaded. This only works with the
``loky`` context.
"""
- pass
+ global _executor, _executor_kwargs
+ executor_id = _get_next_executor_id()
+
+ with _executor_lock:
+ if _executor is None or kill_workers:
+ if _executor is not None:
+ _executor.shutdown(wait=True)
+
+ _executor = _ReusablePoolExecutor(
+ _executor_lock,
+ max_workers=max_workers,
+ context=context,
+ timeout=timeout,
+ executor_id=executor_id,
+ job_reducers=job_reducers,
+ result_reducers=result_reducers,
+ initializer=initializer,
+ initargs=initargs,
+ env=env
+ )
+ _executor_kwargs = _executor._kwargs
+ else:
+ if max_workers is not None and _executor._max_workers != max_workers:
+ _executor._wait_job_completion()
+ _executor._max_workers = max_workers
+ _executor._adjust_process_count()
+ if _executor._broken:
+ _executor = None
+ return get_reusable_executor(
+ max_workers=max_workers,
+ context=context,
+ timeout=timeout,
+ kill_workers=kill_workers,
+ reuse=reuse,
+ job_reducers=job_reducers,
+ result_reducers=result_reducers,
+ initializer=initializer,
+ initargs=initargs,
+ env=env
+ )
+
+ return _executor
class _ReusablePoolExecutor(ProcessPoolExecutor):
@@ -81,4 +126,5 @@ class _ReusablePoolExecutor(ProcessPoolExecutor):
def _wait_job_completion(self):
"""Wait for the cache to be empty before resizing the pool."""
- pass
+ while len(self._pending_work_items) > 0:
+ time.sleep(0.1)
diff --git a/joblib/func_inspect.py b/joblib/func_inspect.py
index 5a263a0..7d3f537 100644
--- a/joblib/func_inspect.py
+++ b/joblib/func_inspect.py
@@ -35,12 +35,23 @@ def get_func_code(func):
This function does a bit more magic than inspect, and is thus
more robust.
"""
- pass
+ try:
+ source_file = inspect.getsourcefile(func)
+ if source_file is None:
+ raise ValueError("Unable to find the source file for the function.")
+
+ lines, first_line = inspect.getsourcelines(func)
+ func_code = ''.join(lines)
+
+ return func_code, source_file, first_line
+ except Exception as e:
+ raise ValueError(f"Unable to retrieve function code: {str(e)}")
def _clean_win_chars(string):
"""Windows cannot encode some characters in filename."""
- pass
+ import urllib.parse
+ return urllib.parse.quote(string, safe='')
def get_func_name(func, resolv_alias=True, win_characters=True):
@@ -57,17 +68,49 @@ def get_func_name(func, resolv_alias=True, win_characters=True):
If true, substitute special characters using urllib.quote
This is useful in Windows, as it cannot encode some filenames
"""
- pass
+ module = inspect.getmodule(func)
+ if module is None:
+ return [], func.__name__
+
+ module_path = module.__name__.split('.')
+ func_name = func.__name__
+
+ if resolv_alias:
+ if hasattr(module, '__all__'):
+ if func_name not in module.__all__:
+ func_name = f'{func_name} (alias)'
+ elif not func_name.startswith('_'):
+ func_name = f'{func_name} (alias)'
+
+ if win_characters:
+ func_name = _clean_win_chars(func_name)
+
+ return module_path, func_name
def _signature_str(function_name, arg_sig):
"""Helper function to output a function signature"""
- pass
+ args = []
+ for arg in arg_sig.args:
+ if arg in arg_sig.annotations:
+ args.append(f"{arg}: {arg_sig.annotations[arg].__name__}")
+ else:
+ args.append(arg)
+
+ if arg_sig.varargs:
+ args.append(f"*{arg_sig.varargs}")
+ if arg_sig.varkw:
+ args.append(f"**{arg_sig.varkw}")
+
+ return f"{function_name}({', '.join(args)})"
def _function_called_str(function_name, args, kwargs):
"""Helper function to output a function call"""
- pass
+ args_str = [repr(arg) for arg in args]
+ kwargs_str = [f"{key}={repr(value)}" for key, value in kwargs.items()]
+ all_args = args_str + kwargs_str
+ return f"{function_name}({', '.join(all_args)})"
def filter_args(func, ignore_lst, args=(), kwargs=dict()):
@@ -91,11 +134,27 @@ def filter_args(func, ignore_lst, args=(), kwargs=dict()):
filtered_args: list
List of filtered positional and keyword arguments.
"""
- pass
+ arg_spec = inspect.getfullargspec(func)
+
+ # Filter positional arguments
+ filtered_args = [arg for i, arg in enumerate(args) if i < len(arg_spec.args) and arg_spec.args[i] not in ignore_lst]
+
+ # Filter keyword arguments
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ignore_lst}
+
+ # Handle '*args' and '**kwargs'
+ if '*' not in ignore_lst and arg_spec.varargs:
+ filtered_args.extend(args[len(arg_spec.args):])
+ if '**' not in ignore_lst and arg_spec.varkw:
+ filtered_kwargs.update({k: v for k, v in kwargs.items() if k not in arg_spec.args})
+
+ return filtered_args + list(filtered_kwargs.items())
def format_call(func, args, kwargs, object_name='Memory'):
""" Returns a nicely formatted statement displaying the function
call with the given arguments.
"""
- pass
+ func_name = get_func_name(func)[1]
+ arg_str = _function_called_str(func_name, args, kwargs)
+ return f"{object_name}({arg_str})"
diff --git a/joblib/hashing.py b/joblib/hashing.py
index 7f57b88..e4e7758 100644
--- a/joblib/hashing.py
+++ b/joblib/hashing.py
@@ -77,7 +77,20 @@ class NumpyHasher(Hasher):
than pickling them. Off course, this is a total abuse of
the Pickler class.
"""
- pass
+ if isinstance(obj, self.np.ndarray):
+ # Check if it's a memmap
+ if isinstance(obj, self.np.memmap) and not self.coerce_mmap:
+ # Memmap detected, handle differently
+ self._hash.update(obj.filename.encode('utf-8'))
+ self._hash.update(str(obj.offset).encode('utf-8'))
+ self._hash.update(str(obj.shape).encode('utf-8'))
+ self._hash.update(str(obj.dtype).encode('utf-8'))
+ else:
+ # Regular ndarray or coerced memmap
+ self._hash.update(self._getbuffer(obj))
+ else:
+ # For all other objects, use the default pickling behavior
+ Hasher.save(self, obj)
def hash(obj, hash_name='md5', coerce_mmap=False):
@@ -92,4 +105,17 @@ def hash(obj, hash_name='md5', coerce_mmap=False):
coerce_mmap: boolean
Make no difference between np.memmap and np.ndarray
"""
- pass
+ if hash_name not in ('md5', 'sha1'):
+ raise ValueError("Valid options for 'hash_name' are 'md5' or 'sha1'")
+ try:
+ import numpy as np
+ hasher = NumpyHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
+ except ImportError:
+ hasher = Hasher(hash_name=hash_name)
+
+ try:
+ hasher.save(obj)
+ except pickle.PicklingError as e:
+ raise pickle.PicklingError('PicklingError while hashing %r: %r' %
+ (obj, e))
+ return hasher._hash.hexdigest()
diff --git a/joblib/logger.py b/joblib/logger.py
index 4991108..40ece3d 100644
--- a/joblib/logger.py
+++ b/joblib/logger.py
@@ -18,7 +18,9 @@ def _squeeze_time(t):
stat files. This is needed to make results similar to timings under
Unix, for tests
"""
- pass
+ if sys.platform.startswith('win'):
+ return max(0, t - 0.1)
+ return t
class Logger(object):
@@ -39,7 +41,7 @@ class Logger(object):
def format(self, obj, indent=0):
"""Return the formatted representation of the object."""
- pass
+ return pprint.pformat(obj, indent=indent, depth=self.depth)
class PrintTime(object):
diff --git a/joblib/memory.py b/joblib/memory.py
index 14f8456..2db1dcd 100644
--- a/joblib/memory.py
+++ b/joblib/memory.py
@@ -30,7 +30,10 @@ def extract_first_line(func_code):
""" Extract the first line information from the function code
text if available.
"""
- pass
+ lines = func_code.split('\n')
+ if lines:
+ return lines[0].strip()
+ return ''
class JobLibCollisionWarning(UserWarning):
@@ -59,17 +62,32 @@ def register_store_backend(backend_name, backend):
The name of a class that implements the StoreBackendBase interface.
"""
- pass
+ if not issubclass(backend, StoreBackendBase):
+ raise ValueError("backend must be a subclass of StoreBackendBase")
+ _STORE_BACKENDS[backend_name] = backend
def _store_backend_factory(backend, location, verbose=0, backend_options=None):
"""Return the correct store object for the given location."""
- pass
+ if backend not in _STORE_BACKENDS:
+ raise ValueError(f"Unknown backend: {backend}")
+
+ backend_options = backend_options or {}
+ backend_class = _STORE_BACKENDS[backend]
+ return backend_class(location, verbose=verbose, **backend_options)
def _build_func_identifier(func):
"""Build a roughly unique identifier for the cached function."""
- pass
+ module = getattr(func, '__module__', None)
+ name = getattr(func, '__qualname__', None) or getattr(func, '__name__', None)
+
+ if module and name:
+ return f"{module}.{name}"
+ elif name:
+ return name
+ else:
+ return str(func)
_FUNCTION_HASHES = weakref.WeakKeyDictionary()
@@ -123,11 +141,20 @@ class MemorizedResult(Logger):
def get(self):
"""Read value from cache and return it."""
- pass
+ if self.store_backend is None:
+ raise ValueError("Cannot get a value from a non-initialized backend")
+
+ value = self.store_backend.load_item(self._call_id)
+
+ if self.mmap_mode is not None and hasattr(value, 'mmap'):
+ value = value.mmap(mode=self.mmap_mode)
+
+ return value
def clear(self):
"""Clear value from cache"""
- pass
+ if self.store_backend is not None:
+ self.store_backend.clear_item(self._call_id)
def __repr__(self):
return '{}(location="{}", func="{}", args_id="{}")'.format(self.
@@ -279,7 +306,22 @@ class MemorizedFunc(Logger):
Returns True if the function call is in cache and can be used, and
returns False otherwise.
"""
- pass
+ if self.store_backend is None:
+ return False
+
+ if not self.store_backend.contains_item(call_id):
+ return False
+
+ func_code, _ = get_func_code(self.func)
+ if self._check_previous_func_code(stacklevel=4):
+ return False
+
+ if self.cache_validation_callback is not None:
+ metadata = self.store_backend.get_metadata(call_id)
+ if not self.cache_validation_callback(metadata):
+ return False
+
+ return True
def _cached_call(self, args, kwargs, shelving):
"""Call wrapped function and cache result, or read cache if available.
@@ -303,7 +345,32 @@ class MemorizedFunc(Logger):
MemorizedResult reference to the value if shelving is true.
metadata: dict containing the metadata associated with the call.
"""
- pass
+ call_id = self._get_args_id(*args, **kwargs)
+
+ if self._is_in_cache_and_valid(call_id):
+ return self._load_cached_result(call_id, shelving)
+
+ start_time = time.time()
+ output = self.func(*args, **kwargs)
+ duration = time.time() - start_time
+
+ metadata = {
+ 'duration': duration,
+ 'timestamp': time.time(),
+ }
+
+ self._persist_input(duration, call_id, args, kwargs)
+ self.store_backend.dump_item(call_id, output, metadata)
+
+ if shelving:
+ return MemorizedResult(self.store_backend.location, call_id,
+ backend=self.store_backend.__class__.__name__,
+ mmap_mode=self.mmap_mode,
+ verbose=self._verbose,
+ timestamp=self.timestamp,
+ metadata=metadata), metadata
+ else:
+ return output, metadata
def call_and_shelve(self, *args, **kwargs):
"""Call wrapped function, cache result and return a reference.
@@ -348,27 +415,47 @@ class MemorizedFunc(Logger):
def _get_args_id(self, *args, **kwargs):
"""Return the input parameter hash of a result."""
- pass
+ return hashing.hash((args, kwargs), ignore=self.ignore)
def _hash_func(self):
"""Hash a function to key the online cache"""
- pass
+ func_code, source_file = get_func_code(self.func)
+ return hashing.hash((func_code, source_file, self.func_id))
def _write_func_code(self, func_code, first_line):
""" Write the function code and the filename to a file.
"""
- pass
+ if self.store_backend is None:
+ return
+
+ func_code = f"{FIRST_LINE_TEXT} {first_line}\n{func_code}"
+ self.store_backend.store_cached_func_code([self.func_id], func_code)
def _check_previous_func_code(self, stacklevel=2):
"""
stacklevel is the depth a which this function is called, to
issue useful warnings to the user.
"""
- pass
+ if self._func_code_info is None:
+ return False
+
+ func_code, source_file = get_func_code(self.func)
+ func_code_id = self._hash_func()
+
+ if func_code_id != self._func_code_id:
+ warnings.warn('Function {} has changed.'.format(self.func),
+ JobLibCollisionWarning, stacklevel=stacklevel)
+ return True
+ return False
def clear(self, warn=True):
"""Empty the function's cache."""
- pass
+ if self.store_backend is not None:
+ self.store_backend.clear()
+
+ if warn:
+ warnings.warn('Clearing {0} cache'.format(self.func.__name__),
+ stacklevel=2)
def call(self, *args, **kwargs):
"""Force the execution of the function with the given arguments.
@@ -410,7 +497,25 @@ class MemorizedFunc(Logger):
this_duration_limit: float
Max execution time for this function before issuing a warning.
"""
- pass
+ if duration > this_duration_limit:
+ warnings.warn(
+ 'Persisting input arguments took %.2fs to run.\n'
+ 'If this happens often in your code, it can cause performance problems '
+ '(results will be correct but it can be slow).\n'
+ 'The reason for this is probably some large input arguments.\n'
+ 'You can try to increase the compression level or cache them separately '
+ 'if possible.' % duration, stacklevel=5)
+
+ input_repr = format_call(self.func, args, kwargs)
+ input_repr = textwrap.shorten(input_repr, width=300)
+
+ metadata = {
+ 'duration': duration,
+ 'input_repr': input_repr,
+ 'timestamp': time.time(),
+ }
+
+ self.store_backend.store_metadata(call_id, metadata)
def __repr__(self):
return '{class_name}(func={func}, location={location})'.format(
@@ -548,7 +653,11 @@ class Memory(Logger):
def clear(self, warn=True):
""" Erase the complete cache directory.
"""
- pass
+ if self.store_backend is not None:
+ self.store_backend.clear()
+
+ if warn:
+ warnings.warn('Clearing Memory cache', stacklevel=2)
def reduce_size(self, bytes_limit=None, items_limit=None, age_limit=None):
"""Remove cache elements to make the cache fit its limits.
@@ -576,7 +685,19 @@ class Memory(Logger):
of the cache, any items last accessed more than the given length of
time ago are deleted.
"""
- pass
+ if self.store_backend is None:
+ return
+
+ if isinstance(bytes_limit, str):
+ bytes_limit = self._parse_size(bytes_limit)
+
+ self.store_backend.reduce_store(bytes_limit, items_limit, age_limit)
+
+ def _parse_size(self, size_str):
+ units = {'K': 1024, 'M': 1024**2, 'G': 1024**3}
+ size = int(size_str[:-1])
+ unit = size_str[-1].upper()
+ return size * units.get(unit, 1)
def eval(self, func, *args, **kwargs):
""" Eval function func with arguments `*args` and `**kwargs`,
@@ -587,7 +708,11 @@ class Memory(Logger):
up to date.
"""
- pass
+ if self.store_backend is None:
+ return func(*args, **kwargs)
+
+ cached_func = self.cache(func)
+ return cached_func(*args, **kwargs)
def __repr__(self):
return '{class_name}(location={location})'.format(class_name=self.
@@ -612,4 +737,16 @@ def expires_after(days=0, seconds=0, microseconds=0, milliseconds=0,
days, seconds, microseconds, milliseconds, minutes, hours, weeks: numbers
argument passed to a timedelta.
"""
- pass
+ duration = datetime.timedelta(
+ days=days, seconds=seconds, microseconds=microseconds,
+ milliseconds=milliseconds, minutes=minutes, hours=hours, weeks=weeks
+ )
+
+ def callback(metadata):
+ timestamp = metadata.get('timestamp')
+ if timestamp is None:
+ return False
+ age = datetime.datetime.now() - datetime.datetime.fromtimestamp(timestamp)
+ return age < duration
+
+ return callback
diff --git a/joblib/numpy_pickle.py b/joblib/numpy_pickle.py
index 6cc4089..6da27f2 100644
--- a/joblib/numpy_pickle.py
+++ b/joblib/numpy_pickle.py
@@ -70,7 +70,30 @@ class NumpyArrayWrapper(object):
This function is an adaptation of the numpy write_array function
available in version 1.10.1 in numpy/lib/format.py.
"""
- pass
+ if self.order not in ('C', 'F'):
+ raise ValueError("order must be 'C' or 'F'")
+
+ if array.dtype.hasobject:
+ # We don't handle object arrays
+ raise ValueError("object arrays cannot be written")
+
+ if pickler.np is None:
+ # numpy is not available, fallback to pickle
+ return pickler.save(array)
+
+ # Make sure we're working with a contiguous array
+ array = pickler.np.ascontiguousarray(array, dtype=self.dtype)
+
+ # Write the array data
+ if not pickler.buffered:
+ pickler.file_handle.write(array.data)
+ else:
+ # If the file handle is buffered (like with gzip), we need to chunk the write
+ chunk_size = BUFFER_SIZE
+ data = array.data
+ for i in range(0, len(data), chunk_size):
+ chunk = data[i:i+chunk_size]
+ pickler.file_handle.write(chunk)
def read_array(self, unpickler):
"""Read array from unpickler file handle.
@@ -78,11 +101,37 @@ class NumpyArrayWrapper(object):
This function is an adaptation of the numpy read_array function
available in version 1.10.1 in numpy/lib/format.py.
"""
- pass
+ if unpickler.np is None:
+ raise ImportError("Numpy is required to unpickle numpy arrays")
+
+ # Read the array data
+ array = unpickler.np.frombuffer(
+ _read_bytes(unpickler.file_handle, self.size, "array data"),
+ dtype=self.dtype
+ ).reshape(self.shape, order=self.order)
+
+ if self.order == 'F':
+ array = unpickler.np.asfortranarray(array)
+
+ return array
def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
- pass
+ if unpickler.np is None:
+ raise ImportError("Numpy is required to unpickle numpy arrays")
+
+ offset = unpickler.file_handle.tell()
+ if isinstance(unpickler.file_handle, io.BufferedReader):
+ # For buffered readers, we need to flush to get the true file position
+ unpickler.file_handle.flush()
+
+ array = unpickler.np.memmap(unpickler.filename, mode=unpickler.mmap_mode,
+ dtype=self.dtype, shape=self.shape,
+ order=self.order, offset=offset)
+
+ # Advance file cursor
+ unpickler.file_handle.seek(self.size, 1)
+ return array
def read(self, unpickler):
"""Read the array corresponding to this wrapper.
@@ -98,7 +147,16 @@ class NumpyArrayWrapper(object):
array: numpy.ndarray
"""
- pass
+ if unpickler.mmap_mode is not None and self.allow_mmap:
+ array = self.read_mmap(unpickler)
+ else:
+ array = self.read_array(unpickler)
+
+ # Handle subclasses
+ if self.subclass is not unpickler.np.ndarray:
+ array = array.view(self.subclass)
+
+ return array
class NumpyPickler(Pickler):
@@ -131,7 +189,22 @@ class NumpyPickler(Pickler):
def _create_array_wrapper(self, array):
"""Create and returns a numpy array wrapper from a numpy array."""
- pass
+ if self.np is None:
+ return array
+
+ # Check if the array is a numpy array
+ if not isinstance(array, self.np.ndarray):
+ return array
+
+ # Create and return the wrapper
+ return NumpyArrayWrapper(
+ subclass=array.__class__,
+ shape=array.shape,
+ dtype=array.dtype,
+ order='F' if array.flags.f_contiguous else 'C',
+ allow_mmap=not array.dtype.hasobject,
+ numpy_array_alignment_bytes=self.np.lib.format.ARRAY_ALIGN
+ )
def save(self, obj):
"""Subclass the Pickler `save` method.
@@ -143,7 +216,14 @@ class NumpyPickler(Pickler):
after in the file. Warning: the file produced does not follow the
pickle format. As such it can not be read with `pickle.load`.
"""
- pass
+ if self.np is not None and isinstance(obj, self.np.ndarray):
+ # Replace the numpy array with our custom wrapper
+ wrapper = self._create_array_wrapper(obj)
+ Pickler.save(self, wrapper)
+ # Write the array data directly to the file
+ wrapper.write_array(obj, self)
+ else:
+ Pickler.save(self, obj)
class NumpyUnpickler(Unpickler):
@@ -185,7 +265,13 @@ class NumpyUnpickler(Unpickler):
replace them directly in the stack of pickler.
NDArrayWrapper is used for backward compatibility with joblib <= 0.9.
"""
- pass
+ stack = self.stack
+ if isinstance(stack[-1], (NDArrayWrapper, NumpyArrayWrapper)):
+ # Replace the wrapper object by the underlying array
+ array = stack[-1].read(self)
+ stack[-1] = array
+ else:
+ Unpickler.load_build(self)
dispatch[pickle.BUILD[0]] = load_build
diff --git a/joblib/numpy_pickle_compat.py b/joblib/numpy_pickle_compat.py
index 4ccffb6..700e68d 100644
--- a/joblib/numpy_pickle_compat.py
+++ b/joblib/numpy_pickle_compat.py
@@ -11,7 +11,7 @@ from .numpy_pickle_utils import _ensure_native_byte_order
def hex_str(an_int):
"""Convert an int to an hexadecimal string."""
- pass
+ return f"{an_int:x}".zfill(_MAX_LEN)
_MAX_LEN = len(hex_str(2 ** 64))
@@ -25,7 +25,9 @@ def read_zfile(file_handle):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.seek(0)
+ assert file_handle.read(len(_ZFILE_PREFIX)) == _ZFILE_PREFIX
+ return zlib.decompress(file_handle.read())
def write_zfile(file_handle, data, compress=1):
@@ -35,7 +37,8 @@ def write_zfile(file_handle, data, compress=1):
for persistence. Backward compatibility is not guaranteed. Do not
use for external purposes.
"""
- pass
+ file_handle.write(_ZFILE_PREFIX)
+ file_handle.write(zlib.compress(data, compress))
class NDArrayWrapper(object):
@@ -53,7 +56,17 @@ class NDArrayWrapper(object):
def read(self, unpickler):
"""Reconstruct the array."""
- pass
+ import numpy as np
+ filename = os.path.join(unpickler._dirname, self.filename)
+ if self.allow_mmap and unpickler.mmap_mode is not None:
+ array = np.load(filename, mmap_mode=unpickler.mmap_mode)
+ else:
+ array = np.load(filename, allow_pickle=True)
+ if self.subclass is not np.ndarray:
+ # We need to reconstruct another subclass
+ return self.subclass(array.shape, array.dtype,
+ buffer=array, order='C')
+ return array
class ZNDArrayWrapper(NDArrayWrapper):
@@ -79,7 +92,24 @@ class ZNDArrayWrapper(NDArrayWrapper):
def read(self, unpickler):
"""Reconstruct the array from the meta-information and the z-file."""
- pass
+ import numpy as np
+ filename = os.path.join(unpickler._dirname, self.filename)
+ with open(filename, 'rb') as f:
+ array = np.frombuffer(read_zfile(f),
+ dtype=self.init_args['dtype'])
+ array = array.reshape(self.init_args['shape'])
+ if self.init_args.get('order') == 'F':
+ array = array.T
+ if self.init_args['dtype'].hasobject:
+ array = array.copy()
+ # Reconstruct subclasses
+ if self.init_args.get('cls') is not None:
+ new_array = self.init_args['cls'].__new__(
+ self.init_args['cls'], self.init_args['shape'],
+ self.init_args['dtype'], buffer=array)
+ array = new_array
+ array.__setstate__(self.state)
+ return array
class ZipNumpyUnpickler(Unpickler):
@@ -106,7 +136,32 @@ class ZipNumpyUnpickler(Unpickler):
NDArrayWrapper, by the array we are interested in. We
replace them directly in the stack of pickler.
"""
- pass
+ stack = self.stack
+ state = stack.pop()
+ instance = stack[-1]
+ if isinstance(instance, NDArrayWrapper):
+ # Replace the NDArrayWrapper by the array itself
+ array = instance.read(self)
+ # If the array is part of a subclass, we need to preserve it
+ if (isinstance(array, self.np.ndarray) and
+ not type(array) is self.np.ndarray):
+ self.stack[-1] = array
+ else:
+ self.stack[-1] = _ensure_native_byte_order(array)
+ else:
+ setstate = getattr(instance, "__setstate__", None)
+ if setstate is not None:
+ setstate(state)
+ else:
+ try:
+ (slotstate, state) = state
+ except:
+ slotstate = None
+ if state:
+ instance.__dict__.update(state)
+ if slotstate:
+ for key, value in slotstate.items():
+ setattr(instance, key, value)
dispatch[pickle.BUILD[0]] = load_build
@@ -136,4 +191,14 @@ def load_compatibility(filename):
This function can load numpy array files saved separately during the
dump.
"""
- pass
+ with open(filename, 'rb') as file_handle:
+ # We are careful to open the file handle early and keep it open to
+ # avoid race-conditions on renames.
+ # We use a try/finally block to ensure that file is closed even if
+ # we raise an exception
+ try:
+ unpickler = ZipNumpyUnpickler(filename, file_handle)
+ obj = unpickler.load()
+ finally:
+ file_handle.close()
+ return obj
diff --git a/joblib/numpy_pickle_utils.py b/joblib/numpy_pickle_utils.py
index e79528e..79b048f 100644
--- a/joblib/numpy_pickle_utils.py
+++ b/joblib/numpy_pickle_utils.py
@@ -22,12 +22,15 @@ _IO_BUFFER_SIZE = 1024 ** 2
def _is_raw_file(fileobj):
"""Check if fileobj is a raw file object, e.g created with open."""
- pass
+ return isinstance(fileobj, (io.FileIO, io.BufferedReader, io.BufferedWriter))
def _is_numpy_array_byte_order_mismatch(array):
"""Check if numpy array is having byte order mismatch"""
- pass
+ if np is None:
+ return False
+ return (array.dtype.byteorder == '>' and sys.byteorder == 'little') or \
+ (array.dtype.byteorder == '<' and sys.byteorder == 'big')
def _ensure_native_byte_order(array):
@@ -35,7 +38,9 @@ def _ensure_native_byte_order(array):
Does nothing if array already uses the system byte order.
"""
- pass
+ if _is_numpy_array_byte_order_mismatch(array):
+ return array.byteswap().newbyteorder()
+ return array
def _detect_compressor(fileobj):
@@ -49,17 +54,40 @@ def _detect_compressor(fileobj):
-------
str in {'zlib', 'gzip', 'bz2', 'lzma', 'xz', 'compat', 'not-compressed'}
"""
- pass
+ if not hasattr(fileobj, 'read'):
+ raise ValueError("Expected file-like object, got %s" % type(fileobj))
+
+ magic = fileobj.read(4)
+ fileobj.seek(0)
+
+ if magic.startswith(_ZFILE_PREFIX):
+ return 'zlib'
+ elif magic.startswith(b'\x1f\x8b'):
+ return 'gzip'
+ elif magic.startswith(b'BZh'):
+ return 'bz2'
+ elif magic.startswith(b'\x28\xb5\x2f\xfd'):
+ return 'zstd'
+ elif magic.startswith(b'\xfd7zXZ\x00'):
+ return 'xz'
+ elif magic.startswith(b'\x89LZO'):
+ return 'lzo'
+ else:
+ return 'not-compressed'
def _buffered_read_file(fobj):
"""Return a buffered version of a read file object."""
- pass
+ if isinstance(fobj, io.BufferedReader):
+ return fobj
+ return io.BufferedReader(fobj, buffer_size=_IO_BUFFER_SIZE)
def _buffered_write_file(fobj):
"""Return a buffered version of a write file object."""
- pass
+ if isinstance(fobj, io.BufferedWriter):
+ return fobj
+ return io.BufferedWriter(fobj, buffer_size=_IO_BUFFER_SIZE)
@contextlib.contextmanager
@@ -90,12 +118,39 @@ def _read_fileobject(fileobj, filename, mmap_mode=None):
a file like object
"""
- pass
+ compressor = _detect_compressor(fileobj)
+
+ if compressor == 'not-compressed':
+ if mmap_mode is not None:
+ fileobj = np.load(filename, mmap_mode=mmap_mode)
+ else:
+ fileobj = _buffered_read_file(fileobj)
+ elif compressor in _COMPRESSORS:
+ if mmap_mode is not None:
+ warnings.warn('File "%(filename)s" is compressed using %(compressor)s. '
+ 'Memory mapping mode "%(mmap_mode)s" is not supported.'
+ % locals())
+ fileobj = _COMPRESSORS[compressor]['file_open'](fileobj, mode='rb')
+ else:
+ raise ValueError("Unrecognized file type: %s" % compressor)
+
+ try:
+ yield fileobj
+ finally:
+ fileobj.close()
def _write_fileobject(filename, compress=('zlib', 3)):
"""Return the right compressor file object in write mode."""
- pass
+ if compress is None or compress == 'not-compressed':
+ return _buffered_write_file(open(filename, 'wb'))
+
+ compressor, compress_level = compress
+ if compressor not in _COMPRESSORS:
+ raise ValueError("Unrecognized compressor: %s" % compressor)
+
+ return _COMPRESSORS[compressor]['file_open'](filename, 'wb',
+ compresslevel=compress_level)
BUFFER_SIZE = 2 ** 18
@@ -127,4 +182,10 @@ def _read_bytes(fp, size, error_template='ran out of data'):
The data read in bytes.
"""
- pass
+ data = bytes()
+ while len(data) < size:
+ chunk = fp.read(size - len(data))
+ if not chunk:
+ raise ValueError(error_template)
+ data += chunk
+ return data
diff --git a/joblib/parallel.py b/joblib/parallel.py
index f141ee0..d7d863b 100644
--- a/joblib/parallel.py
+++ b/joblib/parallel.py
@@ -904,7 +904,20 @@ class Parallel(Logger):
def _initialize_backend(self):
"""Build a process or thread pool and return the number of workers"""
- pass
+ try:
+ n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self, **self._backend_args)
+ if self.timeout is not None and not self._backend.supports_timeout:
+ warnings.warn(
+ 'The backend class {!r} does not support timeout. '
+ "You have set 'timeout={}' in Parallel but "
+ "the 'timeout' parameter will not be used.".format(
+ self._backend.__class__.__name__,
+ self.timeout))
+
+ return n_jobs
+ except FallbackToBackend as e:
+ self._backend = e.backend
+ return self._initialize_backend()
def _dispatch(self, batch):
"""Queue the batch for computing, with or without multiprocessing
@@ -913,7 +926,15 @@ class Parallel(Logger):
indirectly via dispatch_one_batch.
"""
- pass
+ if self.timeout is not None:
+ batch = BatchCompletionCallBack(self._backend, batch, self.timeout)
+
+ dispatch_timestamp = time.time()
+ cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
+ with self._lock:
+ job_id = self._backend.apply_async(batch, callback=cb)
+ cb.register_job(job_id)
+ self._jobs.append(job_id)
def dispatch_next(self):
"""Dispatch more data for parallel processing
@@ -923,7 +944,15 @@ class Parallel(Logger):
against concurrent consumption of the unprotected iterator.
"""
- pass
+ if self.dispatch_one_batch(self._original_iterator):
+ return True
+ elif self._original_iterator is not None:
+ self._iterating = False
+ self._original_iterator = None
+ with self._lock:
+ if not self.return_generator:
+ self._ready_batches.put(None)
+ return False
def dispatch_one_batch(self, iterator):
"""Prefetch the tasks for the next batch and dispatch them.
@@ -935,41 +964,121 @@ class Parallel(Logger):
lock so calling this function should be thread safe.
"""
- pass
+ if iterator is None:
+ return False
+
+ with self._lock:
+ batch_size = self._get_batch_size()
+ tasks = []
+ for _ in range(batch_size):
+ try:
+ tasks.append(next(iterator))
+ except StopIteration:
+ break
+
+ if len(tasks) == 0:
+ return False
+
+ self._dispatch(tasks)
+ return True
def _get_batch_size(self):
"""Returns the effective batch size for dispatch"""
- pass
+ if self.batch_size == 'auto':
+ if self._cached_effective_n_jobs == 1:
+ return 1
+ elif self._backend.supports_threading:
+ return 1
+ else:
+ return 2 * self._cached_effective_n_jobs
+ else:
+ return self.batch_size
def _print(self, msg):
"""Display the message on stout or stderr depending on verbosity"""
- pass
+ if self.verbose > 10:
+ print(msg)
+ elif self.verbose > 0:
+ print(msg, end='\r', file=sys.stderr)
def _is_completed(self):
"""Check if all tasks have been completed"""
- pass
+ return self._ready_batches.qsize() == 0 and len(self._jobs) == 0
def print_progress(self):
"""Display the process of the parallel execution only a fraction
of time, controlled by self.verbose.
"""
- pass
+ if not self.verbose:
+ return
+ elapsed_time = time.time() - self._start_time
+
+ # This is heuristic code to print only 'verbose' times a messages
+ # The challenge is that we may not know the queue length
+ if self._original_iterator:
+ if _verbosity_filter(self._completed_tasks, self.verbose):
+ return
+ self._print('Done %3i tasks | elapsed: %s' %
+ (self._completed_tasks,
+ short_format_time(elapsed_time)))
+ else:
+ index = self._completed_tasks
+ if _verbosity_filter(index, self.verbose):
+ return
+ self._print('[%s]: Done %3i out of %3i | elapsed: %s' % (
+ self, index, len(self._jobs),
+ short_format_time(elapsed_time)))
def _get_outputs(self, iterator, pre_dispatch):
"""Iterator returning the tasks' output as soon as they are ready."""
- pass
+ self._start_time = time.time()
+ self._iterating = True
+ self._completed_tasks = 0
+
+ while self._iterating or len(self._jobs) > 0:
+ if self._iterating:
+ self.dispatch_one_batch(iterator)
+
+ if len(self._jobs) == 0:
+ continue
+
+ try:
+ batch = self._ready_batches.get(timeout=0.1)
+ except queue.Empty:
+ self.print_progress()
+ continue
+
+ if batch is None:
+ self._iterating = False
+ else:
+ yield from batch
+ self._completed_tasks += len(batch)
+ self.print_progress()
+
+ self._raise_error_fast()
+
+ if self._aborting:
+ self._terminate_and_reset()
+ raise self._exception
+
+ self._terminate_and_reset()
def _wait_retrieval(self):
"""Return True if we need to continue retrieving some tasks."""
- pass
+ return (self._iterating or len(self._jobs) > 0 or
+ self._ready_batches.qsize() > 0)
def _raise_error_fast(self):
"""If we are aborting, raise if a job caused an error."""
- pass
+ if self._aborting:
+ raise self._exception
def _warn_exit_early(self):
- """Warn the user if the generator is gc'ed before being consumned."""
- pass
+ """Warn the user if the generator is gc'ed before being consumed."""
+ if self._iterating:
+ warnings.warn("Parallel generator was garbage collected before "
+ "being fully consumed. Some jobs may have not been "
+ "run.", UserWarning)
def _get_sequential_output(self, iterable):
"""Separate loop for sequential output.
@@ -977,11 +1086,29 @@ class Parallel(Logger):
This simplifies the traceback in case of errors and reduces the
overhead of calling sequential tasks with `joblib`.
"""
- pass
+ self._start_time = time.time()
+ output = []
+ try:
+ for i, func_args_kwargs in enumerate(iterable):
+ if self.pre_dispatch == "all" or i < self.pre_dispatch:
+ output.append(func_args_kwargs)
+ else:
+ yield func_args_kwargs
+ self._completed_tasks += 1
+ if self.timeout is not None and time.time() - self._start_time > self.timeout:
+ raise TimeoutError()
+ except BaseException as e:
+ self._exception = e
+ self._aborting = True
+ finally:
+ yield from output
def _reset_run_tracking(self):
"""Reset the counters and flags used to track the execution."""
- pass
+ self._aborting = False
+ self._exception = None
+ self._iterating = False
+ self._completed_tasks = 0
def __call__(self, iterable):
"""Main function to dispatch parallel tasks."""
diff --git a/joblib/pool.py b/joblib/pool.py
index a5c2643..71628a3 100644
--- a/joblib/pool.py
+++ b/joblib/pool.py
@@ -61,7 +61,10 @@ class CustomizablePickler(Pickler):
def register(self, type, reduce_func):
"""Attach a reducer function to a given type in the dispatch table."""
- pass
+ if hasattr(self, 'dispatch'):
+ self.dispatch[type] = reduce_func
+ else:
+ self.dispatch_table[type] = reduce_func
class CustomizablePicklingQueue(object):
@@ -101,6 +104,46 @@ class CustomizablePicklingQueue(object):
) = state
self._make_methods()
+ def _make_methods(self):
+ self.put = self._make_put()
+ self.get = self._make_get()
+ self.empty = self._make_empty()
+
+ def _make_put(self):
+ def put(obj):
+ buffer = BytesIO()
+ CustomizablePickler(buffer, self._reducers).dump(obj)
+ self._writer.send_bytes(buffer.getvalue())
+ if self._wlock is None:
+ return put
+ else:
+ def locked_put(obj):
+ with self._wlock:
+ return put(obj)
+ return locked_put
+
+ def _make_get(self):
+ def get():
+ return CustomizablePickler.loads(self._reader.recv_bytes())
+ if self._rlock is None:
+ return get
+ else:
+ def locked_get():
+ with self._rlock:
+ return get()
+ return locked_get
+
+ def _make_empty(self):
+ def empty():
+ return not self._reader.poll()
+ if self._rlock is None:
+ return empty
+ else:
+ def locked_empty():
+ with self._rlock:
+ return empty()
+ return locked_empty
+
class PicklingPool(Pool):
"""Pool implementation with customizable pickling reducers.
diff --git a/joblib/testing.py b/joblib/testing.py
index e20431c..af4e267 100644
--- a/joblib/testing.py
+++ b/joblib/testing.py
@@ -23,11 +23,15 @@ param = pytest.param
def warnings_to_stdout():
""" Redirect all warnings to stdout.
"""
- pass
+ def custom_formatwarning(message, category, filename, lineno, line=None):
+ return f"{filename}:{lineno}: {category.__name__}: {message}\n"
+
+ warnings.formatwarning = custom_formatwarning
+ warnings.simplefilter("always")
+ warnings.showwarning = lambda *args, **kwargs: sys.stdout.write(warnings.formatwarning(*args, **kwargs))
-def check_subprocess_call(cmd, timeout=5, stdout_regex=None, stderr_regex=None
- ):
+def check_subprocess_call(cmd, timeout=5, stdout_regex=None, stderr_regex=None):
"""Runs a command in a subprocess with timeout in seconds.
A SIGTERM is sent after `timeout` and if it does not terminate, a
@@ -36,4 +40,27 @@ def check_subprocess_call(cmd, timeout=5, stdout_regex=None, stderr_regex=None
Also checks returncode is zero, stdout if stdout_regex is set, and
stderr if stderr_regex is set.
"""
- pass
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
+
+ try:
+ stdout, stderr = process.communicate(timeout=timeout)
+ except subprocess.TimeoutExpired:
+ process.terminate()
+ try:
+ stdout, stderr = process.communicate(timeout=timeout)
+ except subprocess.TimeoutExpired:
+ process.kill()
+ stdout, stderr = process.communicate()
+
+ returncode = process.returncode
+
+ if returncode != 0:
+ raise subprocess.CalledProcessError(returncode, cmd, stdout, stderr)
+
+ if stdout_regex and not re.search(stdout_regex, stdout):
+ raise AssertionError(f"Stdout did not match regex: {stdout_regex}\nStdout: {stdout}")
+
+ if stderr_regex and not re.search(stderr_regex, stderr):
+ raise AssertionError(f"Stderr did not match regex: {stderr_regex}\nStderr: {stderr}")
+
+ return stdout, stderr