diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py
index 9bf26e5..f18c970 100644
--- a/imblearn/pipeline.py
+++ b/imblearn/pipeline.py
@@ -129,6 +129,14 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
"""
_parameter_constraints: dict = {'steps': 'no_validation', 'memory': [None, str, HasMethods(['cache'])], 'verbose': ['boolean']}
+ def _can_fit_transform(self):
+ """Check if the pipeline can fit_transform."""
+ return hasattr(self.steps[-1][1], "fit_transform") or (hasattr(self.steps[-1][1], "fit") and hasattr(self.steps[-1][1], "transform"))
+
+ def _can_fit_resample(self):
+ """Check if the pipeline can fit_resample."""
+ return hasattr(self.steps[-1][1], "fit_resample")
+
def _iter(self, with_final=True, filter_passthrough=True, filter_resample=True):
"""Generate (idx, (name, trans)) tuples from self.steps.
@@ -136,7 +144,11 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
transformers are filtered out. When `filter_resample` is `True`,
estimator with a method `fit_resample` are filtered out.
"""
- pass
+ stop = len(self.steps) if with_final else -1
+ for idx, (name, trans) in enumerate(self.steps[:stop]):
+ if not filter_passthrough or trans not in ('passthrough', None):
+ if not filter_resample or not hasattr(trans, 'fit_resample'):
+ yield idx, (name, trans)
@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X, y=None, **params):
@@ -183,7 +195,46 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
self : Pipeline
This estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "fit", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.fit = self._check_method_params(method="fit", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ yt = y
+ self._memory = check_memory(self.memory)
+
+ for step_idx, name, transformer in self._iter(with_final=False, filter_passthrough=False):
+ if transformer is None or transformer == "passthrough":
+ continue
+
+ if hasattr(transformer, "fit_resample"):
+ cloned_transformer = clone(transformer)
+ Xt, yt = cloned_transformer.fit_resample(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ if hasattr(transformer, "fit_transform"):
+ cloned_transformer = clone(transformer)
+ Xt = cloned_transformer.fit_transform(Xt, yt, **routed_params.fit.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ cloned_transformer = clone(transformer)
+ cloned_transformer.fit(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+
+ self.steps[step_idx] = (name, transformer)
+
+ if self._final_estimator != "passthrough":
+ fit_params_last_step = routed_params.fit.get(self.steps[-1][0], {})
+ self._final_estimator.fit(Xt, yt, **fit_params_last_step)
+
+ return self
@available_if(_can_fit_transform)
@_fit_context(prefer_skip_nested_validation=False)
@@ -230,7 +281,50 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Xt : array-like of shape (n_samples, n_transformed_features)
Transformed samples.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "fit", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.fit = self._check_method_params(method="fit", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ yt = y
+ self._memory = check_memory(self.memory)
+
+ for step_idx, name, transformer in self._iter(with_final=False, filter_passthrough=False):
+ if transformer is None or transformer == "passthrough":
+ continue
+
+ if hasattr(transformer, "fit_resample"):
+ cloned_transformer = clone(transformer)
+ Xt, yt = cloned_transformer.fit_resample(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ if hasattr(transformer, "fit_transform"):
+ cloned_transformer = clone(transformer)
+ Xt = cloned_transformer.fit_transform(Xt, yt, **routed_params.fit.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ cloned_transformer = clone(transformer)
+ cloned_transformer.fit(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+
+ self.steps[step_idx] = (name, transformer)
+
+ if self._final_estimator == "passthrough":
+ return Xt
+
+ fit_params_last_step = routed_params.fit.get(self.steps[-1][0], {})
+ if hasattr(self._final_estimator, "fit_transform"):
+ return self._final_estimator.fit_transform(Xt, yt, **fit_params_last_step)
+ else:
+ self._final_estimator.fit(Xt, yt, **fit_params_last_step)
+ return self._final_estimator.transform(Xt)
@available_if(pipeline._final_estimator_has('predict'))
def predict(self, X, **params):
@@ -279,7 +373,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_pred : ndarray
Result of calling `predict` on the final estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "predict", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.predict = self._check_method_params(method="predict", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+
+ predict_params = routed_params.predict.get(self.steps[-1][0], {})
+ return self.steps[-1][1].predict(Xt, **predict_params)
@available_if(_can_fit_resample)
@_fit_context(prefer_skip_nested_validation=False)
@@ -329,7 +438,47 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
yt : array-like of shape (n_samples, n_transformed_features)
Transformed target.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "fit", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.fit = self._check_method_params(method="fit", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ yt = y
+ self._memory = check_memory(self.memory)
+
+ for step_idx, name, transformer in self._iter(with_final=False, filter_passthrough=False):
+ if transformer is None or transformer == "passthrough":
+ continue
+
+ if hasattr(transformer, "fit_resample"):
+ cloned_transformer = clone(transformer)
+ Xt, yt = cloned_transformer.fit_resample(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ if hasattr(transformer, "fit_transform"):
+ cloned_transformer = clone(transformer)
+ Xt = cloned_transformer.fit_transform(Xt, yt, **routed_params.fit.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ cloned_transformer = clone(transformer)
+ cloned_transformer.fit(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+
+ self.steps[step_idx] = (name, transformer)
+
+ if self._final_estimator == "passthrough":
+ return Xt, yt
+
+ fit_params_last_step = routed_params.fit.get(self.steps[-1][0], {})
+ Xt, yt = self._final_estimator.fit_resample(Xt, yt, **fit_params_last_step)
+ return Xt, yt
@available_if(pipeline._final_estimator_has('fit_predict'))
@_fit_context(prefer_skip_nested_validation=False)
@@ -382,7 +531,43 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_pred : ndarray of shape (n_samples,)
The predicted target.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "fit", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.fit = self._check_method_params(method="fit", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ yt = y
+ self._memory = check_memory(self.memory)
+
+ for step_idx, name, transformer in self._iter(with_final=False, filter_passthrough=False):
+ if transformer is None or transformer == "passthrough":
+ continue
+
+ if hasattr(transformer, "fit_resample"):
+ cloned_transformer = clone(transformer)
+ Xt, yt = cloned_transformer.fit_resample(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ if hasattr(transformer, "fit_transform"):
+ cloned_transformer = clone(transformer)
+ Xt = cloned_transformer.fit_transform(Xt, yt, **routed_params.fit.get(name, {}))
+ transformer = cloned_transformer
+ else:
+ cloned_transformer = clone(transformer)
+ cloned_transformer.fit(Xt, yt, **routed_params.fit.get(name, {}))
+ if hasattr(cloned_transformer, "transform"):
+ Xt = cloned_transformer.transform(Xt, **routed_params.transform.get(name, {}))
+ transformer = cloned_transformer
+
+ self.steps[step_idx] = (name, transformer)
+
+ fit_params_last_step = routed_params.fit.get(self.steps[-1][0], {})
+ return self.steps[-1][1].fit_predict(Xt, yt, **fit_params_last_step)
@available_if(pipeline._final_estimator_has('predict_proba'))
def predict_proba(self, X, **params):
@@ -426,7 +611,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_proba` on the final estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "predict_proba", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.predict_proba = self._check_method_params(method="predict_proba", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+
+ predict_proba_params = routed_params.predict_proba.get(self.steps[-1][0], {})
+ return self.steps[-1][1].predict_proba(Xt, **predict_proba_params)
@available_if(pipeline._final_estimator_has('decision_function'))
def decision_function(self, X, **params):
@@ -458,7 +658,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_score : ndarray of shape (n_samples, n_classes)
Result of calling `decision_function` on the final estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "decision_function", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.decision_function = self._check_method_params(method="decision_function", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+
+ decision_function_params = routed_params.decision_function.get(self.steps[-1][0], {})
+ return self.steps[-1][1].decision_function(Xt, **decision_function_params)
@available_if(pipeline._final_estimator_has('score_samples'))
def score_samples(self, X):
@@ -480,7 +695,13 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_score : ndarray of shape (n_samples,)
Result of calling `score_samples` on the final estimator.
"""
- pass
+ Xt = X
+ for _, _, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt)
+ return self.steps[-1][1].score_samples(Xt)
@available_if(pipeline._final_estimator_has('predict_log_proba'))
def predict_log_proba(self, X, **params):
@@ -524,7 +745,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_log_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_log_proba` on the final estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "predict_log_proba", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.predict_log_proba = self._check_method_params(method="predict_log_proba", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+
+ predict_log_proba_params = routed_params.predict_log_proba.get(self.steps[-1][0], {})
+ return self.steps[-1][1].predict_log_proba(Xt, **predict_log_proba_params)
@available_if(_can_transform)
def transform(self, X, **params):
@@ -559,7 +795,19 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Xt : ndarray of shape (n_samples, n_transformed_features)
Transformed data.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "transform", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=True, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+ return Xt
@available_if(_can_inverse_transform)
def inverse_transform(self, Xt, **params):
@@ -591,7 +839,18 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Inverse transformed data, that is, data in the original feature
space.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "inverse_transform", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.inverse_transform = self._check_method_params(method="inverse_transform", **params)
+
+ for _, name, transform in reversed(list(self._iter(with_final=True, filter_passthrough=False))):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "inverse_transform"):
+ Xt = transform.inverse_transform(Xt, **routed_params.inverse_transform.get(name, {}))
+ return Xt
@available_if(pipeline._final_estimator_has('score'))
def score(self, X, y=None, sample_weight=None, **params):
@@ -630,7 +889,24 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
score : float
Result of calling `score` on the final estimator.
"""
- pass
+ if _routing_enabled():
+ routed_params = process_routing(self, "score", **params)
+ else:
+ routed_params = Bunch()
+ routed_params.score = self._check_method_params(method="score", **params)
+ routed_params.transform = self._check_method_params(method="transform", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False, filter_passthrough=False):
+ if transform is None or transform == "passthrough":
+ continue
+ if hasattr(transform, "transform"):
+ Xt = transform.transform(Xt, **routed_params.transform.get(name, {}))
+
+ score_params = routed_params.score.get(self.steps[-1][0], {})
+ if sample_weight is not None:
+ score_params = {**score_params, "sample_weight": sample_weight}
+ return self.steps[-1][1].score(Xt, y, **score_params)
def get_metadata_routing(self):
"""Get metadata routing of this object.
diff --git a/imblearn/utils/_metadata_requests.py b/imblearn/utils/_metadata_requests.py
index 494f911..46f897f 100644
--- a/imblearn/utils/_metadata_requests.py
+++ b/imblearn/utils/_metadata_requests.py
@@ -387,7 +387,9 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
common = existing & upcoming
conflicts = [key for key in common if requests[key] != mmr._requests[key]]
if conflicts:
- raise ValueError(f'Conflicting metadata requests for {', '.join(conflicts)} while composing the requests for {name}. Metadata with the same name for methods {', '.join(COMPOSITE_METHODS[name])} should have the same request value.')
+ conflicts_str = ', '.join(conflicts)
+ methods_str = ', '.join(COMPOSITE_METHODS[name])
+ raise ValueError(f'Conflicting metadata requests for {conflicts_str} while composing the requests for {name}. Metadata with the same name for methods {methods_str} should have the same request value.')
requests.update(mmr._requests)
return MethodMetadataRequest(owner=self.owner, method=name, requests=requests)
@@ -882,6 +884,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
.. versionadded:: 1.3
"""
if TYPE_CHECKING:
+ pass
def __init_subclass__(cls, **kwargs):
"""Set the ``set_{method}_request`` methods.
diff --git a/imblearn/utils/_param_validation.py b/imblearn/utils/_param_validation.py
index 25f6aa4..5ecbde1 100644
--- a/imblearn/utils/_param_validation.py
+++ b/imblearn/utils/_param_validation.py
@@ -59,7 +59,39 @@ if sklearn_version < parse_version('1.4'):
caller_name : str
The name of the estimator or function or method that called this function.
"""
- pass
+ if parameter_constraints == "no_validation":
+ return
+
+ for param_name, param_val in params.items():
+ # Each parameter can be validated against a list of constraints.
+ # We keep track of the exceptions raised by each validator so we can
+ # give a meaningful message if all of them failed.
+ if param_name not in parameter_constraints:
+ continue
+
+ constraints = parameter_constraints[param_name]
+ constraints = [make_constraint(constraint) for constraint in constraints]
+
+ exceptions_raised = []
+ for constraint in constraints:
+ try:
+ if constraint.is_satisfied_by(param_val):
+ # We found a valid constraint
+ break
+ except Exception as e:
+ exceptions_raised.append(e)
+ else:
+ # No constraint was found to be valid
+ constraints_str = ", ".join(str(c) for c in constraints)
+ error_msg = (
+ f"The {param_name!r} parameter of {caller_name} must be "
+ f"{constraints_str}. Got {param_val!r} instead."
+ )
+ if exceptions_raised:
+ error_msg += (
+ f" The following errors were raised: {exceptions_raised}"
+ )
+ raise InvalidParameterError(error_msg)
def make_constraint(constraint):
"""Convert the constraint into the appropriate Constraint object.
@@ -74,7 +106,35 @@ if sklearn_version < parse_version('1.4'):
constraint : instance of _Constraint
The converted constraint.
"""
- pass
+ if isinstance(constraint, _Constraint):
+ return constraint
+ elif isinstance(constraint, type):
+ return _InstancesOf(constraint)
+ elif constraint is None:
+ return _NoneConstraint()
+ elif constraint == "array-like":
+ return _ArrayLikes()
+ elif constraint == "sparse matrix":
+ return _SparseMatrices()
+ elif constraint == "random_state":
+ return _RandomStates()
+ elif constraint == "boolean":
+ return _Booleans()
+ elif constraint == "verbose":
+ return _VerboseHelper()
+ elif constraint == "cv_object":
+ return _CVObjects()
+ elif constraint == "nan":
+ return _NanConstraint()
+ elif callable(constraint):
+ return _Callables()
+ else:
+ raise ValueError(
+ f"Unknown constraint type: {constraint}. "
+ "Valid constraints are: instance of _Constraint, type, None, "
+ "'array-like', 'sparse matrix', 'random_state', 'boolean', "
+ "'verbose', 'cv_object', 'nan', or callable."
+ )
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
"""Decorator to validate types and values of functions and methods.
@@ -107,7 +167,33 @@ if sklearn_version < parse_version('1.4'):
decorated_function : function or method
The decorated function.
"""
- pass
+ def decorator(func):
+ # Get the signature of the function to be decorated
+ sig = signature(func)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ # Map *args and **kwargs to the function signature to get a dict of
+ # parameters
+ params = sig.bind(*args, **kwargs)
+ params.apply_defaults()
+ params = params.arguments
+
+ # Skip validation if the config flag is set
+ if get_config()["skip_parameter_validation"]:
+ return func(*args, **kwargs)
+
+ # Skip validation if we are in a nested validation context and
+ # prefer_skip_nested_validation is True
+ with config_context(skip_parameter_validation=prefer_skip_nested_validation):
+ validate_parameter_constraints(
+ parameter_constraints, params, func.__qualname__
+ )
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
class RealNotInt(Real):
"""A type that represents reals that are not instances of int.
@@ -120,7 +206,11 @@ if sklearn_version < parse_version('1.4'):
def _type_name(t):
"""Convert type into human readable string."""
- pass
+ module = t.__module__
+ qualname = t.__qualname__
+ if module == "builtins":
+ return qualname
+ return f"{module}.{qualname}"
class _Constraint(ABC):
"""Base class for the constraint objects."""
@@ -142,7 +232,7 @@ if sklearn_version < parse_version('1.4'):
is_satisfied : bool
Whether or not the constraint is satisfied by this value.
"""
- pass
+ raise NotImplementedError
@abstractmethod
def __str__(self):
@@ -161,24 +251,40 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self.type = type
+ def is_satisfied_by(self, val):
+ return isinstance(val, self.type)
+
def __str__(self):
return f'an instance of {_type_name(self.type)!r}'
class _NoneConstraint(_Constraint):
"""Constraint representing the None singleton."""
+ def is_satisfied_by(self, val):
+ return val is None
+
def __str__(self):
return 'None'
class _NanConstraint(_Constraint):
"""Constraint representing the indicator `np.nan`."""
+ def is_satisfied_by(self, val):
+ return isinstance(val, Real) and math.isnan(val)
+
def __str__(self):
return 'numpy.nan'
class _PandasNAConstraint(_Constraint):
"""Constraint representing the indicator `pd.NA`."""
+ def is_satisfied_by(self, val):
+ try:
+ import pandas as pd
+ return val is pd.NA
+ except ImportError:
+ return False
+
def __str__(self):
return 'pandas.NA'
@@ -207,10 +313,12 @@ if sklearn_version < parse_version('1.4'):
def _mark_if_deprecated(self, option):
"""Add a deprecated mark to an option if needed."""
- pass
+ if option in self.deprecated:
+ return f"{option} (DEPRECATED)"
+ return str(option)
def __str__(self):
- options_str = f'{', '.join([self._mark_if_deprecated(o) for o in self.options])}'
+ options_str = ', '.join([self._mark_if_deprecated(o) for o in self.options])
return f'a {_type_name(self.type)} among {{{options_str}}}'
class StrOptions(Options):
@@ -301,18 +409,27 @@ if sklearn_version < parse_version('1.4'):
class _ArrayLikes(_Constraint):
"""Constraint representing array-likes"""
+ def is_satisfied_by(self, val):
+ return _is_arraylike_not_scalar(val)
+
def __str__(self):
return 'an array-like'
class _SparseMatrices(_Constraint):
"""Constraint representing sparse matrices."""
+ def is_satisfied_by(self, val):
+ return issparse(val)
+
def __str__(self):
return 'a sparse matrix'
class _Callables(_Constraint):
"""Constraint representing callables."""
+ def is_satisfied_by(self, val):
+ return callable(val)
+
def __str__(self):
return 'a callable'
@@ -327,8 +444,12 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self._constraints = [Interval(Integral, 0, 2 ** 32 - 1, closed='both'), _InstancesOf(np.random.RandomState), _NoneConstraint()]
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
+
def __str__(self):
- return f'{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}'
+ constraints_str = ', '.join(str(c) for c in self._constraints[:-1])
+ return f'{constraints_str} or {self._constraints[-1]}'
class _Booleans(_Constraint):
"""Constraint representing boolean likes.
@@ -341,8 +462,12 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self._constraints = [_InstancesOf(bool), _InstancesOf(np.bool_)]
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
+
def __str__(self):
- return f'{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}'
+ constraints_str = ', '.join(str(c) for c in self._constraints[:-1])
+ return f'{constraints_str} or {self._constraints[-1]}'
class _VerboseHelper(_Constraint):
"""Helper constraint for the verbose parameter.
@@ -355,8 +480,12 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self._constraints = [Interval(Integral, 0, None, closed='left'), _InstancesOf(bool), _InstancesOf(np.bool_)]
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
+
def __str__(self):
- return f'{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}'
+ constraints_str = ', '.join(str(c) for c in self._constraints[:-1])
+ return f'{constraints_str} or {self._constraints[-1]}'
class MissingValues(_Constraint):
"""Helper constraint for the `missing_values` parameters.
@@ -385,8 +514,12 @@ if sklearn_version < parse_version('1.4'):
if not self.numeric_only:
self._constraints.extend([_InstancesOf(str), _NoneConstraint()])
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
+
def __str__(self):
- return f'{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}'
+ constraints_str = ', '.join(str(c) for c in self._constraints[:-1])
+ return f'{constraints_str} or {self._constraints[-1]}'
class HasMethods(_Constraint):
"""Constraint representing objects that expose specific methods.
@@ -407,16 +540,23 @@ if sklearn_version < parse_version('1.4'):
methods = [methods]
self.methods = methods
+ def is_satisfied_by(self, val):
+ return all(hasattr(val, method) and callable(getattr(val, method)) for method in self.methods)
+
def __str__(self):
if len(self.methods) == 1:
- methods = f'{self.methods[0]!r}'
+ methods = repr(self.methods[0])
else:
- methods = f'{', '.join([repr(m) for m in self.methods[:-1]])} and {self.methods[-1]!r}'
+ methods_str = ', '.join(repr(m) for m in self.methods[:-1])
+ methods = f'{methods_str} and {self.methods[-1]!r}'
return f'an object implementing {methods}'
class _IterablesNotString(_Constraint):
"""Constraint representing iterables that are not strings."""
+ def is_satisfied_by(self, val):
+ return isinstance(val, Iterable) and not isinstance(val, str)
+
def __str__(self):
return 'an iterable'
@@ -436,8 +576,12 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self._constraints = [Interval(Integral, 2, None, closed='left'), HasMethods(['split', 'get_n_splits']), _IterablesNotString(), _NoneConstraint()]
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
+
def __str__(self):
- return f'{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}'
+ constraints_str = ', '.join(str(c) for c in self._constraints[:-1])
+ return f'{constraints_str} or {self._constraints[-1]}'
class Hidden:
"""Class encapsulating a constraint not meant to be exposed to the user.
diff --git a/imblearn/utils/fixes.py b/imblearn/utils/fixes.py
index 6b49d43..47aa9a9 100644
--- a/imblearn/utils/fixes.py
+++ b/imblearn/utils/fixes.py
@@ -21,7 +21,7 @@ else:
def _is_arraylike_not_scalar(array):
"""Return True if array is array-like and not a scalar"""
- pass
+ return _is_arraylike(array) and not np.isscalar(array)
if sklearn_version < parse_version('1.3'):
def _fit_context(*, prefer_skip_nested_validation):
@@ -47,7 +47,13 @@ if sklearn_version < parse_version('1.3'):
decorated_fit : method
The decorated fit method.
"""
- pass
+ def decorator(fit_method):
+ @functools.wraps(fit_method)
+ def wrapper(self, *args, **kwargs):
+ with config_context(skip_parameter_validation=prefer_skip_nested_validation):
+ return fit_method(self, *args, **kwargs)
+ return wrapper
+ return decorator
else:
from sklearn.base import _fit_context
if sklearn_version < parse_version('1.3'):
@@ -76,7 +82,17 @@ if sklearn_version < parse_version('1.3'):
fitted : bool
Whether the estimator is fitted.
"""
- pass
+ if attributes is None:
+ attributes = [v for v in vars(estimator)
+ if v.endswith("_") and not v.startswith("__")]
+
+ if not attributes:
+ raise ValueError("No valid attributes to check if fitted.")
+
+ if isinstance(attributes, (str, bytes)):
+ attributes = [attributes]
+
+ return all_or_any(hasattr(estimator, attr) for attr in attributes)
else:
from sklearn.utils.validation import _is_fitted
try:
@@ -85,4 +101,19 @@ except ImportError:
def _is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
- pass
\ No newline at end of file
+ try:
+ import pandas as pd
+ return isinstance(X, pd.DataFrame)
+ except ImportError:
+ return False
+
+def _mode(a, axis=0):
+ """Return the mode of an array along a given axis.
+
+ This is a replacement for scipy.stats.mode which was deprecated in version 1.9.0.
+ """
+ if sp_version >= parse_version('1.9.0'):
+ mode_result = scipy.stats.mode(a, axis=axis, keepdims=True)
+ return mode_result.mode, mode_result.count
+ else:
+ return scipy.stats.mode(a, axis=axis)
\ No newline at end of file