Skip to content

back to SWE-Agent summary

SWE-Agent: joblib

Failed to run pytests for test test

ImportError while loading conftest '/testbed/conftest.py'.
conftest.py:9: in <module>
    from joblib.parallel import mp
joblib/__init__.py:114: in <module>
    from .memory import Memory
joblib/memory.py:21: in <module>
    from . import hashing
joblib/hashing.py:31: in <module>
    class Hasher(Pickler):
joblib/hashing.py:42: in Hasher
    dispatch[type(len)] = save_global
E   NameError: name 'save_global' is not defined

Patch diff

diff --git a/joblib/parallel.py b/joblib/parallel.py
index 77395c4..b95dbb2 100644
--- a/joblib/parallel.py
+++ b/joblib/parallel.py
@@ -38,7 +38,12 @@ _backend = threading.local()

 def _register_dask():
     """Register Dask Backend if called with parallel_config(backend="dask")"""
-    pass
+    try:
+        from dask.distributed import Client
+        from ._dask import DaskDistributedBackend
+        register_parallel_backend('dask', DaskDistributedBackend)
+    except ImportError:
+        raise ImportError("To use the dask backend, you need to install 'dask' and 'distributed' libraries.")
 EXTERNAL_BACKENDS = {'dask': _register_dask}
 default_parallel_config = {'backend': _Sentinel(default_value=None), 'n_jobs': _Sentinel(default_value=None), 'verbose': _Sentinel(default_value=0), 'temp_folder': _Sentinel(default_value=None), 'max_nbytes': _Sentinel(default_value='1M'), 'mmap_mode': _Sentinel(default_value='r'), 'prefer': _Sentinel(default_value=None), 'require': _Sentinel(default_value=None)}
 VALID_BACKEND_HINTS = ('processes', 'threads', None)
@@ -50,15 +55,46 @@ def _get_config_param(param, context_config, key):
     Explicitly setting it in Parallel has priority over setting in a
     parallel_(config/backend) context manager.
     """
-    pass
+    if param is not _Sentinel():
+        return param
+    elif context_config is not None:
+        return context_config.get(key, default_parallel_config[key])
+    else:
+        return default_parallel_config[key]

 def get_active_backend(prefer=default_parallel_config['prefer'], require=default_parallel_config['require'], verbose=default_parallel_config['verbose']):
     """Return the active default backend"""
-    pass
+    return _get_active_backend(prefer, require, verbose)[0]

 def _get_active_backend(prefer=default_parallel_config['prefer'], require=default_parallel_config['require'], verbose=default_parallel_config['verbose']):
     """Return the active default backend"""
-    pass
+    config = getattr(_backend, 'config', default_parallel_config)
+    backend = config.get('backend', None)
+    
+    if backend is None:
+        if require is not None:
+            if require not in VALID_BACKEND_CONSTRAINTS:
+                raise ValueError(f"Invalid 'require' parameter: {require}")
+            if require == 'sharedmem':
+                backend = ThreadingBackend()
+            else:
+                backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+        elif prefer is not None:
+            if prefer not in VALID_BACKEND_HINTS:
+                raise ValueError(f"Invalid 'prefer' parameter: {prefer}")
+            if prefer == 'processes':
+                backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+            elif prefer == 'threads':
+                backend = ThreadingBackend()
+            else:
+                backend = LokyBackend() if 'loky' in BACKENDS else ThreadingBackend()
+        else:
+            backend = BACKENDS[DEFAULT_BACKEND]()
+    
+    if verbose > 0:
+        print(f"Using {backend.__class__.__name__} as parallel backend")
+    
+    return backend, config

 class parallel_config:
     """Set the default backend or configuration for :class:`~joblib.Parallel`.
@@ -351,7 +387,7 @@ def cpu_count(only_physical_cores=False):
     If only_physical_cores is True, do not take hyperthreading / SMT logical
     cores into account.
     """
-    pass
+    return loky.cpu_count(only_physical_cores=only_physical_cores)

 def _verbosity_filter(index, verbose):
     """ Returns False for indices increasingly apart, the distance
@@ -359,11 +395,16 @@ def _verbosity_filter(index, verbose):

         We use a lag increasing as the square of index
     """
-    pass
+    if not verbose:
+        return True
+    return (index % int(sqrt(index) * 10 * verbose) == 0)

 def delayed(function):
     """Decorator used to capture the arguments of a function."""
-    pass
+    @functools.wraps(function)
+    def delayed_function(*args, **kwargs):
+        return lambda: function(*args, **kwargs)
+    return delayed_function

 class BatchCompletionCallBack(object):
     """Callback to keep track of completed results and schedule the next tasks.
@@ -393,7 +434,7 @@ class BatchCompletionCallBack(object):

     def register_job(self, job):
         """Register the object returned by `apply_async`."""
-        pass
+        self.job = job

     def get_result(self, timeout):
         """Returns the raw result of the task that was submitted.
@@ -412,7 +453,13 @@ class BatchCompletionCallBack(object):
         return it or raise. It will block at most `self.timeout` seconds
         waiting for retrieval to complete, after that it raises a TimeoutError.
         """
-        pass
+    def get_result(self, timeout):
+        if self.parallel._backend.supports_retrieve_callback:
+            if self.get_status(timeout) == TASK_PENDING:
+                raise TimeoutError("Timeout reached while waiting for result")
+            return self.result
+        else:
+            return self.job.get(timeout=timeout)

     def get_status(self, timeout):
         """Get the status of the task.
@@ -420,7 +467,14 @@ class BatchCompletionCallBack(object):
         This function also checks if the timeout has been reached and register
         the TimeoutError outcome when it is the case.
         """
-        pass
+        if self.parallel._backend.supports_retrieve_callback:
+            return self.status
+        else:
+            try:
+                self.job.wait(timeout=timeout)
+                return TASK_DONE
+            except TimeoutError:
+                return TASK_PENDING

     def __call__(self, out):
         """Function called by the callback thread after a job is completed."""
@@ -440,7 +494,12 @@ class BatchCompletionCallBack(object):

     def _dispatch_new(self):
         """Schedule the next batch of tasks to be processed."""
-        pass
+        with self.parallel._lock:
+            if self.parallel._aborting:
+                return
+            if self.parallel._original_iterator is None:
+                return
+            self.parallel.dispatch_next()

     def _retrieve_result(self, out):
         """Fetch and register the outcome of a task.
@@ -449,14 +508,27 @@ class BatchCompletionCallBack(object):
         This function is only called by backends that support retrieving
         the task result in the callback thread.
         """
-        pass
+        success, result = out
+        self.result = result
+        if success:
+            self.status = TASK_DONE
+            return True
+        else:
+            self.status = TASK_ERROR
+            return False

     def _register_outcome(self, outcome):
         """Register the outcome of a task.

         This method can be called only once, future calls will be ignored.
         """
-        pass
+        if self.status is not TASK_PENDING:
+            return
+        self.result = outcome
+        if isinstance(outcome, BaseException):
+            self.status = TASK_ERROR
+        else:
+            self.status = TASK_DONE

 def register_parallel_backend(name, factory, make_default=False):
     """Register a new Parallel backend factory.
@@ -473,7 +545,10 @@ def register_parallel_backend(name, factory, make_default=False):

     .. versionadded:: 0.10
     """
-    pass
+    BACKENDS[name] = factory
+    if make_default:
+        global DEFAULT_BACKEND
+        DEFAULT_BACKEND = name

 def effective_n_jobs(n_jobs=-1):
     """Determine the number of jobs that can actually run in parallel
@@ -496,7 +571,11 @@ def effective_n_jobs(n_jobs=-1):

     .. versionadded:: 0.10
     """
-    pass
+    if n_jobs == 0:
+        raise ValueError('n_jobs == 0 in Parallel has no meaning')
+    elif n_jobs < 0:
+        n_jobs = max(cpu_count() + 1 + n_jobs, 1)
+    return n_jobs

 class Parallel(Logger):
     """ Helper class for readable parallel mapping.
@@ -832,16 +911,33 @@ class Parallel(Logger):

     def _initialize_backend(self):
         """Build a process or thread pool and return the number of workers"""
-        pass
+        if self._backend is None:
+            self._backend, _ = _get_active_backend(
+                prefer=self._backend_args['prefer'],
+                require=self._backend_args['require'],
+                verbose=self._backend_args['verbose']
+            )
+        
+        n_jobs = self.n_jobs
+        if n_jobs == 0:
+            raise ValueError('n_jobs == 0 in Parallel has no meaning')
+        elif n_jobs < 0:
+            n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1)
+        
+        self._n_jobs = n_jobs
+        self._backend.configure(n_jobs=self._n_jobs, parallel=self, **self._backend_args)
+        return n_jobs

     def _dispatch(self, batch):
         """Queue the batch for computing, with or without multiprocessing

         WARNING: this method is not thread-safe: it should be only called
         indirectly via dispatch_one_batch.
-
         """
-        pass
+        if self._backend.supports_retrieve_callback:
+            self._backend.apply_async(batch, callback=self._callback)
+        else:
+            self._pending_outputs.append(self._backend.apply_async(batch))

     def dispatch_next(self):
         """Dispatch more data for parallel processing
@@ -849,9 +945,12 @@ class Parallel(Logger):
         This method is meant to be called concurrently by the multiprocessing
         callback. We rely on the thread-safety of dispatch_one_batch to protect
         against concurrent consumption of the unprotected iterator.
-
         """
-        pass
+        if self.dispatch_one_batch(self._original_iterator):
+            return True
+        else:
+            self._iterating = False
+            return False

     def dispatch_one_batch(self, iterator):
         """Prefetch the tasks for the next batch and dispatch them.
@@ -863,41 +962,69 @@ class Parallel(Logger):
         lock so calling this function should be thread safe.

         """
-        pass
-
-    def _get_batch_size(self):
-        """Returns the effective batch size for dispatch"""
-        pass

     def _print(self, msg):
-        """Display the message on stout or stderr depending on verbosity"""
-        pass
+        """Display the message on stdout or stderr depending on verbosity"""
+        if self.verbose > 10:
+            print(msg, file=sys.stderr)
+        elif self.verbose:
+            print(msg)

     def _is_completed(self):
         """Check if all tasks have been completed"""
-        pass
+        return self._n_completed_tasks == self._n_dispatched_tasks

     def print_progress(self):
-        """Display the process of the parallel execution only a fraction
+        """Display the progress of the parallel execution only a fraction
            of time, controlled by self.verbose.
         """
-        pass
+        if not self.verbose:
+            return
+        elapsed_time = time.time() - self._start_time
+
+        # This is heuristic code to print only 'verbose' times a message
+        if self._n_completed_tasks > self._n_dispatched_tasks // self.verbose:
+            # We are done dispatching
+            self._print('Done %3i tasks      | elapsed: %s' %
+                        (self._n_completed_tasks,
+                         short_format_time(elapsed_time)))
+        else:
+            # We are not done dispatching
+            self._print('Completed %3i tasks | elapsed: %s' %
+                        (self._n_completed_tasks,
+                         short_format_time(elapsed_time)))

     def _get_outputs(self, iterator, pre_dispatch):
         """Iterator returning the tasks' output as soon as they are ready."""
-        pass
+        while self._iterating or self._n_completed_tasks < self._n_dispatched_tasks:
+            if self._n_completed_tasks < self._n_dispatched_tasks:
+                yield self._backend.retrieve()
+                self._n_completed_tasks += 1
+                self.print_progress()
+            else:
+                if self._iterating:
+                    self.dispatch_next()
+                else:
+                    break
+            if self._aborting:
+                self._raise_error_fast()

     def _wait_retrieval(self):
         """Return True if we need to continue retrieving some tasks."""
-        pass
+        return (self._n_completed_tasks < self._n_dispatched_tasks or
+                not self._backend.supports_retrieve_callback)

     def _raise_error_fast(self):
         """If we are aborting, raise if a job caused an error."""
-        pass
+        if self._aborting:
+            error = self._backend.get_first_exception()
+            if error is not None:
+                raise error

     def _warn_exit_early(self):
-        """Warn the user if the generator is gc'ed before being consumned."""
-        pass
+        """Warn the user if the generator is gc'ed before being consumed."""
+        warnings.warn("Parallel generator was garbage collected before being fully consumed. "
+                      "Some tasks may not have been executed.", UserWarning)

     def _get_sequential_output(self, iterable):
         """Separate loop for sequential output.
@@ -905,11 +1032,20 @@ class Parallel(Logger):
         This simplifies the traceback in case of errors and reduces the
         overhead of calling sequential tasks with `joblib`.
         """
-        pass
+        self._start_time = time.time()
+        for i, func in enumerate(iterable):
+            self._print('Done %3i tasks | elapsed: %s' %
+                        (i, short_format_time(time.time() - self._start_time)))
+            yield func()
+            self._n_completed_tasks += 1

     def _reset_run_tracking(self):
         """Reset the counters and flags used to track the execution."""
-        pass
+        self._n_completed_tasks = 0
+        self._n_dispatched_tasks = 0
+        self._n_dispatched_batches = 0
+        self._aborting = False
+        self._iterating = True

     def __call__(self, iterable):
         """Main function to dispatch parallel tasks."""
@@ -957,4 +1093,4 @@ class Parallel(Logger):
         return output if self.return_generator else list(output)

     def __repr__(self):
-        return '%s(n_jobs=%s)' % (self.__class__.__name__, self.n_jobs)
\ No newline at end of file
+        return '%s(n_jobs=%s)' % (self.__class__.__name__, self.n_jobs)