back to Reference (Gold) summary
Reference (Gold): imbalanced-learn
Pytest Summary for test tests
status | count |
---|---|
passed | 1441 |
xpassed | 4 |
failed | 21 |
xfailed | 1 |
skipped | 1 |
total | 1468 |
collected | 1468 |
Failed pytests:
test_common.py::test_estimators_compatibility_sklearn[RandomOverSampler(random_state=0)-check_complex_data]
test_common.py::test_estimators_compatibility_sklearn[RandomOverSampler(random_state=0)-check_complex_data]
test_common.py::test_estimators_compatibility_sklearn[RandomUnderSampler(random_state=0)-check_complex_data]
test_common.py::test_estimators_compatibility_sklearn[RandomUnderSampler(random_state=0)-check_complex_data]
test_common.py::test_estimators_imblearn[AllKNN()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[AllKNN()-check_samplers_sparse]
estimator = AllKNN() check = functools.partial(, 'AllKNN') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'AllKNN', sampler_orig = AllKNN() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[BorderlineSMOTE(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[BorderlineSMOTE(random_state=0)-check_samplers_sparse]
estimator = BorderlineSMOTE(random_state=0) check = functools.partial(, 'BorderlineSMOTE') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'BorderlineSMOTE', sampler_orig = BorderlineSMOTE(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[ClusterCentroids(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[ClusterCentroids(random_state=0)-check_samplers_sparse]
estimator = ClusterCentroids(estimator=KMeans(n_init=1, random_state=0), random_state=0, voting='soft') check = functools.partial(, 'ClusterCentroids') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'ClusterCentroids' sampler_orig = ClusterCentroids(estimator=KMeans(n_init=1, random_state=0), random_state=0, voting='soft') def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[CondensedNearestNeighbour(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[CondensedNearestNeighbour(random_state=0)-check_samplers_sparse]
estimator = CondensedNearestNeighbour(random_state=0) check = functools.partial(, 'CondensedNearestNeighbour') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'CondensedNearestNeighbour' sampler_orig = CondensedNearestNeighbour(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[EditedNearestNeighbours()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[EditedNearestNeighbours()-check_samplers_sparse]
estimator = EditedNearestNeighbours() check = functools.partial(, 'EditedNearestNeighbours') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'EditedNearestNeighbours', sampler_orig = EditedNearestNeighbours() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[FunctionSampler()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[FunctionSampler()-check_samplers_sparse]
estimator = FunctionSampler() check = functools.partial(, 'FunctionSampler') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'FunctionSampler', sampler_orig = FunctionSampler() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[InstanceHardnessThreshold(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[InstanceHardnessThreshold(random_state=0)-check_samplers_sparse]
estimator = InstanceHardnessThreshold(random_state=0) check = functools.partial(, 'InstanceHardnessThreshold') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'InstanceHardnessThreshold' sampler_orig = InstanceHardnessThreshold(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[KMeansSMOTE(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[KMeansSMOTE(random_state=0)-check_samplers_sparse]
estimator = KMeansSMOTE(kmeans_estimator=12, random_state=0) check = functools.partial(, 'KMeansSMOTE') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'KMeansSMOTE' sampler_orig = KMeansSMOTE(kmeans_estimator=12, random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[NearMiss()-check_samplers_fit_resample]
test_common.py::test_estimators_imblearn[NearMiss()-check_samplers_fit_resample]
test_common.py::test_estimators_imblearn[NearMiss()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[NearMiss()-check_samplers_sparse]
estimator = NearMiss() check = functools.partial(, 'NearMiss') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'NearMiss', sampler_orig = NearMiss() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[NearMiss(version=2)-check_samplers_fit_resample]
test_common.py::test_estimators_imblearn[NearMiss(version=2)-check_samplers_fit_resample]
test_common.py::test_estimators_imblearn[NearMiss(version=2)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[NearMiss(version=2)-check_samplers_sparse]
estimator = NearMiss(version=2) check = functools.partial(, 'NearMiss') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'NearMiss', sampler_orig = NearMiss(version=2) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[NearMiss(version=3)-check_samplers_fit_resample]
test_common.py::test_estimators_imblearn[NearMiss(version=3)-check_samplers_fit_resample]
estimator = NearMiss(version=3) check = functools.partial(, 'NearMiss') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'NearMiss', sampler_orig = NearMiss(version=3) def check_samplers_fit_resample(name, sampler_orig): sampler = clone(sampler_orig) X, y = sample_dataset_generator() target_stats = Counter(y) X_res, y_res = sampler.fit_resample(X, y) if isinstance(sampler, BaseOverSampler): target_stats_res = Counter(y_res) n_samples = max(target_stats.values()) assert all(value >= n_samples for value in Counter(y_res).values()) elif isinstance(sampler, BaseUnderSampler): n_samples = min(target_stats.values()) if name == "InstanceHardnessThreshold": # IHT does not enforce the number of samples but provide a number # of samples the closest to the desired target. assert all( Counter(y_res)[k] <= target_stats[k] for k in target_stats.keys() ) else: > assert all(value == n_samples for value in Counter(y_res).values()) E AssertionError imblearn/utils/estimator_checks.py:269: AssertionError
test_common.py::test_estimators_imblearn[NearMiss(version=3)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[NearMiss(version=3)-check_samplers_sparse]
estimator = NearMiss(version=3) check = functools.partial(, 'NearMiss') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'NearMiss', sampler_orig = NearMiss(version=3) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[NeighbourhoodCleaningRule()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[NeighbourhoodCleaningRule()-check_samplers_sparse]
estimator = NeighbourhoodCleaningRule() check = functools.partial(, 'NeighbourhoodCleaningRule') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'NeighbourhoodCleaningRule', sampler_orig = NeighbourhoodCleaningRule() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[OneSidedSelection(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[OneSidedSelection(random_state=0)-check_samplers_sparse]
estimator = OneSidedSelection(random_state=0) check = functools.partial(, 'OneSidedSelection') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'OneSidedSelection', sampler_orig = OneSidedSelection(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[RandomOverSampler(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[RandomOverSampler(random_state=0)-check_samplers_sparse]
estimator = RandomOverSampler(random_state=0) check = functools.partial(, 'RandomOverSampler') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'RandomOverSampler', sampler_orig = RandomOverSampler(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[RandomUnderSampler(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[RandomUnderSampler(random_state=0)-check_samplers_sparse]
estimator = RandomUnderSampler(random_state=0) check = functools.partial(, 'RandomUnderSampler') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'RandomUnderSampler', sampler_orig = RandomUnderSampler(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[RepeatedEditedNearestNeighbours()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[RepeatedEditedNearestNeighbours()-check_samplers_sparse]
estimator = RepeatedEditedNearestNeighbours() check = functools.partial(, 'RepeatedEditedNearestNeighbours') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'RepeatedEditedNearestNeighbours' sampler_orig = RepeatedEditedNearestNeighbours() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[SMOTE(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[SMOTE(random_state=0)-check_samplers_sparse]
estimator = SMOTE(random_state=0) check = functools.partial(, 'SMOTE') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'SMOTE', sampler_orig = SMOTE(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[SMOTEENN(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[SMOTEENN(random_state=0)-check_samplers_sparse]
estimator = SMOTEENN(random_state=0) check = functools.partial(, 'SMOTEENN') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'SMOTEENN', sampler_orig = SMOTEENN(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[SMOTETomek(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[SMOTETomek(random_state=0)-check_samplers_sparse]
estimator = SMOTETomek(random_state=0) check = functools.partial(, 'SMOTETomek') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'SMOTETomek', sampler_orig = SMOTETomek(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[SVMSMOTE(random_state=0)-check_samplers_sparse]
test_common.py::test_estimators_imblearn[SVMSMOTE(random_state=0)-check_samplers_sparse]
estimator = SVMSMOTE(random_state=0) check = functools.partial(, 'SVMSMOTE') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'SVMSMOTE', sampler_orig = SVMSMOTE(random_state=0) def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
test_common.py::test_estimators_imblearn[TomekLinks()-check_samplers_sparse]
test_common.py::test_estimators_imblearn[TomekLinks()-check_samplers_sparse]
estimator = TomekLinks() check = functools.partial(, 'TomekLinks') request = > @parametrize_with_checks(list(_tested_estimators())) def test_estimators_imblearn(estimator, check, request): # Common tests for estimator instances with ignore_warnings( category=( FutureWarning, ConvergenceWarning, UserWarning, FutureWarning, ) ): _set_checking_parameters(estimator) > check(estimator) imblearn/tests/test_common.py:71: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'TomekLinks', sampler_orig = TomekLinks() def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) X_res, y_res = sampler.fit_resample(X, y) assert sparse.issparse(X_res_sparse) > assert_allclose(X_res_sparse.A, X_res, rtol=1e-5) E AttributeError: 'csr_matrix' object has no attribute 'A' imblearn/utils/estimator_checks.py:312: AttributeError
Patch diff
diff --git a/imblearn/_config.py b/imblearn/_config.py
index 88884fe..ef98e73 100644
--- a/imblearn/_config.py
+++ b/imblearn/_config.py
@@ -5,23 +5,34 @@ We remove the array_api_dispatch for the moment.
import os
import threading
from contextlib import contextmanager as contextmanager
+
import sklearn
from sklearn.utils.fixes import parse_version
+
sklearn_version = parse_version(sklearn.__version__)
-if sklearn_version < parse_version('1.3'):
- _global_config = {'assume_finite': bool(os.environ.get(
- 'SKLEARN_ASSUME_FINITE', False)), 'working_memory': int(os.environ.
- get('SKLEARN_WORKING_MEMORY', 1024)), 'print_changed_only': True,
- 'display': 'diagram', 'pairwise_dist_chunk_size': int(os.environ.
- get('SKLEARN_PAIRWISE_DIST_CHUNK_SIZE', 256)),
- 'enable_cython_pairwise_dist': True, 'transform_output': 'default',
- 'enable_metadata_routing': False, 'skip_parameter_validation': False}
+
+if sklearn_version < parse_version("1.3"):
+ _global_config = {
+ "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
+ "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
+ "print_changed_only": True,
+ "display": "diagram",
+ "pairwise_dist_chunk_size": int(
+ os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
+ ),
+ "enable_cython_pairwise_dist": True,
+ "transform_output": "default",
+ "enable_metadata_routing": False,
+ "skip_parameter_validation": False,
+ }
_threadlocal = threading.local()
def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
- pass
+ if not hasattr(_threadlocal, "global_config"):
+ _threadlocal.global_config = _global_config.copy()
+ return _threadlocal.global_config
def get_config():
"""Retrieve current values for configuration set by :func:`set_config`.
@@ -36,12 +47,21 @@ if sklearn_version < parse_version('1.3'):
config_context : Context manager for global scikit-learn configuration.
set_config : Set global scikit-learn configuration.
"""
- pass
-
- def set_config(assume_finite=None, working_memory=None,
- print_changed_only=None, display=None, pairwise_dist_chunk_size=
- None, enable_cython_pairwise_dist=None, transform_output=None,
- enable_metadata_routing=None, skip_parameter_validation=None):
+ # Return a copy of the threadlocal configuration so that users will
+ # not be able to modify the configuration with the returned dict.
+ return _get_threadlocal_config().copy()
+
+ def set_config(
+ assume_finite=None,
+ working_memory=None,
+ print_changed_only=None,
+ display=None,
+ pairwise_dist_chunk_size=None,
+ enable_cython_pairwise_dist=None,
+ transform_output=None,
+ enable_metadata_routing=None,
+ skip_parameter_validation=None,
+ ):
"""Set global scikit-learn configuration
.. versionadded:: 0.19
@@ -142,13 +162,40 @@ if sklearn_version < parse_version('1.3'):
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
"""
- pass
+ local_config = _get_threadlocal_config()
+
+ if assume_finite is not None:
+ local_config["assume_finite"] = assume_finite
+ if working_memory is not None:
+ local_config["working_memory"] = working_memory
+ if print_changed_only is not None:
+ local_config["print_changed_only"] = print_changed_only
+ if display is not None:
+ local_config["display"] = display
+ if pairwise_dist_chunk_size is not None:
+ local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
+ if enable_cython_pairwise_dist is not None:
+ local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
+ if transform_output is not None:
+ local_config["transform_output"] = transform_output
+ if enable_metadata_routing is not None:
+ local_config["enable_metadata_routing"] = enable_metadata_routing
+ if skip_parameter_validation is not None:
+ local_config["skip_parameter_validation"] = skip_parameter_validation
@contextmanager
- def config_context(*, assume_finite=None, working_memory=None,
- print_changed_only=None, display=None, pairwise_dist_chunk_size=
- None, enable_cython_pairwise_dist=None, transform_output=None,
- enable_metadata_routing=None, skip_parameter_validation=None):
+ def config_context(
+ *,
+ assume_finite=None,
+ working_memory=None,
+ print_changed_only=None,
+ display=None,
+ pairwise_dist_chunk_size=None,
+ enable_cython_pairwise_dist=None,
+ transform_output=None,
+ enable_metadata_routing=None,
+ skip_parameter_validation=None,
+ ):
"""Context manager for global scikit-learn configuration.
Parameters
@@ -270,6 +317,28 @@ if sklearn_version < parse_version('1.3'):
...
ValueError: Input contains NaN...
"""
- pass
+ old_config = get_config()
+ set_config(
+ assume_finite=assume_finite,
+ working_memory=working_memory,
+ print_changed_only=print_changed_only,
+ display=display,
+ pairwise_dist_chunk_size=pairwise_dist_chunk_size,
+ enable_cython_pairwise_dist=enable_cython_pairwise_dist,
+ transform_output=transform_output,
+ enable_metadata_routing=enable_metadata_routing,
+ skip_parameter_validation=skip_parameter_validation,
+ )
+
+ try:
+ yield
+ finally:
+ set_config(**old_config)
+
else:
- from sklearn._config import _get_threadlocal_config, _global_config, config_context, get_config
+ from sklearn._config import ( # type: ignore[no-redef]
+ _get_threadlocal_config,
+ _global_config,
+ config_context, # noqa
+ get_config,
+ )
diff --git a/imblearn/_min_dependencies.py b/imblearn/_min_dependencies.py
index d713aa5..ec1f5de 100644
--- a/imblearn/_min_dependencies.py
+++ b/imblearn/_min_dependencies.py
@@ -1,38 +1,60 @@
"""All minimum dependencies for imbalanced-learn."""
import argparse
-NUMPY_MIN_VERSION = '1.17.3'
-SCIPY_MIN_VERSION = '1.5.0'
-PANDAS_MIN_VERSION = '1.0.5'
-SKLEARN_MIN_VERSION = '1.0.2'
-TENSORFLOW_MIN_VERSION = '2.4.3'
-KERAS_MIN_VERSION = '2.4.3'
-JOBLIB_MIN_VERSION = '1.1.1'
-THREADPOOLCTL_MIN_VERSION = '2.0.0'
-PYTEST_MIN_VERSION = '5.0.1'
-dependent_packages = {'numpy': (NUMPY_MIN_VERSION, 'install'), 'scipy': (
- SCIPY_MIN_VERSION, 'install'), 'scikit-learn': (SKLEARN_MIN_VERSION,
- 'install'), 'joblib': (JOBLIB_MIN_VERSION, 'install'), 'threadpoolctl':
- (THREADPOOLCTL_MIN_VERSION, 'install'), 'pandas': (PANDAS_MIN_VERSION,
- 'optional, docs, examples, tests'), 'tensorflow': (
- TENSORFLOW_MIN_VERSION, 'optional, docs, examples, tests'), 'keras': (
- KERAS_MIN_VERSION, 'optional, docs, examples, tests'), 'matplotlib': (
- '3.1.2', 'docs, examples'), 'seaborn': ('0.9.0', 'docs, examples'),
- 'memory_profiler': ('0.57.0', 'docs'), 'pytest': (PYTEST_MIN_VERSION,
- 'tests'), 'pytest-cov': ('2.9.0', 'tests'), 'flake8': ('3.8.2', 'tests'
- ), 'black': ('23.3.0', 'tests'), 'mypy': ('1.3.0', 'tests'), 'sphinx':
- ('6.0.0', 'docs'), 'sphinx-gallery': ('0.13.0', 'docs'),
- 'sphinx-copybutton': ('0.5.2', 'docs'), 'numpydoc': ('1.5.0', 'docs'),
- 'sphinxcontrib-bibtex': ('2.4.1', 'docs'), 'pydata-sphinx-theme': (
- '0.13.3', 'docs'), 'sphinx-design': ('0.5.0', 'docs')}
-tag_to_packages: dict = {extra: [] for extra in ['install', 'optional',
- 'docs', 'examples', 'tests']}
+
+NUMPY_MIN_VERSION = "1.17.3"
+SCIPY_MIN_VERSION = "1.5.0"
+PANDAS_MIN_VERSION = "1.0.5"
+SKLEARN_MIN_VERSION = "1.0.2"
+TENSORFLOW_MIN_VERSION = "2.4.3"
+KERAS_MIN_VERSION = "2.4.3"
+JOBLIB_MIN_VERSION = "1.1.1"
+THREADPOOLCTL_MIN_VERSION = "2.0.0"
+PYTEST_MIN_VERSION = "5.0.1"
+
+# 'build' and 'install' is included to have structured metadata for CI.
+# It will NOT be included in setup's extras_require
+# The values are (version_spec, comma separated tags)
+dependent_packages = {
+ "numpy": (NUMPY_MIN_VERSION, "install"),
+ "scipy": (SCIPY_MIN_VERSION, "install"),
+ "scikit-learn": (SKLEARN_MIN_VERSION, "install"),
+ "joblib": (JOBLIB_MIN_VERSION, "install"),
+ "threadpoolctl": (THREADPOOLCTL_MIN_VERSION, "install"),
+ "pandas": (PANDAS_MIN_VERSION, "optional, docs, examples, tests"),
+ "tensorflow": (TENSORFLOW_MIN_VERSION, "optional, docs, examples, tests"),
+ "keras": (KERAS_MIN_VERSION, "optional, docs, examples, tests"),
+ "matplotlib": ("3.1.2", "docs, examples"),
+ "seaborn": ("0.9.0", "docs, examples"),
+ "memory_profiler": ("0.57.0", "docs"),
+ "pytest": (PYTEST_MIN_VERSION, "tests"),
+ "pytest-cov": ("2.9.0", "tests"),
+ "flake8": ("3.8.2", "tests"),
+ "black": ("23.3.0", "tests"),
+ "mypy": ("1.3.0", "tests"),
+ "sphinx": ("6.0.0", "docs"),
+ "sphinx-gallery": ("0.13.0", "docs"),
+ "sphinx-copybutton": ("0.5.2", "docs"),
+ "numpydoc": ("1.5.0", "docs"),
+ "sphinxcontrib-bibtex": ("2.4.1", "docs"),
+ "pydata-sphinx-theme": ("0.13.3", "docs"),
+ "sphinx-design": ("0.5.0", "docs"),
+}
+
+
+# create inverse mapping for setuptools
+tag_to_packages: dict = {
+ extra: [] for extra in ["install", "optional", "docs", "examples", "tests"]
+}
for package, (min_version, extras) in dependent_packages.items():
- for extra in extras.split(', '):
- tag_to_packages[extra].append('{}>={}'.format(package, min_version))
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description=
- 'Get min dependencies for a package')
- parser.add_argument('package', choices=dependent_packages)
+ for extra in extras.split(", "):
+ tag_to_packages[extra].append("{}>={}".format(package, min_version))
+
+
+# Used by CI to get the min dependencies
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Get min dependencies for a package")
+
+ parser.add_argument("package", choices=dependent_packages)
args = parser.parse_args()
min_version = dependent_packages[args.package][0]
print(min_version)
diff --git a/imblearn/_version.py b/imblearn/_version.py
index ed49005..19c405e 100644
--- a/imblearn/_version.py
+++ b/imblearn/_version.py
@@ -2,4 +2,24 @@
``imbalanced-learn`` is a set of python methods to deal with imbalanced
datset in machine learning and pattern recognition.
"""
-__version__ = '0.12.3'
+# Based on NiLearn package
+# License: simplified BSD
+
+# PEP0440 compatible formatted version, see:
+# https://www.python.org/dev/peps/pep-0440/
+#
+# Generic release markers:
+# X.Y
+# X.Y.Z # For bugfix releases
+#
+# Admissible pre-release markers:
+# X.YaN # Alpha release
+# X.YbN # Beta release
+# X.YrcN # Release Candidate
+# X.Y # Final release
+#
+# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
+# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
+#
+
+__version__ = "0.12.3"
diff --git a/imblearn/base.py b/imblearn/base.py
index b4c50a8..0b2d94e 100644
--- a/imblearn/base.py
+++ b/imblearn/base.py
@@ -1,18 +1,29 @@
-"""Base class for sampling"""
+"""Base class for sampling"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from abc import ABCMeta, abstractmethod
+
import numpy as np
import sklearn
from sklearn.base import BaseEstimator
+
try:
+ # scikit-learn >= 1.2
from sklearn.base import OneToOneFeatureMixin
except ImportError:
from sklearn.base import _OneToOneFeatureMixin as OneToOneFeatureMixin
+
from sklearn.preprocessing import label_binarize
from sklearn.utils.fixes import parse_version
from sklearn.utils.multiclass import check_classification_targets
+
from .utils import check_sampling_strategy, check_target_type
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer
+
sklearn_version = parse_version(sklearn.__version__)
@@ -27,7 +38,12 @@ class _ParamsValidationMixin:
the docstring of `validate_parameter_constraints` for a description of the
accepted constraints.
"""
- pass
+ if hasattr(self, "_parameter_constraints"):
+ validate_parameter_constraints(
+ self._parameter_constraints,
+ self.get_params(deep=False),
+ caller_name=self.__class__.__name__,
+ )
class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
@@ -36,7 +52,8 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
Warning: This class should not be used directly. Use the derive classes
instead.
"""
- _estimator_type = 'sampler'
+
+ _estimator_type = "sampler"
def fit(self, X, y):
"""Check inputs and statistics of the sampler.
@@ -45,7 +62,8 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
Parameters
----------
- X : {array-like, dataframe, sparse matrix} of shape (n_samples, n_features)
+ X : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples, n_features)
Data array.
y : array-like of shape (n_samples,)
@@ -56,14 +74,19 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
self : object
Return the instance itself.
"""
- pass
+ X, y, _ = self._check_X_y(X, y)
+ self.sampling_strategy_ = check_sampling_strategy(
+ self.sampling_strategy, y, self._sampling_type
+ )
+ return self
def fit_resample(self, X, y):
"""Resample the dataset.
Parameters
----------
- X : {array-like, dataframe, sparse matrix} of shape (n_samples, n_features)
+ X : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like of shape (n_samples,)
@@ -71,13 +94,29 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
Returns
-------
- X_resampled : {array-like, dataframe, sparse matrix} of shape (n_samples_new, n_features)
+ X_resampled : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
- pass
+ check_classification_targets(y)
+ arrays_transformer = ArraysTransformer(X, y)
+ X, y, binarize_y = self._check_X_y(X, y)
+
+ self.sampling_strategy_ = check_sampling_strategy(
+ self.sampling_strategy, y, self._sampling_type
+ )
+
+ output = self._fit_resample(X, y)
+
+ y_ = (
+ label_binarize(output[1], classes=np.unique(y)) if binarize_y else output[1]
+ )
+
+ X_, y_ = arrays_transformer.transform(output[0], y_)
+ return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
@abstractmethod
def _fit_resample(self, X, y):
@@ -94,7 +133,8 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
Returns
-------
- X_resampled : {ndarray, sparse matrix} of shape (n_samples_new, n_features)
+ X_resampled : {ndarray, sparse matrix} of shape \
+ (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray of shape (n_samples_new,)
@@ -111,9 +151,16 @@ class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
instead.
"""
- def __init__(self, sampling_strategy='auto'):
+ def __init__(self, sampling_strategy="auto"):
self.sampling_strategy = sampling_strategy
+ def _check_X_y(self, X, y, accept_sparse=None):
+ if accept_sparse is None:
+ accept_sparse = ["csr", "csc"]
+ y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
+ X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
+ return X, y, binarize_y
+
def fit(self, X, y):
"""Check inputs and statistics of the sampler.
@@ -121,7 +168,8 @@ class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
Parameters
----------
- X : {array-like, dataframe, sparse matrix} of shape (n_samples, n_features)
+ X : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples, n_features)
Data array.
y : array-like of shape (n_samples,)
@@ -132,14 +180,16 @@ class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
self : object
Return the instance itself.
"""
- pass
+ self._validate_params()
+ return super().fit(X, y)
def fit_resample(self, X, y):
"""Resample the dataset.
Parameters
----------
- X : {array-like, dataframe, sparse matrix} of shape (n_samples, n_features)
+ X : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like of shape (n_samples,)
@@ -147,13 +197,22 @@ class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
Returns
-------
- X_resampled : {array-like, dataframe, sparse matrix} of shape (n_samples_new, n_features)
+ X_resampled : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
- pass
+ self._validate_params()
+ return super().fit_resample(X, y)
+
+ def _more_tags(self):
+ return {"X_types": ["2darray", "sparse", "dataframe"]}
+
+
+def _identity(X, y):
+ return X, y
def is_sampler(estimator):
@@ -169,7 +228,9 @@ def is_sampler(estimator):
is_sampler : bool
True if estimator is a sampler, otherwise False.
"""
- pass
+ if estimator._estimator_type == "sampler":
+ return True
+ return False
class FunctionSampler(BaseSampler):
@@ -260,13 +321,17 @@ class FunctionSampler(BaseSampler):
>>> print(f'Resampled dataset shape {sorted(Counter(y_res).items())}')
Resampled dataset shape [(0, 100), (1, 100)]
"""
- _sampling_type = 'bypass'
- _parameter_constraints: dict = {'func': [callable, None],
- 'accept_sparse': ['boolean'], 'kw_args': [dict, None], 'validate':
- ['boolean']}
- def __init__(self, *, func=None, accept_sparse=True, kw_args=None,
- validate=True):
+ _sampling_type = "bypass"
+
+ _parameter_constraints: dict = {
+ "func": [callable, None],
+ "accept_sparse": ["boolean"],
+ "kw_args": [dict, None],
+ "validate": ["boolean"],
+ }
+
+ def __init__(self, *, func=None, accept_sparse=True, kw_args=None, validate=True):
super().__init__()
self.func = func
self.accept_sparse = accept_sparse
@@ -280,7 +345,8 @@ class FunctionSampler(BaseSampler):
Parameters
----------
- X : {array-like, dataframe, sparse matrix} of shape (n_samples, n_features)
+ X : {array-like, dataframe, sparse matrix} of shape \
+ (n_samples, n_features)
Data array.
y : array-like of shape (n_samples,)
@@ -291,7 +357,17 @@ class FunctionSampler(BaseSampler):
self : object
Return the instance itself.
"""
- pass
+ self._validate_params()
+ # we need to overwrite SamplerMixin.fit to bypass the validation
+ if self.validate:
+ check_classification_targets(y)
+ X, y, _ = self._check_X_y(X, y, accept_sparse=self.accept_sparse)
+
+ self.sampling_strategy_ = check_sampling_strategy(
+ self.sampling_strategy, y, self._sampling_type
+ )
+
+ return self
def fit_resample(self, X, y):
"""Resample the dataset.
@@ -306,10 +382,38 @@ class FunctionSampler(BaseSampler):
Returns
-------
- X_resampled : {array-like, sparse matrix} of shape (n_samples_new, n_features)
+ X_resampled : {array-like, sparse matrix} of shape \
+ (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
- pass
+ self._validate_params()
+ arrays_transformer = ArraysTransformer(X, y)
+
+ if self.validate:
+ check_classification_targets(y)
+ X, y, binarize_y = self._check_X_y(X, y, accept_sparse=self.accept_sparse)
+
+ self.sampling_strategy_ = check_sampling_strategy(
+ self.sampling_strategy, y, self._sampling_type
+ )
+
+ output = self._fit_resample(X, y)
+
+ if self.validate:
+ y_ = (
+ label_binarize(output[1], classes=np.unique(y))
+ if binarize_y
+ else output[1]
+ )
+ X_, y_ = arrays_transformer.transform(output[0], y_)
+ return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
+
+ return output
+
+ def _fit_resample(self, X, y):
+ func = _identity if self.func is None else self.func
+ output = func(X, y, **(self.kw_args if self.kw_args else {}))
+ return output
diff --git a/imblearn/combine/_smote_enn.py b/imblearn/combine/_smote_enn.py
index 451604e..1b0ffe0 100644
--- a/imblearn/combine/_smote_enn.py
+++ b/imblearn/combine/_smote_enn.py
@@ -1,7 +1,14 @@
"""Class to perform over-sampling using SMOTE and cleaning using ENN."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
+
from sklearn.base import clone
from sklearn.utils import check_X_y
+
from ..base import BaseSampler
from ..over_sampling import SMOTE
from ..over_sampling.base import BaseOverSampler
@@ -10,9 +17,11 @@ from ..utils import Substitution, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SMOTEENN(BaseSampler):
"""Over-sampling using SMOTE and cleaning using ENN.
@@ -98,13 +107,25 @@ class SMOTEENN(BaseSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 881}})
"""
- _sampling_type = 'over-sampling'
- _parameter_constraints: dict = {**BaseOverSampler.
- _parameter_constraints, 'smote': [SMOTE, None], 'enn': [
- EditedNearestNeighbours, None], 'n_jobs': [numbers.Integral, None]}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- smote=None, enn=None, n_jobs=None):
+ _sampling_type = "over-sampling"
+
+ _parameter_constraints: dict = {
+ **BaseOverSampler._parameter_constraints,
+ "smote": [SMOTE, None],
+ "enn": [EditedNearestNeighbours, None],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ smote=None,
+ enn=None,
+ n_jobs=None,
+ ):
super().__init__()
self.sampling_strategy = sampling_strategy
self.random_state = random_state
@@ -113,5 +134,28 @@ class SMOTEENN(BaseSampler):
self.n_jobs = n_jobs
def _validate_estimator(self):
- """Private function to validate SMOTE and ENN objects"""
- pass
+ "Private function to validate SMOTE and ENN objects"
+ if self.smote is not None:
+ self.smote_ = clone(self.smote)
+ else:
+ self.smote_ = SMOTE(
+ sampling_strategy=self.sampling_strategy,
+ random_state=self.random_state,
+ n_jobs=self.n_jobs,
+ )
+
+ if self.enn is not None:
+ self.enn_ = clone(self.enn)
+ else:
+ self.enn_ = EditedNearestNeighbours(
+ sampling_strategy="all", n_jobs=self.n_jobs
+ )
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+ y = check_target_type(y)
+ X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
+ self.sampling_strategy_ = self.sampling_strategy
+
+ X_res, y_res = self.smote_.fit_resample(X, y)
+ return self.enn_.fit_resample(X_res, y_res)
diff --git a/imblearn/combine/_smote_tomek.py b/imblearn/combine/_smote_tomek.py
index 2bbf9bf..94d7c4d 100644
--- a/imblearn/combine/_smote_tomek.py
+++ b/imblearn/combine/_smote_tomek.py
@@ -1,8 +1,15 @@
"""Class to perform over-sampling using SMOTE and cleaning using Tomek
links."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
+
from sklearn.base import clone
from sklearn.utils import check_X_y
+
from ..base import BaseSampler
from ..over_sampling import SMOTE
from ..over_sampling.base import BaseOverSampler
@@ -11,9 +18,11 @@ from ..utils import Substitution, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SMOTETomek(BaseSampler):
"""Over-sampling using SMOTE and cleaning using Tomek links.
@@ -96,13 +105,25 @@ class SMOTETomek(BaseSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
- _sampling_type = 'over-sampling'
- _parameter_constraints: dict = {**BaseOverSampler.
- _parameter_constraints, 'smote': [SMOTE, None], 'tomek': [
- TomekLinks, None], 'n_jobs': [numbers.Integral, None]}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- smote=None, tomek=None, n_jobs=None):
+ _sampling_type = "over-sampling"
+
+ _parameter_constraints: dict = {
+ **BaseOverSampler._parameter_constraints,
+ "smote": [SMOTE, None],
+ "tomek": [TomekLinks, None],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ smote=None,
+ tomek=None,
+ n_jobs=None,
+ ):
super().__init__()
self.sampling_strategy = sampling_strategy
self.random_state = random_state
@@ -111,5 +132,27 @@ class SMOTETomek(BaseSampler):
self.n_jobs = n_jobs
def _validate_estimator(self):
- """Private function to validate SMOTE and ENN objects"""
- pass
+ "Private function to validate SMOTE and ENN objects"
+
+ if self.smote is not None:
+ self.smote_ = clone(self.smote)
+ else:
+ self.smote_ = SMOTE(
+ sampling_strategy=self.sampling_strategy,
+ random_state=self.random_state,
+ n_jobs=self.n_jobs,
+ )
+
+ if self.tomek is not None:
+ self.tomek_ = clone(self.tomek)
+ else:
+ self.tomek_ = TomekLinks(sampling_strategy="all", n_jobs=self.n_jobs)
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+ y = check_target_type(y)
+ X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
+ self.sampling_strategy_ = self.sampling_strategy
+
+ X_res, y_res = self.smote_.fit_resample(X, y)
+ return self.tomek_.fit_resample(X_res, y_res)
diff --git a/imblearn/combine/tests/test_smote_enn.py b/imblearn/combine/tests/test_smote_enn.py
index f6dabe0..df72cc7 100644
--- a/imblearn/combine/tests/test_smote_enn.py
+++ b/imblearn/combine/tests/test_smote_enn.py
@@ -1,18 +1,157 @@
"""Test the module SMOTE ENN."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import EditedNearestNeighbours
+
RND_SEED = 0
-X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [
- 1.25192108, -0.22367336], [0.53366841, -0.30312976], [1.52091956, -
- 0.49283504], [-0.28162401, -2.10400981], [0.83680821, 1.72827342], [
- 0.3084254, 0.33299982], [0.70472253, -0.73309052], [0.28893132, -
- 0.38761769], [1.15514042, 0.0129463], [0.88407872, 0.35454207], [
- 1.31301027, -0.92648734], [-1.11515198, -0.93689695], [-0.18410027, -
- 0.45194484], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], [-
- 0.41635887, -0.38299653], [0.08711622, 0.93259929], [1.70580611, -
- 0.11219234]])
+X = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ ]
+)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
-R_TOL = 0.0001
+R_TOL = 1e-4
+
+
+def test_sample_regular():
+ smote = SMOTEENN(random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [1.52091956, -0.49283504],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571667],
+ [0.66052536, -0.28246518],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.08711622, 0.93259929],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 1, 1, 1])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_sample_regular_pass_smote_enn():
+ smote = SMOTEENN(
+ smote=SMOTE(sampling_strategy="auto", random_state=RND_SEED),
+ enn=EditedNearestNeighbours(sampling_strategy="all"),
+ random_state=RND_SEED,
+ )
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [1.52091956, -0.49283504],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571667],
+ [0.66052536, -0.28246518],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.08711622, 0.93259929],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 1, 1, 1])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_sample_regular_half():
+ sampling_strategy = {0: 10, 1: 12}
+ smote = SMOTEENN(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.08711622, 0.93259929],
+ ]
+ )
+ y_gt = np.array([0, 1, 1, 1])
+ assert_allclose(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_validate_estimator_init():
+ smote = SMOTE(random_state=RND_SEED)
+ enn = EditedNearestNeighbours(sampling_strategy="all")
+ smt = SMOTEENN(smote=smote, enn=enn, random_state=RND_SEED)
+ X_resampled, y_resampled = smt.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [1.52091956, -0.49283504],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571667],
+ [0.66052536, -0.28246518],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.08711622, 0.93259929],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 1, 1, 1])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_validate_estimator_default():
+ smt = SMOTEENN(random_state=RND_SEED)
+ X_resampled, y_resampled = smt.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [1.52091956, -0.49283504],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571667],
+ [0.66052536, -0.28246518],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.08711622, 0.93259929],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 1, 1, 1])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_parallelisation():
+ # Check if default job count is none
+ smt = SMOTEENN(random_state=RND_SEED)
+ smt._validate_estimator()
+ assert smt.n_jobs is None
+ assert smt.smote_.n_jobs is None
+ assert smt.enn_.n_jobs is None
+
+ # Check if job count is set
+ smt = SMOTEENN(random_state=RND_SEED, n_jobs=8)
+ smt._validate_estimator()
+ assert smt.n_jobs == 8
+ assert smt.smote_.n_jobs == 8
+ assert smt.enn_.n_jobs == 8
diff --git a/imblearn/combine/tests/test_smote_tomek.py b/imblearn/combine/tests/test_smote_tomek.py
index 5685726..2ca3e38 100644
--- a/imblearn/combine/tests/test_smote_tomek.py
+++ b/imblearn/combine/tests/test_smote_tomek.py
@@ -1,18 +1,167 @@
"""Test the module SMOTE ENN."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.combine import SMOTETomek
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import TomekLinks
+
RND_SEED = 0
-X = np.array([[0.20622591, 0.0582794], [0.68481731, 0.51935141], [
- 1.34192108, -0.13367336], [0.62366841, -0.21312976], [1.61091956, -
- 0.40283504], [-0.37162401, -2.19400981], [0.74680821, 1.63827342], [
- 0.2184254, 0.24299982], [0.61472253, -0.82309052], [0.19893132, -
- 0.47761769], [1.06514042, -0.0770537], [0.97407872, 0.44454207], [
- 1.40301027, -0.83648734], [-1.20515198, -1.02689695], [-0.27410027, -
- 0.54194484], [0.8381014, 0.44085498], [-0.23374509, 0.18370049], [-
- 0.32635887, -0.29299653], [-0.00288378, 0.84259929], [1.79580611, -
- 0.02219234]])
+X = np.array(
+ [
+ [0.20622591, 0.0582794],
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.2184254, 0.24299982],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.06514042, -0.0770537],
+ [0.97407872, 0.44454207],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.27410027, -0.54194484],
+ [0.8381014, 0.44085498],
+ [-0.23374509, 0.18370049],
+ [-0.32635887, -0.29299653],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ ]
+)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
-R_TOL = 0.0001
+R_TOL = 1e-4
+
+
+def test_sample_regular():
+ smote = SMOTETomek(random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.23374509, 0.18370049],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ [0.38307743, -0.05670439],
+ [0.70319159, -0.02571667],
+ [0.75052536, -0.19246518],
+ ]
+ )
+ y_gt = np.array([1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_sample_regular_half():
+ sampling_strategy = {0: 9, 1: 12}
+ smote = SMOTETomek(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.68481731, 0.51935141],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.23374509, 0.18370049],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ [0.45784496, -0.1053161],
+ ]
+ )
+ y_gt = np.array([1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_validate_estimator_init():
+ smote = SMOTE(random_state=RND_SEED)
+ tomek = TomekLinks(sampling_strategy="all")
+ smt = SMOTETomek(smote=smote, tomek=tomek, random_state=RND_SEED)
+ X_resampled, y_resampled = smt.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.23374509, 0.18370049],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ [0.38307743, -0.05670439],
+ [0.70319159, -0.02571667],
+ [0.75052536, -0.19246518],
+ ]
+ )
+ y_gt = np.array([1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_validate_estimator_default():
+ smt = SMOTETomek(random_state=RND_SEED)
+ X_resampled, y_resampled = smt.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.23374509, 0.18370049],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ [0.38307743, -0.05670439],
+ [0.70319159, -0.02571667],
+ [0.75052536, -0.19246518],
+ ]
+ )
+ y_gt = np.array([1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_parallelisation():
+ # Check if default job count is None
+ smt = SMOTETomek(random_state=RND_SEED)
+ smt._validate_estimator()
+ assert smt.n_jobs is None
+ assert smt.smote_.n_jobs is None
+ assert smt.tomek_.n_jobs is None
+
+ # Check if job count is set
+ smt = SMOTETomek(random_state=RND_SEED, n_jobs=8)
+ smt._validate_estimator()
+ assert smt.n_jobs == 8
+ assert smt.smote_.n_jobs == 8
+ assert smt.tomek_.n_jobs == 8
diff --git a/imblearn/datasets/_imbalance.py b/imblearn/datasets/_imbalance.py
index 53e40de..9e6e512 100644
--- a/imblearn/datasets/_imbalance.py
+++ b/imblearn/datasets/_imbalance.py
@@ -1,17 +1,31 @@
"""Transform a dataset into an imbalanced dataset."""
+
+# Authors: Dayvid Oliveira
+# Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
from collections.abc import Mapping
+
from ..under_sampling import RandomUnderSampler
from ..utils import check_sampling_strategy
from ..utils._param_validation import validate_params
-@validate_params({'X': ['array-like'], 'y': ['array-like'],
- 'sampling_strategy': [Mapping, callable, None], 'random_state': [
- 'random_state'], 'verbose': ['boolean']}, prefer_skip_nested_validation
- =True)
-def make_imbalance(X, y, *, sampling_strategy=None, random_state=None,
- verbose=False, **kwargs):
+@validate_params(
+ {
+ "X": ["array-like"],
+ "y": ["array-like"],
+ "sampling_strategy": [Mapping, callable, None],
+ "random_state": ["random_state"],
+ "verbose": ["boolean"],
+ },
+ prefer_skip_nested_validation=True,
+)
+def make_imbalance(
+ X, y, *, sampling_strategy=None, random_state=None, verbose=False, **kwargs
+):
"""Turn a dataset into an imbalanced dataset with a specific sampling strategy.
A simple toy dataset to visualize clustering and classification
@@ -82,4 +96,22 @@ def make_imbalance(X, y, *, sampling_strategy=None, random_state=None,
>>> print(f'Distribution after imbalancing: {Counter(y_res)}')
Distribution after imbalancing: Counter({2: 30, 1: 20, 0: 10})
"""
- pass
+ target_stats = Counter(y)
+ # restrict ratio to be a dict or a callable
+ if isinstance(sampling_strategy, Mapping) or callable(sampling_strategy):
+ sampling_strategy_ = check_sampling_strategy(
+ sampling_strategy, y, "under-sampling", **kwargs
+ )
+
+ if verbose:
+ print(f"The original target distribution in the dataset is: {target_stats}")
+ rus = RandomUnderSampler(
+ sampling_strategy=sampling_strategy_,
+ replacement=False,
+ random_state=random_state,
+ )
+ X_resampled, y_resampled = rus.fit_resample(X, y)
+ if verbose:
+ print(f"Make the dataset imbalanced: {Counter(y_resampled)}")
+
+ return X_resampled, y_resampled
diff --git a/imblearn/datasets/_zenodo.py b/imblearn/datasets/_zenodo.py
index f9062b9..a73ef37 100644
--- a/imblearn/datasets/_zenodo.py
+++ b/imblearn/datasets/_zenodo.py
@@ -39,25 +39,57 @@ References
Imbalanced Data Learning and their Application in Bioinformatics."
Dissertation, Georgia State University, (2011).
"""
+
+# Author: Guillaume Lemaitre
+# License: BSD 3 clause
+
import tarfile
from collections import OrderedDict
from io import BytesIO
from os import makedirs
from os.path import isfile, join
from urllib.request import urlopen
+
import numpy as np
from sklearn.datasets import get_data_home
from sklearn.utils import Bunch, check_random_state
+
from ..utils._param_validation import validate_params
-URL = 'https://zenodo.org/record/61452/files/benchmark-imbalanced-learn.tar.gz'
-PRE_FILENAME = 'x'
-POST_FILENAME = 'data.npz'
-MAP_NAME_ID_KEYS = ['ecoli', 'optical_digits', 'satimage', 'pen_digits',
- 'abalone', 'sick_euthyroid', 'spectrometer', 'car_eval_34', 'isolet',
- 'us_crime', 'yeast_ml8', 'scene', 'libras_move', 'thyroid_sick',
- 'coil_2000', 'arrhythmia', 'solar_flare_m0', 'oil', 'car_eval_4',
- 'wine_quality', 'letter_img', 'yeast_me2', 'webpage', 'ozone_level',
- 'mammography', 'protein_homo', 'abalone_19']
+
+URL = "https://zenodo.org/record/61452/files/benchmark-imbalanced-learn.tar.gz"
+PRE_FILENAME = "x"
+POST_FILENAME = "data.npz"
+
+MAP_NAME_ID_KEYS = [
+ "ecoli",
+ "optical_digits",
+ "satimage",
+ "pen_digits",
+ "abalone",
+ "sick_euthyroid",
+ "spectrometer",
+ "car_eval_34",
+ "isolet",
+ "us_crime",
+ "yeast_ml8",
+ "scene",
+ "libras_move",
+ "thyroid_sick",
+ "coil_2000",
+ "arrhythmia",
+ "solar_flare_m0",
+ "oil",
+ "car_eval_4",
+ "wine_quality",
+ "letter_img",
+ "yeast_me2",
+ "webpage",
+ "ozone_level",
+ "mammography",
+ "protein_homo",
+ "abalone_19",
+]
+
MAP_NAME_ID = OrderedDict()
MAP_ID_NAME = OrderedDict()
for v, k in enumerate(MAP_NAME_ID_KEYS):
@@ -65,12 +97,26 @@ for v, k in enumerate(MAP_NAME_ID_KEYS):
MAP_ID_NAME[v + 1] = k
-@validate_params({'data_home': [None, str], 'filter_data': [None, tuple],
- 'download_if_missing': ['boolean'], 'random_state': ['random_state'],
- 'shuffle': ['boolean'], 'verbose': ['boolean']},
- prefer_skip_nested_validation=True)
-def fetch_datasets(*, data_home=None, filter_data=None, download_if_missing
- =True, random_state=None, shuffle=False, verbose=False):
+@validate_params(
+ {
+ "data_home": [None, str],
+ "filter_data": [None, tuple],
+ "download_if_missing": ["boolean"],
+ "random_state": ["random_state"],
+ "shuffle": ["boolean"],
+ "verbose": ["boolean"],
+ },
+ prefer_skip_nested_validation=True,
+)
+def fetch_datasets(
+ *,
+ data_home=None,
+ filter_data=None,
+ download_if_missing=True,
+ random_state=None,
+ shuffle=False,
+ verbose=False,
+):
"""Load the benchmark datasets from Zenodo, downloading it if necessary.
.. versionadded:: 0.3
@@ -185,4 +231,68 @@ def fetch_datasets(*, data_home=None, filter_data=None, download_if_missing
Imbalanced Data Learning and their Application in Bioinformatics."
Dissertation, Georgia State University, (2011).
"""
- pass
+
+ data_home = get_data_home(data_home=data_home)
+ zenodo_dir = join(data_home, "zenodo")
+ datasets = OrderedDict()
+
+ if filter_data is None:
+ filter_data_ = MAP_NAME_ID.keys()
+ else:
+ list_data = MAP_NAME_ID.keys()
+ filter_data_ = []
+ for it in filter_data:
+ if isinstance(it, str):
+ if it not in list_data:
+ raise ValueError(
+ f"{it} is not a dataset available. "
+ f"The available datasets are {list_data}"
+ )
+ else:
+ filter_data_.append(it)
+ elif isinstance(it, int):
+ if it < 1 or it > 27:
+ raise ValueError(
+ f"The dataset with the ID={it} is not an "
+ f"available dataset. The IDs are "
+ f"{range(1, 28)}"
+ )
+ else:
+ # The index start at one, then we need to remove one
+ # to not have issue with the indexing.
+ filter_data_.append(MAP_ID_NAME[it])
+ else:
+ raise ValueError(
+ f"The value in the tuple should be str or int."
+ f" Got {type(it)} instead."
+ )
+
+ # go through the list and check if the data are available
+ for it in filter_data_:
+ filename = PRE_FILENAME + str(MAP_NAME_ID[it]) + POST_FILENAME
+ filename = join(zenodo_dir, filename)
+ available = isfile(filename)
+
+ if download_if_missing and not available:
+ makedirs(zenodo_dir, exist_ok=True)
+ if verbose:
+ print("Downloading %s" % URL)
+ f = BytesIO(urlopen(URL).read())
+ tar = tarfile.open(fileobj=f)
+ tar.extractall(path=zenodo_dir)
+ elif not download_if_missing and not available:
+ raise IOError("Data not found and `download_if_missing` is False")
+
+ data = np.load(filename)
+ X, y = data["data"], data["label"]
+
+ if shuffle:
+ ind = np.arange(X.shape[0])
+ rng = check_random_state(random_state)
+ rng.shuffle(ind)
+ X = X[ind]
+ y = y[ind]
+
+ datasets[it] = Bunch(data=X, target=y, DESCR=it)
+
+ return datasets
diff --git a/imblearn/datasets/tests/test_imbalance.py b/imblearn/datasets/tests/test_imbalance.py
index 1067628..ac3b417 100644
--- a/imblearn/datasets/tests/test_imbalance.py
+++ b/imblearn/datasets/tests/test_imbalance.py
@@ -1,6 +1,80 @@
"""Test the module easy ensemble."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
+
import numpy as np
import pytest
from sklearn.datasets import load_iris
+
from imblearn.datasets import make_imbalance
+
+
+@pytest.fixture
+def iris():
+ return load_iris(return_X_y=True)
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, err_msg",
+ [
+ ({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"),
+ ({0: 10, 1: 70}, "should be less or equal to the original"),
+ ],
+)
+def test_make_imbalance_error(iris, sampling_strategy, err_msg):
+ # we are reusing part of utils.check_sampling_strategy, however this is not
+ # cover in the common tests so we will repeat it here
+ X, y = iris
+ with pytest.raises(ValueError, match=err_msg):
+ make_imbalance(X, y, sampling_strategy=sampling_strategy)
+
+
+def test_make_imbalance_error_single_class(iris):
+ X, y = iris
+ y = np.zeros_like(y)
+ with pytest.raises(ValueError, match="needs to have more than 1 class."):
+ make_imbalance(X, y, sampling_strategy={0: 10})
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, expected_counts",
+ [
+ ({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}),
+ ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50}),
+ ],
+)
+def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
+ X, y = iris
+ _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy)
+ assert Counter(y_) == expected_counts
+
+
+@pytest.mark.parametrize("as_frame", [True, False], ids=["dataframe", "array"])
+@pytest.mark.parametrize(
+ "sampling_strategy, expected_counts",
+ [
+ (
+ {"setosa": 10, "versicolor": 20, "virginica": 30},
+ {"setosa": 10, "versicolor": 20, "virginica": 30},
+ ),
+ (
+ {"setosa": 10, "versicolor": 20},
+ {"setosa": 10, "versicolor": 20, "virginica": 50},
+ ),
+ ],
+)
+def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts):
+ pd = pytest.importorskip("pandas")
+ iris = load_iris(as_frame=as_frame)
+ X, y = iris.data, iris.target
+ y = iris.target_names[iris.target]
+ if as_frame:
+ y = pd.Series(iris.target_names[iris.target], name="target")
+ X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy)
+ if as_frame:
+ assert hasattr(X_res, "loc")
+ pd.testing.assert_index_equal(X_res.index, y_res.index)
+ assert Counter(y_res) == expected_counts
diff --git a/imblearn/datasets/tests/test_zenodo.py b/imblearn/datasets/tests/test_zenodo.py
index b9c2288..3854fd2 100644
--- a/imblearn/datasets/tests/test_zenodo.py
+++ b/imblearn/datasets/tests/test_zenodo.py
@@ -2,17 +2,97 @@
Skipped if datasets is not already downloaded to data_home.
"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import pytest
from sklearn.utils._testing import SkipTest
+
from imblearn.datasets import fetch_datasets
-DATASET_SHAPE = {'ecoli': (336, 7), 'optical_digits': (5620, 64),
- 'satimage': (6435, 36), 'pen_digits': (10992, 16), 'abalone': (4177, 10
- ), 'sick_euthyroid': (3163, 42), 'spectrometer': (531, 93),
- 'car_eval_34': (1728, 21), 'isolet': (7797, 617), 'us_crime': (1994,
- 100), 'yeast_ml8': (2417, 103), 'scene': (2407, 294), 'libras_move': (
- 360, 90), 'thyroid_sick': (3772, 52), 'coil_2000': (9822, 85),
- 'arrhythmia': (452, 278), 'solar_flare_m0': (1389, 32), 'oil': (937, 49
- ), 'car_eval_4': (1728, 21), 'wine_quality': (4898, 11), 'letter_img':
- (20000, 16), 'yeast_me2': (1484, 8), 'webpage': (34780, 300),
- 'ozone_level': (2536, 72), 'mammography': (11183, 6), 'protein_homo': (
- 145751, 74), 'abalone_19': (4177, 10)}
+
+DATASET_SHAPE = {
+ "ecoli": (336, 7),
+ "optical_digits": (5620, 64),
+ "satimage": (6435, 36),
+ "pen_digits": (10992, 16),
+ "abalone": (4177, 10),
+ "sick_euthyroid": (3163, 42),
+ "spectrometer": (531, 93),
+ "car_eval_34": (1728, 21),
+ "isolet": (7797, 617),
+ "us_crime": (1994, 100),
+ "yeast_ml8": (2417, 103),
+ "scene": (2407, 294),
+ "libras_move": (360, 90),
+ "thyroid_sick": (3772, 52),
+ "coil_2000": (9822, 85),
+ "arrhythmia": (452, 278),
+ "solar_flare_m0": (1389, 32),
+ "oil": (937, 49),
+ "car_eval_4": (1728, 21),
+ "wine_quality": (4898, 11),
+ "letter_img": (20000, 16),
+ "yeast_me2": (1484, 8),
+ "webpage": (34780, 300),
+ "ozone_level": (2536, 72),
+ "mammography": (11183, 6),
+ "protein_homo": (145751, 74),
+ "abalone_19": (4177, 10),
+}
+
+
+def fetch(*args, **kwargs):
+ return fetch_datasets(*args, download_if_missing=True, **kwargs)
+
+
+@pytest.mark.xfail
+def test_fetch():
+ try:
+ datasets1 = fetch(shuffle=True, random_state=42)
+ except IOError:
+ raise SkipTest("Zenodo dataset can not be loaded.")
+
+ datasets2 = fetch(shuffle=True, random_state=37)
+
+ for k in DATASET_SHAPE.keys():
+ X1, X2 = datasets1[k].data, datasets2[k].data
+ assert DATASET_SHAPE[k] == X1.shape
+ assert X1.shape == X2.shape
+
+ y1, y2 = datasets1[k].target, datasets2[k].target
+ assert (X1.shape[0],) == y1.shape
+ assert (X1.shape[0],) == y2.shape
+
+
+def test_fetch_filter():
+ try:
+ datasets1 = fetch(filter_data=tuple([1]), shuffle=True, random_state=42)
+ except IOError:
+ raise SkipTest("Zenodo dataset can not be loaded.")
+
+ datasets2 = fetch(filter_data=tuple(["ecoli"]), shuffle=True, random_state=37)
+
+ X1, X2 = datasets1["ecoli"].data, datasets2["ecoli"].data
+ assert DATASET_SHAPE["ecoli"] == X1.shape
+ assert X1.shape == X2.shape
+
+ assert X1.sum() == pytest.approx(X2.sum())
+
+ y1, y2 = datasets1["ecoli"].target, datasets2["ecoli"].target
+ assert (X1.shape[0],) == y1.shape
+ assert (X1.shape[0],) == y2.shape
+
+
+@pytest.mark.parametrize(
+ "filter_data, err_msg",
+ [
+ (("rnf",), "is not a dataset available"),
+ ((-1,), "dataset with the ID="),
+ ((100,), "dataset with the ID="),
+ ((1.00,), "value in the tuple"),
+ ],
+)
+def test_fetch_error(filter_data, err_msg):
+ with pytest.raises(ValueError, match=err_msg):
+ fetch_datasets(filter_data=filter_data)
diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py
index b1905ed..acb0c70 100644
--- a/imblearn/ensemble/_bagging.py
+++ b/imblearn/ensemble/_bagging.py
@@ -1,7 +1,13 @@
"""Bagging classifier trained on balanced bootstrap samples."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import copy
import numbers
import warnings
+
import numpy as np
import sklearn
from sklearn.base import clone
@@ -12,11 +18,14 @@ from sklearn.exceptions import NotFittedError
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted
+
try:
+ # scikit-learn >= 1.2
from sklearn.utils.parallel import Parallel, delayed
except (ImportError, ModuleNotFoundError):
from joblib import Parallel
from sklearn.utils.fixes import delayed
+
from ..base import _ParamsValidationMixin
from ..pipeline import Pipeline
from ..under_sampling import RandomUnderSampler
@@ -27,12 +36,15 @@ from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils.fixes import _fit_context
from ._common import _bagging_parameter_constraints, _estimator_has
+
sklearn_version = parse_version(sklearn.__version__)
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
"""A Bagging classifier with additional balancing.
@@ -235,43 +247,108 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
[[ 23 0]
[ 2 225]]
"""
- if sklearn_version >= parse_version('1.4'):
- _parameter_constraints = copy.deepcopy(BaggingClassifier.
- _parameter_constraints)
+
+ # make a deepcopy to not modify the original dictionary
+ if sklearn_version >= parse_version("1.4"):
+ _parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
else:
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
- _parameter_constraints.update({'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), dict, callable],
- 'replacement': ['boolean'], 'sampler': [HasMethods(['fit_resample']
- ), None]})
- if 'base_estimator' in _parameter_constraints:
- del _parameter_constraints['base_estimator']
-
- def __init__(self, estimator=None, n_estimators=10, *, max_samples=1.0,
- max_features=1.0, bootstrap=True, bootstrap_features=False,
- oob_score=False, warm_start=False, sampling_strategy='auto',
- replacement=False, n_jobs=None, random_state=None, verbose=0,
- sampler=None):
- super().__init__(n_estimators=n_estimators, max_samples=max_samples,
- max_features=max_features, bootstrap=bootstrap,
- bootstrap_features=bootstrap_features, oob_score=oob_score,
- warm_start=warm_start, n_jobs=n_jobs, random_state=random_state,
- verbose=verbose)
+
+ _parameter_constraints.update(
+ {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ dict,
+ callable,
+ ],
+ "replacement": ["boolean"],
+ "sampler": [HasMethods(["fit_resample"]), None],
+ }
+ )
+ # TODO: remove when minimum supported version of scikit-learn is 1.4
+ if "base_estimator" in _parameter_constraints:
+ del _parameter_constraints["base_estimator"]
+
+ def __init__(
+ self,
+ estimator=None,
+ n_estimators=10,
+ *,
+ max_samples=1.0,
+ max_features=1.0,
+ bootstrap=True,
+ bootstrap_features=False,
+ oob_score=False,
+ warm_start=False,
+ sampling_strategy="auto",
+ replacement=False,
+ n_jobs=None,
+ random_state=None,
+ verbose=0,
+ sampler=None,
+ ):
+ super().__init__(
+ n_estimators=n_estimators,
+ max_samples=max_samples,
+ max_features=max_features,
+ bootstrap=bootstrap,
+ bootstrap_features=bootstrap_features,
+ oob_score=oob_score,
+ warm_start=warm_start,
+ n_jobs=n_jobs,
+ random_state=random_state,
+ verbose=verbose,
+ )
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
self.sampler = sampler
+ def _validate_y(self, y):
+ y_encoded = super()._validate_y(y)
+ if (
+ isinstance(self.sampling_strategy, dict)
+ and self.sampler_._sampling_type != "bypass"
+ ):
+ self._sampling_strategy = {
+ np.where(self.classes_ == key)[0][0]: value
+ for key, value in check_sampling_strategy(
+ self.sampling_strategy,
+ y,
+ self.sampler_._sampling_type,
+ ).items()
+ }
+ else:
+ self._sampling_strategy = self.sampling_strategy
+ return y_encoded
+
def _validate_estimator(self, default=DecisionTreeClassifier()):
"""Check the estimator and the n_estimator attribute, set the
`estimator_` attribute."""
- pass
+ if self.estimator is not None:
+ estimator = clone(self.estimator)
+ else:
+ estimator = clone(default)
+ if self.sampler_._sampling_type != "bypass":
+ self.sampler_.set_params(sampling_strategy=self._sampling_strategy)
+
+ self.estimator_ = Pipeline(
+ [("sampler", self.sampler_), ("classifier", estimator)]
+ )
+
+ # TODO: remove when supporting scikit-learn>=1.2
@property
def n_features_(self):
"""Number of features when ``fit`` is performed."""
- pass
+ warnings.warn(
+ "`n_features_` was deprecated in scikit-learn 1.0. This attribute will "
+ "not be accessible when the minimum supported version of scikit-learn "
+ "is 1.2.",
+ FutureWarning,
+ )
+ return self.n_features_in_
@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X, y):
@@ -292,9 +369,27 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
self : object
Fitted estimator.
"""
- pass
-
- @available_if(_estimator_has('decision_function'))
+ # overwrite the base class method by disallowing `sample_weight`
+ self._validate_params()
+ return super().fit(X, y)
+
+ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
+ check_target_type(y)
+ # the sampler needs to be validated before to call _fit because
+ # _validate_y is called before _validate_estimator and would require
+ # to know which type of sampler we are using.
+ if self.sampler is None:
+ self.sampler_ = RandomUnderSampler(
+ replacement=self.replacement,
+ )
+ else:
+ self.sampler_ = clone(self.sampler)
+ # RandomUnderSampler is not supporting sample_weight. We need to pass
+ # None.
+ return super()._fit(X, y, self.max_samples)
+
+ # TODO: remove when minimum supported version of scikit-learn is 1.1
+ @available_if(_estimator_has("decision_function"))
def decision_function(self, X):
"""Average of the decision functions of the base classifiers.
@@ -312,9 +407,57 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
``classes_``. Regression and binary classification are special
cases with ``k == 1``, otherwise ``k==n_classes``.
"""
- pass
+ check_is_fitted(self)
+
+ # Check data
+ X = self._validate_data(
+ X,
+ accept_sparse=["csr", "csc"],
+ dtype=None,
+ force_all_finite=False,
+ reset=False,
+ )
+
+ # Parallel loop
+ n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs)
+
+ all_decisions = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
+ delayed(_parallel_decision_function)(
+ self.estimators_[starts[i] : starts[i + 1]],
+ self.estimators_features_[starts[i] : starts[i + 1]],
+ X,
+ )
+ for i in range(n_jobs)
+ )
+
+ # Reduce
+ decisions = sum(all_decisions) / self.n_estimators
+
+ return decisions
@property
def base_estimator_(self):
"""Attribute for older sklearn version compatibility."""
- pass
+ error = AttributeError(
+ f"{self.__class__.__name__} object has no attribute 'base_estimator_'."
+ )
+ if sklearn_version < parse_version("1.2"):
+ # The base class require to have the attribute defined. For scikit-learn
+ # > 1.2, we are going to raise an error.
+ try:
+ check_is_fitted(self)
+ return self.estimator_
+ except NotFittedError:
+ raise error
+ raise error
+
+ def _more_tags(self):
+ tags = super()._more_tags()
+ tags_key = "_xfail_checks"
+ failing_test = "check_estimators_nan_inf"
+ reason = "Fails because the sampler removed infinity and NaN values"
+ if tags_key in tags:
+ tags[tags_key][failing_test] = reason
+ else:
+ tags[tags_key] = {failing_test: reason}
+ return tags
diff --git a/imblearn/ensemble/_common.py b/imblearn/ensemble/_common.py
index f7dcb6e..588fa5e 100644
--- a/imblearn/ensemble/_common.py
+++ b/imblearn/ensemble/_common.py
@@ -1,6 +1,14 @@
from numbers import Integral, Real
+
from sklearn.tree._criterion import Criterion
-from ..utils._param_validation import HasMethods, Hidden, Interval, RealNotInt, StrOptions
+
+from ..utils._param_validation import (
+ HasMethods,
+ Hidden,
+ Interval,
+ RealNotInt,
+ StrOptions,
+)
def _estimator_has(attr):
@@ -8,41 +16,90 @@ def _estimator_has(attr):
First, we check the first fitted estimator if available, otherwise we
check the estimator attribute.
"""
- pass
-
-
-_bagging_parameter_constraints = {'estimator': [HasMethods(['fit',
- 'predict']), None], 'n_estimators': [Interval(Integral, 1, None, closed
- ='left')], 'max_samples': [Interval(Integral, 1, None, closed='left'),
- Interval(RealNotInt, 0, 1, closed='right')], 'max_features': [Interval(
- Integral, 1, None, closed='left'), Interval(RealNotInt, 0, 1, closed=
- 'right')], 'bootstrap': ['boolean'], 'bootstrap_features': ['boolean'],
- 'oob_score': ['boolean'], 'warm_start': ['boolean'], 'n_jobs': [None,
- Integral], 'random_state': ['random_state'], 'verbose': ['verbose'],
- 'base_estimator': [HasMethods(['fit', 'predict']), StrOptions({
- 'deprecated'}), None]}
-_adaboost_classifier_parameter_constraints = {'estimator': [HasMethods([
- 'fit', 'predict']), None], 'n_estimators': [Interval(Integral, 1, None,
- closed='left')], 'learning_rate': [Interval(Real, 0, None, closed=
- 'neither')], 'random_state': ['random_state'], 'base_estimator': [
- HasMethods(['fit', 'predict']), StrOptions({'deprecated'})],
- 'algorithm': [StrOptions({'SAMME', 'SAMME.R'})]}
-_random_forest_classifier_parameter_constraints = {'n_estimators': [
- Interval(Integral, 1, None, closed='left')], 'bootstrap': ['boolean'],
- 'oob_score': ['boolean'], 'n_jobs': [Integral, None], 'random_state': [
- 'random_state'], 'verbose': ['verbose'], 'warm_start': ['boolean'],
- 'criterion': [StrOptions({'gini', 'entropy', 'log_loss'}), Hidden(
- Criterion)], 'max_samples': [None, Interval(Real, 0.0, 1.0, closed=
- 'right'), Interval(Integral, 1, None, closed='left')], 'max_depth': [
- Interval(Integral, 1, None, closed='left'), None], 'min_samples_split':
- [Interval(Integral, 2, None, closed='left'), Interval(RealNotInt, 0.0,
- 1.0, closed='right')], 'min_samples_leaf': [Interval(Integral, 1, None,
- closed='left'), Interval(RealNotInt, 0.0, 1.0, closed='neither')],
- 'min_weight_fraction_leaf': [Interval(Real, 0.0, 0.5, closed='both')],
- 'max_features': [Interval(Integral, 1, None, closed='left'), Interval(
- RealNotInt, 0.0, 1.0, closed='right'), StrOptions({'sqrt', 'log2'}),
- None], 'max_leaf_nodes': [Interval(Integral, 2, None, closed='left'),
- None], 'min_impurity_decrease': [Interval(Real, 0.0, None, closed=
- 'left')], 'ccp_alpha': [Interval(Real, 0.0, None, closed='left')],
- 'class_weight': [StrOptions({'balanced_subsample', 'balanced'}), dict,
- list, None], 'monotonic_cst': ['array-like', None]}
+
+ def check(self):
+ if hasattr(self, "estimators_"):
+ return hasattr(self.estimators_[0], attr)
+ elif self.estimator is not None:
+ return hasattr(self.estimator, attr)
+ else: # TODO(1.4): Remove when the base_estimator deprecation cycle ends
+ return hasattr(self.base_estimator, attr)
+
+ return check
+
+
+_bagging_parameter_constraints = {
+ "estimator": [HasMethods(["fit", "predict"]), None],
+ "n_estimators": [Interval(Integral, 1, None, closed="left")],
+ "max_samples": [
+ Interval(Integral, 1, None, closed="left"),
+ Interval(RealNotInt, 0, 1, closed="right"),
+ ],
+ "max_features": [
+ Interval(Integral, 1, None, closed="left"),
+ Interval(RealNotInt, 0, 1, closed="right"),
+ ],
+ "bootstrap": ["boolean"],
+ "bootstrap_features": ["boolean"],
+ "oob_score": ["boolean"],
+ "warm_start": ["boolean"],
+ "n_jobs": [None, Integral],
+ "random_state": ["random_state"],
+ "verbose": ["verbose"],
+ "base_estimator": [
+ HasMethods(["fit", "predict"]),
+ StrOptions({"deprecated"}),
+ None,
+ ],
+}
+
+_adaboost_classifier_parameter_constraints = {
+ "estimator": [HasMethods(["fit", "predict"]), None],
+ "n_estimators": [Interval(Integral, 1, None, closed="left")],
+ "learning_rate": [Interval(Real, 0, None, closed="neither")],
+ "random_state": ["random_state"],
+ "base_estimator": [HasMethods(["fit", "predict"]), StrOptions({"deprecated"})],
+ "algorithm": [StrOptions({"SAMME", "SAMME.R"})],
+}
+
+_random_forest_classifier_parameter_constraints = {
+ "n_estimators": [Interval(Integral, 1, None, closed="left")],
+ "bootstrap": ["boolean"],
+ "oob_score": ["boolean"],
+ "n_jobs": [Integral, None],
+ "random_state": ["random_state"],
+ "verbose": ["verbose"],
+ "warm_start": ["boolean"],
+ "criterion": [StrOptions({"gini", "entropy", "log_loss"}), Hidden(Criterion)],
+ "max_samples": [
+ None,
+ Interval(Real, 0.0, 1.0, closed="right"),
+ Interval(Integral, 1, None, closed="left"),
+ ],
+ "max_depth": [Interval(Integral, 1, None, closed="left"), None],
+ "min_samples_split": [
+ Interval(Integral, 2, None, closed="left"),
+ Interval(RealNotInt, 0.0, 1.0, closed="right"),
+ ],
+ "min_samples_leaf": [
+ Interval(Integral, 1, None, closed="left"),
+ Interval(RealNotInt, 0.0, 1.0, closed="neither"),
+ ],
+ "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
+ "max_features": [
+ Interval(Integral, 1, None, closed="left"),
+ Interval(RealNotInt, 0.0, 1.0, closed="right"),
+ StrOptions({"sqrt", "log2"}),
+ None,
+ ],
+ "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
+ "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")],
+ "ccp_alpha": [Interval(Real, 0.0, None, closed="left")],
+ "class_weight": [
+ StrOptions({"balanced_subsample", "balanced"}),
+ dict,
+ list,
+ None,
+ ],
+ "monotonic_cst": ["array-like", None],
+}
diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py
index 6d79bd6..e3c8574 100644
--- a/imblearn/ensemble/_easy_ensemble.py
+++ b/imblearn/ensemble/_easy_ensemble.py
@@ -1,7 +1,13 @@
"""Class to perform under-sampling using easy ensemble."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import copy
import numbers
import warnings
+
import numpy as np
import sklearn
from sklearn.base import clone
@@ -12,11 +18,14 @@ from sklearn.exceptions import NotFittedError
from sklearn.utils._tags import _safe_tags
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted
+
try:
+ # scikit-learn >= 1.2
from sklearn.utils.parallel import Parallel, delayed
except (ImportError, ModuleNotFoundError):
from joblib import Parallel
from sklearn.utils.fixes import delayed
+
from ..base import _ParamsValidationMixin
from ..pipeline import Pipeline
from ..under_sampling import RandomUnderSampler
@@ -27,13 +36,16 @@ from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ._common import _bagging_parameter_constraints, _estimator_has
+
MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier):
"""Bag of balanced boosted learners also known as EasyEnsemble.
@@ -160,43 +172,106 @@ class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier):
[[ 23 0]
[ 2 225]]
"""
- if sklearn_version >= parse_version('1.4'):
- _parameter_constraints = copy.deepcopy(BaggingClassifier.
- _parameter_constraints)
+
+ # make a deepcopy to not modify the original dictionary
+ if sklearn_version >= parse_version("1.4"):
+ _parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
else:
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
- excluded_params = {'bootstrap', 'bootstrap_features', 'max_features',
- 'oob_score', 'max_samples'}
+
+ excluded_params = {
+ "bootstrap",
+ "bootstrap_features",
+ "max_features",
+ "oob_score",
+ "max_samples",
+ }
for param in excluded_params:
_parameter_constraints.pop(param, None)
- _parameter_constraints.update({'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), dict, callable],
- 'replacement': ['boolean']})
- if 'base_estimator' in _parameter_constraints:
- del _parameter_constraints['base_estimator']
-
- def __init__(self, n_estimators=10, estimator=None, *, warm_start=False,
- sampling_strategy='auto', replacement=False, n_jobs=None,
- random_state=None, verbose=0):
- super().__init__(n_estimators=n_estimators, max_samples=1.0,
- max_features=1.0, bootstrap=False, bootstrap_features=False,
- oob_score=False, warm_start=warm_start, n_jobs=n_jobs,
- random_state=random_state, verbose=verbose)
+
+ _parameter_constraints.update(
+ {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ dict,
+ callable,
+ ],
+ "replacement": ["boolean"],
+ }
+ )
+ # TODO: remove when minimum supported version of scikit-learn is 1.4
+ if "base_estimator" in _parameter_constraints:
+ del _parameter_constraints["base_estimator"]
+
+ def __init__(
+ self,
+ n_estimators=10,
+ estimator=None,
+ *,
+ warm_start=False,
+ sampling_strategy="auto",
+ replacement=False,
+ n_jobs=None,
+ random_state=None,
+ verbose=0,
+ ):
+ super().__init__(
+ n_estimators=n_estimators,
+ max_samples=1.0,
+ max_features=1.0,
+ bootstrap=False,
+ bootstrap_features=False,
+ oob_score=False,
+ warm_start=warm_start,
+ n_jobs=n_jobs,
+ random_state=random_state,
+ verbose=verbose,
+ )
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
- def _validate_estimator(self, default=AdaBoostClassifier(algorithm='SAMME')
- ):
+ def _validate_y(self, y):
+ y_encoded = super()._validate_y(y)
+ if isinstance(self.sampling_strategy, dict):
+ self._sampling_strategy = {
+ np.where(self.classes_ == key)[0][0]: value
+ for key, value in check_sampling_strategy(
+ self.sampling_strategy,
+ y,
+ "under-sampling",
+ ).items()
+ }
+ else:
+ self._sampling_strategy = self.sampling_strategy
+ return y_encoded
+
+ def _validate_estimator(self, default=AdaBoostClassifier(algorithm="SAMME")):
"""Check the estimator and the n_estimator attribute, set the
`estimator_` attribute."""
- pass
-
+ if self.estimator is not None:
+ estimator = clone(self.estimator)
+ else:
+ estimator = clone(default)
+
+ sampler = RandomUnderSampler(
+ sampling_strategy=self._sampling_strategy,
+ replacement=self.replacement,
+ )
+ self.estimator_ = Pipeline([("sampler", sampler), ("classifier", estimator)])
+
+ # TODO: remove when supporting scikit-learn>=1.2
@property
def n_features_(self):
"""Number of features when ``fit`` is performed."""
- pass
+ warnings.warn(
+ "`n_features_` was deprecated in scikit-learn 1.0. This attribute will "
+ "not be accessible when the minimum supported version of scikit-learn "
+ "is 1.2.",
+ FutureWarning,
+ )
+ return self.n_features_in_
@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X, y):
@@ -217,9 +292,18 @@ class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier):
self : object
Fitted estimator.
"""
- pass
-
- @available_if(_estimator_has('decision_function'))
+ self._validate_params()
+ # overwrite the base class method by disallowing `sample_weight`
+ return super().fit(X, y)
+
+ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
+ check_target_type(y)
+ # RandomUnderSampler is not supporting sample_weight. We need to pass
+ # None.
+ return super()._fit(X, y, self.max_samples)
+
+ # TODO: remove when minimum supported version of scikit-learn is 1.1
+ @available_if(_estimator_has("decision_function"))
def decision_function(self, X):
"""Average of the decision functions of the base classifiers.
@@ -237,9 +321,55 @@ class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier):
``classes_``. Regression and binary classification are special
cases with ``k == 1``, otherwise ``k==n_classes``.
"""
- pass
+ check_is_fitted(self)
+
+ # Check data
+ X = self._validate_data(
+ X,
+ accept_sparse=["csr", "csc"],
+ dtype=None,
+ force_all_finite=False,
+ reset=False,
+ )
+
+ # Parallel loop
+ n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs)
+
+ all_decisions = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
+ delayed(_parallel_decision_function)(
+ self.estimators_[starts[i] : starts[i + 1]],
+ self.estimators_features_[starts[i] : starts[i + 1]],
+ X,
+ )
+ for i in range(n_jobs)
+ )
+
+ # Reduce
+ decisions = sum(all_decisions) / self.n_estimators
+
+ return decisions
@property
def base_estimator_(self):
"""Attribute for older sklearn version compatibility."""
- pass
+ error = AttributeError(
+ f"{self.__class__.__name__} object has no attribute 'base_estimator_'."
+ )
+ if sklearn_version < parse_version("1.2"):
+ # The base class require to have the attribute defined. For scikit-learn
+ # > 1.2, we are going to raise an error.
+ try:
+ check_is_fitted(self)
+ return self.estimator_
+ except NotFittedError:
+ raise error
+ raise error
+
+ def _get_estimator(self):
+ if self.estimator is None:
+ return AdaBoostClassifier(algorithm="SAMME")
+ return self.estimator
+
+ # TODO: remove when minimum supported version of scikit-learn is 1.5
+ def _more_tags(self):
+ return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py
index 6a8de5d..5f8d08e 100644
--- a/imblearn/ensemble/_forest.py
+++ b/imblearn/ensemble/_forest.py
@@ -1,7 +1,12 @@
"""Forest classifiers trained on balanced boostrasp samples."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import numbers
from copy import deepcopy
from warnings import warn
+
import numpy as np
import sklearn
from numpy import float32 as DTYPE
@@ -10,18 +15,25 @@ from scipy.sparse import issparse
from sklearn.base import clone, is_classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
-from sklearn.ensemble._forest import _generate_unsampled_indices, _get_n_samples_bootstrap, _parallel_build_trees
+from sklearn.ensemble._forest import (
+ _generate_unsampled_indices,
+ _get_n_samples_bootstrap,
+ _parallel_build_trees,
+)
from sklearn.exceptions import DataConversionWarning
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing, check_random_state
from sklearn.utils.fixes import parse_version
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _check_sample_weight
+
try:
+ # scikit-learn >= 1.2
from sklearn.utils.parallel import Parallel, delayed
except (ImportError, ModuleNotFoundError):
from joblib import Parallel
from sklearn.utils.fixes import delayed
+
from ..base import _ParamsValidationMixin
from ..pipeline import make_pipeline
from ..under_sampling import RandomUnderSampler
@@ -31,13 +43,69 @@ from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._validation import check_sampling_strategy
from ..utils.fixes import _fit_context
from ._common import _random_forest_classifier_parameter_constraints
+
MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
-@Substitution(n_jobs=_n_jobs_docstring, random_state=_random_state_docstring)
-class BalancedRandomForestClassifier(_ParamsValidationMixin,
- RandomForestClassifier):
+def _local_parallel_build_trees(
+ sampler,
+ tree,
+ bootstrap,
+ X,
+ y,
+ sample_weight,
+ tree_idx,
+ n_trees,
+ verbose=0,
+ class_weight=None,
+ n_samples_bootstrap=None,
+ forest=None,
+ missing_values_in_feature_mask=None,
+):
+ # resample before to fit the tree
+ X_resampled, y_resampled = sampler.fit_resample(X, y)
+ if sample_weight is not None:
+ sample_weight = _safe_indexing(sample_weight, sampler.sample_indices_)
+ if _get_n_samples_bootstrap is not None:
+ n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])
+
+ params_parallel_build_trees = {
+ "tree": tree,
+ "X": X_resampled,
+ "y": y_resampled,
+ "sample_weight": sample_weight,
+ "tree_idx": tree_idx,
+ "n_trees": n_trees,
+ "verbose": verbose,
+ "class_weight": class_weight,
+ "n_samples_bootstrap": n_samples_bootstrap,
+ }
+
+ if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
+ # support for missing values
+ params_parallel_build_trees[
+ "missing_values_in_feature_mask"
+ ] = missing_values_in_feature_mask
+
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.1
+ # change of signature in scikit-learn 1.1
+ if parse_version(sklearn_version.base_version) >= parse_version("1.1"):
+ params_parallel_build_trees["bootstrap"] = bootstrap
+ else:
+ params_parallel_build_trees["forest"] = forest
+
+ tree = _parallel_build_trees(**params_parallel_build_trees)
+
+ return sampler, tree
+
+
+@Substitution(
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
+class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassifier):
"""A balanced random forest classifier.
A balanced random forest differs from a classical random forest by the
@@ -85,7 +153,8 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
the input samples) required to be at a leaf node. Samples have
equal weight when sample_weight is not provided.
- max_features : {{"auto", "sqrt", "log2"}}, int, float, or None, default="sqrt"
+ max_features : {{"auto", "sqrt", "log2"}}, int, float, or None, \
+ default="sqrt"
The number of features to consider when looking for the best split:
- If int, then consider `max_features` features at each split.
@@ -195,7 +264,8 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.
- class_weight : dict, list of dicts, {{"balanced", "balanced_subsample"}}, default=None
+ class_weight : dict, list of dicts, {{"balanced", "balanced_subsample"}}, \
+ default=None
Weights associated with classes in the form dictionary with the key
being the class_label and the value the weight.
If not given, all classes are supposed to have weight one. For
@@ -355,59 +425,124 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
... 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
[1]
"""
- if sklearn_version >= parse_version('1.4'):
- _parameter_constraints = deepcopy(RandomForestClassifier.
- _parameter_constraints)
+
+ # make a deepcopy to not modify the original dictionary
+ if sklearn_version >= parse_version("1.4"):
+ _parameter_constraints = deepcopy(RandomForestClassifier._parameter_constraints)
else:
_parameter_constraints = deepcopy(
- _random_forest_classifier_parameter_constraints)
- _parameter_constraints.update({'bootstrap': ['boolean', Hidden(
- StrOptions({'warn'}))], 'sampling_strategy': [Interval(numbers.Real,
- 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), dict, callable, Hidden(
- StrOptions({'warn'}))], 'replacement': ['boolean', Hidden(
- StrOptions({'warn'}))]})
-
- def __init__(self, n_estimators=100, *, criterion='gini', max_depth=
- None, min_samples_split=2, min_samples_leaf=1,
- min_weight_fraction_leaf=0.0, max_features='sqrt', max_leaf_nodes=
- None, min_impurity_decrease=0.0, bootstrap='warn', oob_score=False,
- sampling_strategy='warn', replacement='warn', n_jobs=None,
- random_state=None, verbose=0, warm_start=False, class_weight=None,
- ccp_alpha=0.0, max_samples=None, monotonic_cst=None):
- params_random_forest = {'criterion': criterion, 'max_depth':
- max_depth, 'n_estimators': n_estimators, 'bootstrap': bootstrap,
- 'oob_score': oob_score, 'n_jobs': n_jobs, 'random_state':
- random_state, 'verbose': verbose, 'warm_start': warm_start,
- 'class_weight': class_weight, 'min_samples_split':
- min_samples_split, 'min_samples_leaf': min_samples_leaf,
- 'min_weight_fraction_leaf': min_weight_fraction_leaf,
- 'max_features': max_features, 'max_leaf_nodes': max_leaf_nodes,
- 'min_impurity_decrease': min_impurity_decrease, 'ccp_alpha':
- ccp_alpha, 'max_samples': max_samples}
- if parse_version(sklearn_version.base_version) >= parse_version('1.4'):
- params_random_forest['monotonic_cst'] = monotonic_cst
+ _random_forest_classifier_parameter_constraints
+ )
+
+ _parameter_constraints.update(
+ {
+ "bootstrap": ["boolean", Hidden(StrOptions({"warn"}))],
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ dict,
+ callable,
+ Hidden(StrOptions({"warn"})),
+ ],
+ "replacement": ["boolean", Hidden(StrOptions({"warn"}))],
+ }
+ )
+
+ def __init__(
+ self,
+ n_estimators=100,
+ *,
+ criterion="gini",
+ max_depth=None,
+ min_samples_split=2,
+ min_samples_leaf=1,
+ min_weight_fraction_leaf=0.0,
+ max_features="sqrt",
+ max_leaf_nodes=None,
+ min_impurity_decrease=0.0,
+ bootstrap="warn",
+ oob_score=False,
+ sampling_strategy="warn",
+ replacement="warn",
+ n_jobs=None,
+ random_state=None,
+ verbose=0,
+ warm_start=False,
+ class_weight=None,
+ ccp_alpha=0.0,
+ max_samples=None,
+ monotonic_cst=None,
+ ):
+ params_random_forest = {
+ "criterion": criterion,
+ "max_depth": max_depth,
+ "n_estimators": n_estimators,
+ "bootstrap": bootstrap,
+ "oob_score": oob_score,
+ "n_jobs": n_jobs,
+ "random_state": random_state,
+ "verbose": verbose,
+ "warm_start": warm_start,
+ "class_weight": class_weight,
+ "min_samples_split": min_samples_split,
+ "min_samples_leaf": min_samples_leaf,
+ "min_weight_fraction_leaf": min_weight_fraction_leaf,
+ "max_features": max_features,
+ "max_leaf_nodes": max_leaf_nodes,
+ "min_impurity_decrease": min_impurity_decrease,
+ "ccp_alpha": ccp_alpha,
+ "max_samples": max_samples,
+ }
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
+ if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
+ # use scikit-learn support for monotonic constraints
+ params_random_forest["monotonic_cst"] = monotonic_cst
else:
if monotonic_cst is not None:
raise ValueError(
- 'Monotonic constraints are not supported for scikit-learn version < 1.4.'
- )
+ "Monotonic constraints are not supported for scikit-learn "
+ "version < 1.4."
+ )
+ # create an attribute for compatibility with other scikit-learn tools such
+ # as HTML representation.
self.monotonic_cst = monotonic_cst
super().__init__(**params_random_forest)
+
self.sampling_strategy = sampling_strategy
self.replacement = replacement
def _validate_estimator(self, default=DecisionTreeClassifier()):
"""Check the estimator and the n_estimator attribute, set the
`estimator_` attribute."""
- pass
+ if hasattr(self, "estimator"):
+ base_estimator = self.estimator
+ else:
+ base_estimator = self.base_estimator
+
+ if base_estimator is not None:
+ self.estimator_ = clone(base_estimator)
+ else:
+ self.estimator_ = clone(default)
+
+ self.base_sampler_ = RandomUnderSampler(
+ sampling_strategy=self._sampling_strategy,
+ replacement=self._replacement,
+ )
def _make_sampler_estimator(self, random_state=None):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
- pass
+ estimator = clone(self.estimator_)
+ estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
+ sampler = clone(self.base_sampler_)
+
+ if random_state is not None:
+ _set_random_states(estimator, random_state)
+ _set_random_states(sampler, random_state)
+
+ return estimator, sampler
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y, sample_weight=None):
@@ -436,7 +571,238 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
self : object
The fitted instance.
"""
- pass
+ self._validate_params()
+ # TODO: remove in 0.13
+ if self.sampling_strategy == "warn":
+ warn(
+ "The default of `sampling_strategy` will change from `'auto'` to "
+ "`'all'` in version 0.13. This change will follow the implementation "
+ "proposed in the original paper. Set to `'all'` to silence this "
+ "warning and adopt the future behaviour.",
+ FutureWarning,
+ )
+ self._sampling_strategy = "auto"
+ else:
+ self._sampling_strategy = self.sampling_strategy
+
+ if self.replacement == "warn":
+ warn(
+ "The default of `replacement` will change from `False` to "
+ "`True` in version 0.13. This change will follow the implementation "
+ "proposed in the original paper. Set to `True` to silence this "
+ "warning and adopt the future behaviour.",
+ FutureWarning,
+ )
+ self._replacement = False
+ else:
+ self._replacement = self.replacement
+
+ if self.bootstrap == "warn":
+ warn(
+ "The default of `bootstrap` will change from `True` to "
+ "`False` in version 0.13. This change will follow the implementation "
+ "proposed in the original paper. Set to `False` to silence this "
+ "warning and adopt the future behaviour.",
+ FutureWarning,
+ )
+ self._bootstrap = True
+ else:
+ self._bootstrap = self.bootstrap
+
+ # Validate or convert input data
+ if issparse(y):
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
+
+ # TODO: remove when the minimum supported version of scipy will be 1.4
+ # Support for missing values
+ if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
+ force_all_finite = False
+ else:
+ force_all_finite = True
+
+ X, y = self._validate_data(
+ X,
+ y,
+ multi_output=True,
+ accept_sparse="csc",
+ dtype=DTYPE,
+ force_all_finite=force_all_finite,
+ )
+
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
+ if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
+ # _compute_missing_values_in_feature_mask checks if X has missing values and
+ # will raise an error if the underlying tree base estimator can't handle
+ # missing values. Only the criterion is required to determine if the tree
+ # supports missing values.
+ estimator = type(self.estimator)(criterion=self.criterion)
+ missing_values_in_feature_mask = (
+ estimator._compute_missing_values_in_feature_mask(
+ X, estimator_name=self.__class__.__name__
+ )
+ )
+ else:
+ missing_values_in_feature_mask = None
+
+ if sample_weight is not None:
+ sample_weight = _check_sample_weight(sample_weight, X)
+
+ self._n_features = X.shape[1]
+
+ if issparse(X):
+ # Pre-sort indices to avoid that each individual tree of the
+ # ensemble sorts the indices.
+ X.sort_indices()
+
+ y = np.atleast_1d(y)
+ if y.ndim == 2 and y.shape[1] == 1:
+ warn(
+ "A column-vector y was passed when a 1d array was"
+ " expected. Please change the shape of y to "
+ "(n_samples,), for example using ravel().",
+ DataConversionWarning,
+ stacklevel=2,
+ )
+
+ if y.ndim == 1:
+ # reshape is necessary to preserve the data contiguity against vs
+ # [:, np.newaxis] that does not.
+ y = np.reshape(y, (-1, 1))
+
+ self.n_outputs_ = y.shape[1]
+
+ y_encoded, expanded_class_weight = self._validate_y_class_weight(y)
+
+ if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
+ y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE)
+
+ if isinstance(self._sampling_strategy, dict):
+ self._sampling_strategy = {
+ np.where(self.classes_[0] == key)[0][0]: value
+ for key, value in check_sampling_strategy(
+ self.sampling_strategy,
+ y,
+ "under-sampling",
+ ).items()
+ }
+ else:
+ self._sampling_strategy = self._sampling_strategy
+
+ if expanded_class_weight is not None:
+ if sample_weight is not None:
+ sample_weight = sample_weight * expanded_class_weight
+ else:
+ sample_weight = expanded_class_weight
+
+ # Get bootstrap sample size
+ n_samples_bootstrap = _get_n_samples_bootstrap(
+ n_samples=X.shape[0], max_samples=self.max_samples
+ )
+
+ # Check parameters
+ self._validate_estimator()
+
+ if not self._bootstrap and self.oob_score:
+ raise ValueError("Out of bag estimation only available if bootstrap=True")
+
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start or not hasattr(self, "estimators_"):
+ # Free allocated memory, if any
+ self.estimators_ = []
+ self.samplers_ = []
+ self.pipelines_ = []
+
+ n_more_estimators = self.n_estimators - len(self.estimators_)
+
+ if n_more_estimators < 0:
+ raise ValueError(
+ "n_estimators=%d must be larger or equal to "
+ "len(estimators_)=%d when warm_start==True"
+ % (self.n_estimators, len(self.estimators_))
+ )
+
+ elif n_more_estimators == 0:
+ warn(
+ "Warm-start fitting without increasing n_estimators does not "
+ "fit new trees."
+ )
+ else:
+ if self.warm_start and len(self.estimators_) > 0:
+ # We draw from the random state to get the random state we
+ # would have got if we hadn't used a warm_start.
+ random_state.randint(MAX_INT, size=len(self.estimators_))
+
+ trees = []
+ samplers = []
+ for _ in range(n_more_estimators):
+ tree, sampler = self._make_sampler_estimator(random_state=random_state)
+ trees.append(tree)
+ samplers.append(sampler)
+
+ # Parallel loop: we prefer the threading backend as the Cython code
+ # for fitting the trees is internally releasing the Python GIL
+ # making threading more efficient than multiprocessing in
+ # that case. However, we respect any parallel_backend contexts set
+ # at a higher level, since correctness does not rely on using
+ # threads.
+ samplers_trees = Parallel(
+ n_jobs=self.n_jobs,
+ verbose=self.verbose,
+ prefer="threads",
+ )(
+ delayed(_local_parallel_build_trees)(
+ s,
+ t,
+ self._bootstrap,
+ X,
+ y_encoded,
+ sample_weight,
+ i,
+ len(trees),
+ verbose=self.verbose,
+ class_weight=self.class_weight,
+ n_samples_bootstrap=n_samples_bootstrap,
+ forest=self,
+ missing_values_in_feature_mask=missing_values_in_feature_mask,
+ )
+ for i, (s, t) in enumerate(zip(samplers, trees))
+ )
+ samplers, trees = zip(*samplers_trees)
+
+ # Collect newly grown trees
+ self.estimators_.extend(trees)
+ self.samplers_.extend(samplers)
+
+ # Create pipeline with the fitted samplers and trees
+ self.pipelines_.extend(
+ [
+ make_pipeline(deepcopy(s), deepcopy(t))
+ for s, t in zip(samplers, trees)
+ ]
+ )
+
+ if self.oob_score:
+ y_type = type_of_target(y)
+ if y_type in ("multiclass-multioutput", "unknown"):
+ # FIXME: we could consider to support multiclass-multioutput if
+ # we introduce or reuse a constructor parameter (e.g.
+ # oob_score) allowing our user to pass a callable defining the
+ # scoring strategy on OOB sample.
+ raise ValueError(
+ "The type of target cannot be used to compute OOB "
+ f"estimates. Got {y_type} while only the following are "
+ "supported: continuous, continuous-multioutput, binary, "
+ "multiclass, multilabel-indicator."
+ )
+ self._set_oob_score_and_attributes(X, y_encoded)
+
+ # Decapsulate classes_ attributes
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
+ self.n_classes_ = self.n_classes_[0]
+ self.classes_ = self.classes_[0]
+
+ return self
def _set_oob_score_and_attributes(self, X, y):
"""Compute and set the OOB score and attributes.
@@ -448,7 +814,15 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
y : ndarray of shape (n_samples, n_outputs)
The target matrix.
"""
- pass
+ self.oob_decision_function_ = self._compute_oob_predictions(X, y)
+ if self.oob_decision_function_.shape[-1] == 1:
+ # drop the n_outputs axis if there is a single output
+ self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1)
+ from sklearn.metrics import accuracy_score
+
+ self.oob_score_ = accuracy_score(
+ y, np.argmax(self.oob_decision_function_, axis=1)
+ )
def _compute_oob_predictions(self, X, y):
"""Compute and set the OOB score.
@@ -462,12 +836,79 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin,
Returns
-------
- oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or (n_samples, 1, n_outputs)
+ oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \
+ (n_samples, 1, n_outputs)
The OOB predictions.
"""
- pass
-
+ # Prediction requires X to be in CSR format
+ if issparse(X):
+ X = X.tocsr()
+
+ n_samples = y.shape[0]
+ n_outputs = self.n_outputs_
+
+ if is_classifier(self) and hasattr(self, "n_classes_"):
+ # n_classes_ is a ndarray at this stage
+ # all the supported type of target will have the same number of
+ # classes in all outputs
+ oob_pred_shape = (n_samples, self.n_classes_[0], n_outputs)
+ else:
+ # for regression, n_classes_ does not exist and we create an empty
+ # axis to be consistent with the classification case and make
+ # the array operations compatible with the 2 settings
+ oob_pred_shape = (n_samples, 1, n_outputs)
+
+ oob_pred = np.zeros(shape=oob_pred_shape, dtype=np.float64)
+ n_oob_pred = np.zeros((n_samples, n_outputs), dtype=np.int64)
+
+ for sampler, estimator in zip(self.samplers_, self.estimators_):
+ X_resample = X[sampler.sample_indices_]
+ y_resample = y[sampler.sample_indices_]
+
+ n_sample_subset = y_resample.shape[0]
+ n_samples_bootstrap = _get_n_samples_bootstrap(
+ n_sample_subset, self.max_samples
+ )
+
+ unsampled_indices = _generate_unsampled_indices(
+ estimator.random_state, n_sample_subset, n_samples_bootstrap
+ )
+
+ y_pred = self._get_oob_predictions(
+ estimator, X_resample[unsampled_indices, :]
+ )
+
+ indices = sampler.sample_indices_[unsampled_indices]
+ oob_pred[indices, ...] += y_pred
+ n_oob_pred[indices, :] += 1
+
+ for k in range(n_outputs):
+ if (n_oob_pred == 0).any():
+ warn(
+ "Some inputs do not have OOB scores. This probably means "
+ "too few trees were used to compute any reliable OOB "
+ "estimates.",
+ UserWarning,
+ )
+ n_oob_pred[n_oob_pred == 0] = 1
+ oob_pred[..., k] /= n_oob_pred[..., [k]]
+
+ return oob_pred
+
+ # TODO: remove when supporting scikit-learn>=1.2
@property
def n_features_(self):
"""Number of features when ``fit`` is performed."""
- pass
+ warn(
+ "`n_features_` was deprecated in scikit-learn 1.0. This attribute will "
+ "not be accessible when the minimum supported version of scikit-learn "
+ "is 1.2.",
+ FutureWarning,
+ )
+ return self.n_features_in_
+
+ def _more_tags(self):
+ return {
+ "multioutput": False,
+ "multilabel": False,
+ }
diff --git a/imblearn/ensemble/_weight_boosting.py b/imblearn/ensemble/_weight_boosting.py
index 26f43c4..9da0225 100644
--- a/imblearn/ensemble/_weight_boosting.py
+++ b/imblearn/ensemble/_weight_boosting.py
@@ -1,6 +1,7 @@
import copy
import numbers
from copy import deepcopy
+
import numpy as np
import sklearn
from sklearn.base import clone
@@ -10,6 +11,7 @@ from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import has_fit_parameter
+
from ..base import _ParamsValidationMixin
from ..pipeline import make_pipeline
from ..under_sampling import RandomUnderSampler
@@ -19,11 +21,14 @@ from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ._common import _adaboost_classifier_parameter_constraints
+
sklearn_version = parse_version(sklearn.__version__)
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, random_state=_random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ random_state=_random_state_docstring,
+)
class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
"""Random under-sampling integrated in the learning of AdaBoost.
@@ -149,24 +154,49 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
>>> clf.predict(X)
array([...])
"""
- if sklearn_version >= parse_version('1.4'):
- _parameter_constraints = copy.deepcopy(AdaBoostClassifier.
- _parameter_constraints)
+
+ # make a deepcopy to not modify the original dictionary
+ if sklearn_version >= parse_version("1.4"):
+ _parameter_constraints = copy.deepcopy(
+ AdaBoostClassifier._parameter_constraints
+ )
else:
_parameter_constraints = copy.deepcopy(
- _adaboost_classifier_parameter_constraints)
- _parameter_constraints.update({'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), dict, callable],
- 'replacement': ['boolean']})
- if 'base_estimator' in _parameter_constraints:
- del _parameter_constraints['base_estimator']
-
- def __init__(self, estimator=None, *, n_estimators=50, learning_rate=
- 1.0, algorithm='SAMME.R', sampling_strategy='auto', replacement=
- False, random_state=None):
- super().__init__(n_estimators=n_estimators, learning_rate=
- learning_rate, algorithm=algorithm, random_state=random_state)
+ _adaboost_classifier_parameter_constraints
+ )
+
+ _parameter_constraints.update(
+ {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ dict,
+ callable,
+ ],
+ "replacement": ["boolean"],
+ }
+ )
+ # TODO: remove when minimum supported version of scikit-learn is 1.4
+ if "base_estimator" in _parameter_constraints:
+ del _parameter_constraints["base_estimator"]
+
+ def __init__(
+ self,
+ estimator=None,
+ *,
+ n_estimators=50,
+ learning_rate=1.0,
+ algorithm="SAMME.R",
+ sampling_strategy="auto",
+ replacement=False,
+ random_state=None,
+ ):
+ super().__init__(
+ n_estimators=n_estimators,
+ learning_rate=learning_rate,
+ algorithm=algorithm,
+ random_state=random_state,
+ )
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
@@ -193,26 +223,174 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
self : object
Returns self.
"""
- pass
+ self._validate_params()
+ check_target_type(y)
+ self.samplers_ = []
+ self.pipelines_ = []
+ super().fit(X, y, sample_weight)
+ return self
def _validate_estimator(self):
"""Check the estimator and the n_estimator attribute.
Sets the `estimator_` attributes.
"""
- pass
+ default = DecisionTreeClassifier(max_depth=1)
+ if self.estimator is not None:
+ self.estimator_ = clone(self.estimator)
+ else:
+ self.estimator_ = clone(default)
+
+ # SAMME-R requires predict_proba-enabled estimators
+ if self.algorithm == "SAMME.R":
+ if not hasattr(self.estimator_, "predict_proba"):
+ raise TypeError(
+ "AdaBoostClassifier with algorithm='SAMME.R' requires "
+ "that the weak learner supports the calculation of class "
+ "probabilities with a predict_proba method.\n"
+ "Please change the base estimator or set "
+ "algorithm='SAMME' instead."
+ )
+ if not has_fit_parameter(self.estimator_, "sample_weight"):
+ raise ValueError(
+ f"{self.estimator_.__class__.__name__} doesn't support sample_weight."
+ )
+
+ self.base_sampler_ = RandomUnderSampler(
+ sampling_strategy=self.sampling_strategy,
+ replacement=self.replacement,
+ )
def _make_sampler_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
- pass
+ estimator = clone(self.estimator_)
+ estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
+ sampler = clone(self.base_sampler_)
+
+ if random_state is not None:
+ _set_random_states(estimator, random_state)
+ _set_random_states(sampler, random_state)
+
+ if append:
+ self.estimators_.append(estimator)
+ self.samplers_.append(sampler)
+ self.pipelines_.append(
+ make_pipeline(deepcopy(sampler), deepcopy(estimator))
+ )
+
+ return estimator, sampler
def _boost_real(self, iboost, X, y, sample_weight, random_state):
"""Implement a single boost using the SAMME.R real algorithm."""
- pass
+ estimator, sampler = self._make_sampler_estimator(random_state=random_state)
+
+ X_res, y_res = sampler.fit_resample(X, y)
+ sample_weight_res = _safe_indexing(sample_weight, sampler.sample_indices_)
+ estimator.fit(X_res, y_res, sample_weight=sample_weight_res)
+
+ y_predict_proba = estimator.predict_proba(X)
+
+ if iboost == 0:
+ self.classes_ = getattr(estimator, "classes_", None)
+ self.n_classes_ = len(self.classes_)
+
+ y_predict = self.classes_.take(np.argmax(y_predict_proba, axis=1), axis=0)
+
+ # Instances incorrectly classified
+ incorrect = y_predict != y
+
+ # Error fraction
+ estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
+
+ # Stop if classification is perfect
+ if estimator_error <= 0:
+ return sample_weight, 1.0, 0.0
+
+ # Construct y coding as described in Zhu et al [2]:
+ #
+ # y_k = 1 if c == k else -1 / (K - 1)
+ #
+ # where K == n_classes_ and c, k in [0, K) are indices along the second
+ # axis of the y coding with c being the index corresponding to the true
+ # class label.
+ n_classes = self.n_classes_
+ classes = self.classes_
+ y_codes = np.array([-1.0 / (n_classes - 1), 1.0])
+ y_coding = y_codes.take(classes == y[:, np.newaxis])
+
+ # Displace zero probabilities so the log is defined.
+ # Also fix negative elements which may occur with
+ # negative sample weights.
+ proba = y_predict_proba # alias for readability
+ np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba)
+
+ # Boost weight using multi-class AdaBoost SAMME.R alg
+ estimator_weight = (
+ -1.0
+ * self.learning_rate
+ * ((n_classes - 1.0) / n_classes)
+ * (y_coding * np.log(y_predict_proba)).sum(axis=1)
+ )
+
+ # Only boost the weights if it will fit again
+ if not iboost == self.n_estimators - 1:
+ # Only boost positive weights
+ sample_weight *= np.exp(
+ estimator_weight * ((sample_weight > 0) | (estimator_weight < 0))
+ )
+
+ return sample_weight, 1.0, estimator_error
def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
"""Implement a single boost using the SAMME discrete algorithm."""
- pass
+ estimator, sampler = self._make_sampler_estimator(random_state=random_state)
+
+ X_res, y_res = sampler.fit_resample(X, y)
+ sample_weight_res = _safe_indexing(sample_weight, sampler.sample_indices_)
+ estimator.fit(X_res, y_res, sample_weight=sample_weight_res)
+
+ y_predict = estimator.predict(X)
+
+ if iboost == 0:
+ self.classes_ = getattr(estimator, "classes_", None)
+ self.n_classes_ = len(self.classes_)
+
+ # Instances incorrectly classified
+ incorrect = y_predict != y
+
+ # Error fraction
+ estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
+
+ # Stop if classification is perfect
+ if estimator_error <= 0:
+ return sample_weight, 1.0, 0.0
+
+ n_classes = self.n_classes_
+
+ # Stop if the error is at least as bad as random guessing
+ if estimator_error >= 1.0 - (1.0 / n_classes):
+ self.estimators_.pop(-1)
+ self.samplers_.pop(-1)
+ self.pipelines_.pop(-1)
+ if len(self.estimators_) == 0:
+ raise ValueError(
+ "BaseClassifier in AdaBoostClassifier "
+ "ensemble is worse than random, ensemble "
+ "can not be fit."
+ )
+ return None, None, None
+
+ # Boost weight using multi-class AdaBoost SAMME alg
+ estimator_weight = self.learning_rate * (
+ np.log((1.0 - estimator_error) / estimator_error) + np.log(n_classes - 1.0)
+ )
+
+ # Only boost the weights if I will fit again
+ if not iboost == self.n_estimators - 1:
+ # Only boost positive weights
+ sample_weight *= np.exp(estimator_weight * incorrect * (sample_weight > 0))
+
+ return sample_weight, estimator_weight, estimator_error
diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py
index 02d90ed..3825971 100644
--- a/imblearn/ensemble/tests/test_bagging.py
+++ b/imblearn/ensemble/tests/test_bagging.py
@@ -1,5 +1,10 @@
"""Test the module ensemble classifiers."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
+
import numpy as np
import pytest
import sklearn
@@ -12,23 +17,578 @@ from sklearn.model_selection import GridSearchCV, ParameterGrid, train_test_spli
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
-from sklearn.utils._testing import assert_allclose, assert_array_almost_equal, assert_array_equal
+from sklearn.utils._testing import (
+ assert_allclose,
+ assert_array_almost_equal,
+ assert_array_equal,
+)
from sklearn.utils.fixes import parse_version
+
from imblearn import FunctionSampler
from imblearn.datasets import make_imbalance
from imblearn.ensemble import BalancedBaggingClassifier
from imblearn.over_sampling import SMOTE, RandomOverSampler
from imblearn.pipeline import make_pipeline
from imblearn.under_sampling import ClusterCentroids, RandomUnderSampler
+
sklearn_version = parse_version(sklearn.__version__)
iris = load_iris()
+@pytest.mark.parametrize(
+ "estimator",
+ [
+ None,
+ DummyClassifier(strategy="prior"),
+ Perceptron(max_iter=1000, tol=1e-3),
+ DecisionTreeClassifier(),
+ KNeighborsClassifier(),
+ SVC(gamma="scale"),
+ ],
+)
+@pytest.mark.parametrize(
+ "params",
+ ParameterGrid(
+ {
+ "max_samples": [0.5, 1.0],
+ "max_features": [1, 2, 4],
+ "bootstrap": [True, False],
+ "bootstrap_features": [True, False],
+ }
+ ),
+)
+def test_balanced_bagging_classifier(estimator, params):
+ # Check classification for various parameter settings.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ bag = BalancedBaggingClassifier(estimator=estimator, random_state=0, **params).fit(
+ X_train, y_train
+ )
+ bag.predict(X_test)
+ bag.predict_proba(X_test)
+ bag.score(X_test, y_test)
+ if hasattr(estimator, "decision_function"):
+ bag.decision_function(X_test)
+
+
+def test_bootstrap_samples():
+ # Test that bootstrapping samples generate non-perfect base estimators.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ estimator = DecisionTreeClassifier().fit(X_train, y_train)
+
+ # without bootstrap, all trees are perfect on the training set
+ # disable the resampling by passing an empty dictionary.
+ ensemble = BalancedBaggingClassifier(
+ estimator=DecisionTreeClassifier(),
+ max_samples=1.0,
+ bootstrap=False,
+ n_estimators=10,
+ sampling_strategy={},
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ assert ensemble.score(X_train, y_train) == estimator.score(X_train, y_train)
+
+ # with bootstrap, trees are no longer perfect on the training set
+ ensemble = BalancedBaggingClassifier(
+ estimator=DecisionTreeClassifier(),
+ max_samples=1.0,
+ bootstrap=True,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ assert ensemble.score(X_train, y_train) < estimator.score(X_train, y_train)
+
+
+def test_bootstrap_features():
+ # Test that bootstrapping features may generate duplicate features.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ ensemble = BalancedBaggingClassifier(
+ estimator=DecisionTreeClassifier(),
+ max_features=1.0,
+ bootstrap_features=False,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ for features in ensemble.estimators_features_:
+ assert np.unique(features).shape[0] == X.shape[1]
+
+ ensemble = BalancedBaggingClassifier(
+ estimator=DecisionTreeClassifier(),
+ max_features=1.0,
+ bootstrap_features=True,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ unique_features = [
+ np.unique(features).shape[0] for features in ensemble.estimators_features_
+ ]
+ assert np.median(unique_features) < X.shape[1]
+
+
+def test_probability():
+ # Predict probabilities.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ # Normal case
+ ensemble = BalancedBaggingClassifier(
+ estimator=DecisionTreeClassifier(), random_state=0
+ ).fit(X_train, y_train)
+
+ assert_array_almost_equal(
+ np.sum(ensemble.predict_proba(X_test), axis=1),
+ np.ones(len(X_test)),
+ )
+
+ assert_array_almost_equal(
+ ensemble.predict_proba(X_test),
+ np.exp(ensemble.predict_log_proba(X_test)),
+ )
+
+ # Degenerate case, where some classes are missing
+ ensemble = BalancedBaggingClassifier(
+ estimator=LogisticRegression(solver="lbfgs"),
+ random_state=0,
+ max_samples=5,
+ )
+ ensemble.fit(X_train, y_train)
+
+ assert_array_almost_equal(
+ np.sum(ensemble.predict_proba(X_test), axis=1),
+ np.ones(len(X_test)),
+ )
+
+ assert_array_almost_equal(
+ ensemble.predict_proba(X_test),
+ np.exp(ensemble.predict_log_proba(X_test)),
+ )
+
+
+def test_oob_score_classification():
+ # Check that oob prediction is a good estimation of the generalization
+ # error.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ for estimator in [DecisionTreeClassifier(), SVC(gamma="scale")]:
+ clf = BalancedBaggingClassifier(
+ estimator=estimator,
+ n_estimators=100,
+ bootstrap=True,
+ oob_score=True,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ test_score = clf.score(X_test, y_test)
+
+ assert abs(test_score - clf.oob_score_) < 0.1
+
+ # Test with few estimators
+ with pytest.warns(UserWarning):
+ BalancedBaggingClassifier(
+ estimator=estimator,
+ n_estimators=1,
+ bootstrap=True,
+ oob_score=True,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+
+def test_single_estimator():
+ # Check singleton ensembles.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ clf1 = BalancedBaggingClassifier(
+ estimator=KNeighborsClassifier(),
+ n_estimators=1,
+ bootstrap=False,
+ bootstrap_features=False,
+ random_state=0,
+ ).fit(X_train, y_train)
+
+ clf2 = make_pipeline(
+ RandomUnderSampler(random_state=clf1.estimators_[0].steps[0][1].random_state),
+ KNeighborsClassifier(),
+ ).fit(X_train, y_train)
+
+ assert_array_equal(clf1.predict(X_test), clf2.predict(X_test))
+
+
+def test_gridsearch():
+ # Check that bagging ensembles can be grid-searched.
+ # Transform iris into a binary classification task
+ X, y = iris.data, iris.target.copy()
+ y[y == 2] = 1
+
+ # Grid search with scoring based on decision_function
+ parameters = {"n_estimators": (1, 2), "estimator__C": (1, 2)}
+
+ GridSearchCV(
+ BalancedBaggingClassifier(SVC(gamma="scale")),
+ parameters,
+ cv=3,
+ scoring="roc_auc",
+ ).fit(X, y)
+
+
+def test_estimator():
+ # Check estimator and its default values.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ ensemble = BalancedBaggingClassifier(None, n_jobs=3, random_state=0).fit(
+ X_train, y_train
+ )
+
+ assert isinstance(ensemble.estimator_.steps[-1][1], DecisionTreeClassifier)
+
+ ensemble = BalancedBaggingClassifier(
+ DecisionTreeClassifier(), n_jobs=3, random_state=0
+ ).fit(X_train, y_train)
+
+ assert isinstance(ensemble.estimator_.steps[-1][1], DecisionTreeClassifier)
+
+ ensemble = BalancedBaggingClassifier(
+ Perceptron(max_iter=1000, tol=1e-3), n_jobs=3, random_state=0
+ ).fit(X_train, y_train)
+
+ assert isinstance(ensemble.estimator_.steps[-1][1], Perceptron)
+
+
+def test_bagging_with_pipeline():
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ estimator = BalancedBaggingClassifier(
+ make_pipeline(SelectKBest(k=1), DecisionTreeClassifier()),
+ max_features=2,
+ )
+ estimator.fit(X, y).predict(X)
+
+
+def test_warm_start(random_state=42):
+ # Test if fitting incrementally with warm start gives a forest of the
+ # right size and the same results as a normal fit.
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+
+ clf_ws = None
+ for n_estimators in [5, 10]:
+ if clf_ws is None:
+ clf_ws = BalancedBaggingClassifier(
+ n_estimators=n_estimators,
+ random_state=random_state,
+ warm_start=True,
+ )
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert len(clf_ws) == n_estimators
+
+ clf_no_ws = BalancedBaggingClassifier(
+ n_estimators=10, random_state=random_state, warm_start=False
+ )
+ clf_no_ws.fit(X, y)
+
+ assert {pipe.steps[-1][1].random_state for pipe in clf_ws} == {
+ pipe.steps[-1][1].random_state for pipe in clf_no_ws
+ }
+
+
+def test_warm_start_smaller_n_estimators():
+ # Test if warm start'ed second fit with smaller n_estimators raises error.
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True)
+ clf.fit(X, y)
+ clf.set_params(n_estimators=4)
+ with pytest.raises(ValueError):
+ clf.fit(X, y)
+
+
+def test_warm_start_equal_n_estimators():
+ # Test that nothing happens when fitting without increasing n_estimators
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43)
+
+ clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True, random_state=83)
+ clf.fit(X_train, y_train)
+
+ y_pred = clf.predict(X_test)
+ # modify X to nonsense values, this should not change anything
+ X_train += 1.0
+
+ warn_msg = "Warm-start fitting without increasing n_estimators does not"
+ with pytest.warns(UserWarning, match=warn_msg):
+ clf.fit(X_train, y_train)
+ assert_array_equal(y_pred, clf.predict(X_test))
+
+
+def test_warm_start_equivalence():
+ # warm started classifier with 5+5 estimators should be equivalent to
+ # one classifier with 10 estimators
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43)
+
+ clf_ws = BalancedBaggingClassifier(
+ n_estimators=5, warm_start=True, random_state=3141
+ )
+ clf_ws.fit(X_train, y_train)
+ clf_ws.set_params(n_estimators=10)
+ clf_ws.fit(X_train, y_train)
+ y1 = clf_ws.predict(X_test)
+
+ clf = BalancedBaggingClassifier(
+ n_estimators=10, warm_start=False, random_state=3141
+ )
+ clf.fit(X_train, y_train)
+ y2 = clf.predict(X_test)
+
+ assert_array_almost_equal(y1, y2)
+
+
+def test_warm_start_with_oob_score_fails():
+ # Check using oob_score and warm_start simultaneously fails
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True, oob_score=True)
+ with pytest.raises(ValueError):
+ clf.fit(X, y)
+
+
+def test_oob_score_removed_on_warm_start():
+ X, y = make_hastie_10_2(n_samples=2000, random_state=1)
+
+ clf = BalancedBaggingClassifier(n_estimators=50, oob_score=True)
+ clf.fit(X, y)
+
+ clf.set_params(warm_start=True, oob_score=False, n_estimators=100)
+ clf.fit(X, y)
+
+ with pytest.raises(AttributeError):
+ getattr(clf, "oob_score_")
+
+
+def test_oob_score_consistency():
+ # Make sure OOB scores are identical when random_state, estimator, and
+ # training data are fixed and fitting is done twice
+ X, y = make_hastie_10_2(n_samples=200, random_state=1)
+ bagging = BalancedBaggingClassifier(
+ KNeighborsClassifier(),
+ max_samples=0.5,
+ max_features=0.5,
+ oob_score=True,
+ random_state=1,
+ )
+ assert bagging.fit(X, y).oob_score_ == bagging.fit(X, y).oob_score_
+
+
+def test_estimators_samples():
+ # Check that format of estimators_samples_ is correct and that results
+ # generated at fit time can be identically reproduced at a later time
+ # using data saved in object attributes.
+ X, y = make_hastie_10_2(n_samples=200, random_state=1)
+
+ # remap the y outside of the BalancedBaggingclassifier
+ # _, y = np.unique(y, return_inverse=True)
+ bagging = BalancedBaggingClassifier(
+ LogisticRegression(),
+ max_samples=0.5,
+ max_features=0.5,
+ random_state=1,
+ bootstrap=False,
+ )
+ bagging.fit(X, y)
+
+ # Get relevant attributes
+ estimators_samples = bagging.estimators_samples_
+ estimators_features = bagging.estimators_features_
+ estimators = bagging.estimators_
+
+ # Test for correct formatting
+ assert len(estimators_samples) == len(estimators)
+ assert len(estimators_samples[0]) == len(X) // 2
+ assert estimators_samples[0].dtype.kind == "i"
+
+ # Re-fit single estimator to test for consistent sampling
+ estimator_index = 0
+ estimator_samples = estimators_samples[estimator_index]
+ estimator_features = estimators_features[estimator_index]
+ estimator = estimators[estimator_index]
+
+ X_train = (X[estimator_samples])[:, estimator_features]
+ y_train = y[estimator_samples]
+
+ orig_coefs = estimator.steps[-1][1].coef_
+ estimator.fit(X_train, y_train)
+ new_coefs = estimator.steps[-1][1].coef_
+
+ assert_allclose(orig_coefs, new_coefs)
+
+
+def test_max_samples_consistency():
+ # Make sure validated max_samples and original max_samples are identical
+ # when valid integer max_samples supplied by user
+ max_samples = 100
+ X, y = make_hastie_10_2(n_samples=2 * max_samples, random_state=1)
+ bagging = BalancedBaggingClassifier(
+ KNeighborsClassifier(),
+ max_samples=max_samples,
+ max_features=0.5,
+ random_state=1,
+ )
+ bagging.fit(X, y)
+ assert bagging._max_samples == max_samples
+
+
class CountDecisionTreeClassifier(DecisionTreeClassifier):
"""DecisionTreeClassifier that will memorize the number of samples seen
at fit."""
+ def fit(self, X, y, sample_weight=None):
+ self.class_counts_ = Counter(y)
+ return super().fit(X, y, sample_weight=sample_weight)
+
+
+@pytest.mark.filterwarnings("ignore:Number of distinct clusters")
+@pytest.mark.parametrize(
+ "sampler, n_samples_bootstrap",
+ [
+ (None, 15),
+ (RandomUnderSampler(), 15), # under-sampling with sample_indices_
+ (
+ ClusterCentroids(estimator=KMeans(n_init=1)),
+ 15,
+ ), # under-sampling without sample_indices_
+ (RandomOverSampler(), 40), # over-sampling with sample_indices_
+ (SMOTE(), 40), # over-sampling without sample_indices_
+ ],
+)
+def test_balanced_bagging_classifier_samplers(sampler, n_samples_bootstrap):
+ # check that we can pass any kind of sampler to a bagging classifier
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+ clf = BalancedBaggingClassifier(
+ estimator=CountDecisionTreeClassifier(),
+ n_estimators=2,
+ sampler=sampler,
+ random_state=0,
+ )
+ clf.fit(X_train, y_train)
+ clf.predict(X_test)
+
+ # check that we have balanced class with the right counts of class
+ # sample depending on the sampling strategy
+ assert_array_equal(
+ list(clf.estimators_[0][-1].class_counts_.values()), n_samples_bootstrap
+ )
+
+
+@pytest.mark.parametrize("replace", [True, False])
+def test_balanced_bagging_classifier_with_function_sampler(replace):
+ # check that we can provide a FunctionSampler in BalancedBaggingClassifier
+ X, y = make_classification(
+ n_samples=1_000,
+ n_features=10,
+ n_classes=2,
+ weights=[0.3, 0.7],
+ random_state=0,
+ )
+
+ def roughly_balanced_bagging(X, y, replace=False):
+ """Implementation of Roughly Balanced Bagging for binary problem."""
+ # find the minority and majority classes
+ class_counts = Counter(y)
+ majority_class = max(class_counts, key=class_counts.get)
+ minority_class = min(class_counts, key=class_counts.get)
+
+ # compute the number of sample to draw from the majority class using
+ # a negative binomial distribution
+ n_minority_class = class_counts[minority_class]
+ n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5)
+
+ # draw randomly with or without replacement
+ majority_indices = np.random.choice(
+ np.flatnonzero(y == majority_class),
+ size=n_majority_resampled,
+ replace=replace,
+ )
+ minority_indices = np.random.choice(
+ np.flatnonzero(y == minority_class),
+ size=n_minority_class,
+ replace=replace,
+ )
+ indices = np.hstack([majority_indices, minority_indices])
+
+ return X[indices], y[indices]
+
+ # Roughly Balanced Bagging
+ rbb = BalancedBaggingClassifier(
+ estimator=CountDecisionTreeClassifier(random_state=0),
+ n_estimators=2,
+ sampler=FunctionSampler(
+ func=roughly_balanced_bagging, kw_args={"replace": replace}
+ ),
+ random_state=0,
+ )
+ rbb.fit(X, y)
+
+ for estimator in rbb.estimators_:
+ class_counts = estimator[-1].class_counts_
+ assert (class_counts[0] / class_counts[1]) > 0.78
+
def test_balanced_bagging_classifier_n_features():
"""Check that we raise a FutureWarning when accessing `n_features_`."""
- pass
+ X, y = load_iris(return_X_y=True)
+ estimator = BalancedBaggingClassifier().fit(X, y)
+ with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
+ estimator.n_features_
diff --git a/imblearn/ensemble/tests/test_easy_ensemble.py b/imblearn/ensemble/tests/test_easy_ensemble.py
index 3b667a8..7dc0441 100644
--- a/imblearn/ensemble/tests/test_easy_ensemble.py
+++ b/imblearn/ensemble/tests/test_easy_ensemble.py
@@ -1,4 +1,8 @@
"""Test the module easy ensemble."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
import sklearn
@@ -8,21 +12,224 @@ from sklearn.feature_selection import SelectKBest
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.utils._testing import assert_allclose, assert_array_equal
from sklearn.utils.fixes import parse_version
+
from imblearn.datasets import make_imbalance
from imblearn.ensemble import EasyEnsembleClassifier
from imblearn.pipeline import make_pipeline
from imblearn.under_sampling import RandomUnderSampler
+
sklearn_version = parse_version(sklearn.__version__)
iris = load_iris()
+
+# Generate a global dataset to use
RND_SEED = 0
-X = np.array([[0.5220963, 0.11349303], [0.59091459, 0.40692742], [
- 1.10915364, 0.05718352], [0.22039505, 0.26469445], [1.35269503,
- 0.44812421], [0.85117925, 1.0185556], [-2.10724436, 0.70263997], [-
- 0.23627356, 0.30254174], [-1.23195149, 0.15427291], [-0.58539673,
- 0.62515052]])
+X = np.array(
+ [
+ [0.5220963, 0.11349303],
+ [0.59091459, 0.40692742],
+ [1.10915364, 0.05718352],
+ [0.22039505, 0.26469445],
+ [1.35269503, 0.44812421],
+ [0.85117925, 1.0185556],
+ [-2.10724436, 0.70263997],
+ [-0.23627356, 0.30254174],
+ [-1.23195149, 0.15427291],
+ [-0.58539673, 0.62515052],
+ ]
+)
Y = np.array([1, 2, 2, 2, 1, 0, 1, 1, 1, 0])
+@pytest.mark.parametrize("n_estimators", [10, 20])
+@pytest.mark.parametrize(
+ "estimator",
+ [
+ AdaBoostClassifier(algorithm="SAMME", n_estimators=5),
+ AdaBoostClassifier(algorithm="SAMME", n_estimators=10),
+ ],
+)
+def test_easy_ensemble_classifier(n_estimators, estimator):
+ # Check classification for various parameter settings.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ eec = EasyEnsembleClassifier(
+ n_estimators=n_estimators,
+ estimator=estimator,
+ n_jobs=-1,
+ random_state=RND_SEED,
+ )
+ eec.fit(X_train, y_train).score(X_test, y_test)
+ assert len(eec.estimators_) == n_estimators
+ for est in eec.estimators_:
+ assert len(est.named_steps["classifier"]) == estimator.n_estimators
+ # test the different prediction function
+ eec.predict(X_test)
+ eec.predict_proba(X_test)
+ eec.predict_log_proba(X_test)
+ eec.decision_function(X_test)
+
+
+def test_estimator():
+ # Check estimator and its default values.
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ ensemble = EasyEnsembleClassifier(2, None, n_jobs=-1, random_state=0).fit(
+ X_train, y_train
+ )
+
+ assert isinstance(ensemble.estimator_.steps[-1][1], AdaBoostClassifier)
+
+ ensemble = EasyEnsembleClassifier(
+ 2, AdaBoostClassifier(algorithm="SAMME"), n_jobs=-1, random_state=0
+ ).fit(X_train, y_train)
+
+ assert isinstance(ensemble.estimator_.steps[-1][1], AdaBoostClassifier)
+
+
+def test_bagging_with_pipeline():
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ estimator = EasyEnsembleClassifier(
+ n_estimators=2,
+ estimator=make_pipeline(
+ SelectKBest(k=1), AdaBoostClassifier(algorithm="SAMME")
+ ),
+ )
+ estimator.fit(X, y).predict(X)
+
+
+def test_warm_start(random_state=42):
+ # Test if fitting incrementally with warm start gives a forest of the
+ # right size and the same results as a normal fit.
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+
+ clf_ws = None
+ for n_estimators in [5, 10]:
+ if clf_ws is None:
+ clf_ws = EasyEnsembleClassifier(
+ n_estimators=n_estimators,
+ random_state=random_state,
+ warm_start=True,
+ )
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert len(clf_ws) == n_estimators
+
+ clf_no_ws = EasyEnsembleClassifier(
+ n_estimators=10, random_state=random_state, warm_start=False
+ )
+ clf_no_ws.fit(X, y)
+
+ assert {pipe.steps[-1][1].random_state for pipe in clf_ws} == {
+ pipe.steps[-1][1].random_state for pipe in clf_no_ws
+ }
+
+
+def test_warm_start_smaller_n_estimators():
+ # Test if warm start'ed second fit with smaller n_estimators raises error.
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ clf = EasyEnsembleClassifier(n_estimators=5, warm_start=True)
+ clf.fit(X, y)
+ clf.set_params(n_estimators=4)
+ with pytest.raises(ValueError):
+ clf.fit(X, y)
+
+
+def test_warm_start_equal_n_estimators():
+ # Test that nothing happens when fitting without increasing n_estimators
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43)
+
+ clf = EasyEnsembleClassifier(n_estimators=5, warm_start=True, random_state=83)
+ clf.fit(X_train, y_train)
+
+ y_pred = clf.predict(X_test)
+ # modify X to nonsense values, this should not change anything
+ X_train += 1.0
+
+ warn_msg = "Warm-start fitting without increasing n_estimators"
+ with pytest.warns(UserWarning, match=warn_msg):
+ clf.fit(X_train, y_train)
+ assert_array_equal(y_pred, clf.predict(X_test))
+
+
+def test_warm_start_equivalence():
+ # warm started classifier with 5+5 estimators should be equivalent to
+ # one classifier with 10 estimators
+ X, y = make_hastie_10_2(n_samples=20, random_state=1)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43)
+
+ clf_ws = EasyEnsembleClassifier(n_estimators=5, warm_start=True, random_state=3141)
+ clf_ws.fit(X_train, y_train)
+ clf_ws.set_params(n_estimators=10)
+ clf_ws.fit(X_train, y_train)
+ y1 = clf_ws.predict(X_test)
+
+ clf = EasyEnsembleClassifier(n_estimators=10, warm_start=False, random_state=3141)
+ clf.fit(X_train, y_train)
+ y2 = clf.predict(X_test)
+
+ assert_allclose(y1, y2)
+
+
+def test_easy_ensemble_classifier_single_estimator():
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+
+ clf1 = EasyEnsembleClassifier(n_estimators=1, random_state=0).fit(X_train, y_train)
+ clf2 = make_pipeline(
+ RandomUnderSampler(random_state=0),
+ AdaBoostClassifier(algorithm="SAMME", random_state=0),
+ ).fit(X_train, y_train)
+
+ assert_array_equal(clf1.predict(X_test), clf2.predict(X_test))
+
+
+def test_easy_ensemble_classifier_grid_search():
+ X, y = make_imbalance(
+ iris.data,
+ iris.target,
+ sampling_strategy={0: 20, 1: 25, 2: 50},
+ random_state=0,
+ )
+
+ parameters = {
+ "n_estimators": [1, 2],
+ "estimator__n_estimators": [3, 4],
+ }
+ grid_search = GridSearchCV(
+ EasyEnsembleClassifier(estimator=AdaBoostClassifier(algorithm="SAMME")),
+ parameters,
+ cv=5,
+ )
+ grid_search.fit(X, y)
+
+
def test_easy_ensemble_classifier_n_features():
"""Check that we raise a FutureWarning when accessing `n_features_`."""
- pass
+ X, y = load_iris(return_X_y=True)
+ estimator = EasyEnsembleClassifier().fit(X, y)
+ with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
+ estimator.n_features_
diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py
index 2742293..3719568 100644
--- a/imblearn/ensemble/tests/test_forest.py
+++ b/imblearn/ensemble/tests/test_forest.py
@@ -5,32 +5,340 @@ from sklearn.datasets import load_iris, make_classification
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.utils._testing import assert_allclose, assert_array_equal
from sklearn.utils.fixes import parse_version
+
from imblearn.ensemble import BalancedRandomForestClassifier
+
sklearn_version = parse_version(sklearn.__version__)
+@pytest.fixture
+def imbalanced_dataset():
+ return make_classification(
+ n_samples=10000,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_classes=3,
+ n_clusters_per_class=1,
+ weights=[0.01, 0.05, 0.94],
+ class_sep=0.8,
+ random_state=0,
+ )
+
+
+def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
+ brf = BalancedRandomForestClassifier(
+ n_estimators=5, sampling_strategy="all", replacement=True, bootstrap=False
+ )
+ brf.fit(*imbalanced_dataset)
+
+ with pytest.raises(ValueError, match="must be larger or equal to"):
+ brf.set_params(warm_start=True, n_estimators=2)
+ brf.fit(*imbalanced_dataset)
+
+ brf.set_params(n_estimators=10)
+ brf.fit(*imbalanced_dataset)
+
+ with pytest.warns(UserWarning, match="Warm-start fitting without"):
+ brf.fit(*imbalanced_dataset)
+
+
+def test_balanced_random_forest(imbalanced_dataset):
+ n_estimators = 10
+ brf = BalancedRandomForestClassifier(
+ n_estimators=n_estimators,
+ random_state=0,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=False,
+ )
+ brf.fit(*imbalanced_dataset)
+
+ assert len(brf.samplers_) == n_estimators
+ assert len(brf.estimators_) == n_estimators
+ assert len(brf.pipelines_) == n_estimators
+ assert len(brf.feature_importances_) == imbalanced_dataset[0].shape[1]
+
+
+def test_balanced_random_forest_attributes(imbalanced_dataset):
+ X, y = imbalanced_dataset
+ n_estimators = 10
+ brf = BalancedRandomForestClassifier(
+ n_estimators=n_estimators,
+ random_state=0,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=False,
+ )
+ brf.fit(X, y)
+
+ for idx in range(n_estimators):
+ X_res, y_res = brf.samplers_[idx].fit_resample(X, y)
+ X_res_2, y_res_2 = (
+ brf.pipelines_[idx].named_steps["randomundersampler"].fit_resample(X, y)
+ )
+ assert_allclose(X_res, X_res_2)
+ assert_array_equal(y_res, y_res_2)
+
+ y_pred = brf.estimators_[idx].fit(X_res, y_res).predict(X)
+ y_pred_2 = brf.pipelines_[idx].fit(X, y).predict(X)
+ assert_array_equal(y_pred, y_pred_2)
+
+ y_pred = brf.estimators_[idx].fit(X_res, y_res).predict_proba(X)
+ y_pred_2 = brf.pipelines_[idx].fit(X, y).predict_proba(X)
+ assert_array_equal(y_pred, y_pred_2)
+
+
+def test_balanced_random_forest_sample_weight(imbalanced_dataset):
+ rng = np.random.RandomState(42)
+ X, y = imbalanced_dataset
+ sample_weight = rng.rand(y.shape[0])
+ brf = BalancedRandomForestClassifier(
+ n_estimators=5,
+ random_state=0,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=False,
+ )
+ brf.fit(X, y, sample_weight)
+
+
+@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
+def test_balanced_random_forest_oob(imbalanced_dataset):
+ X, y = imbalanced_dataset
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, random_state=42, stratify=y
+ )
+ est = BalancedRandomForestClassifier(
+ oob_score=True,
+ random_state=0,
+ n_estimators=1000,
+ min_samples_leaf=2,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=True,
+ )
+
+ est.fit(X_train, y_train)
+ test_score = est.score(X_test, y_test)
+
+ assert abs(test_score - est.oob_score_) < 0.1
+
+ # Check warning if not enough estimators
+ est = BalancedRandomForestClassifier(
+ oob_score=True,
+ random_state=0,
+ n_estimators=1,
+ bootstrap=True,
+ sampling_strategy="all",
+ replacement=True,
+ )
+ with pytest.warns(UserWarning) and np.errstate(divide="ignore", invalid="ignore"):
+ est.fit(X, y)
+
+
+def test_balanced_random_forest_grid_search(imbalanced_dataset):
+ brf = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True, bootstrap=False
+ )
+ grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3)
+ grid.fit(*imbalanced_dataset)
+
+
+def test_little_tree_with_small_max_samples():
+ rng = np.random.RandomState(1)
+
+ X = rng.randn(10000, 2)
+ y = rng.randn(10000) > 0
+
+ # First fit with no restriction on max samples
+ est1 = BalancedRandomForestClassifier(
+ n_estimators=1,
+ random_state=rng,
+ max_samples=None,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=True,
+ )
+
+ # Second fit with max samples restricted to just 2
+ est2 = BalancedRandomForestClassifier(
+ n_estimators=1,
+ random_state=rng,
+ max_samples=2,
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=True,
+ )
+
+ est1.fit(X, y)
+ est2.fit(X, y)
+
+ tree1 = est1.estimators_[0].tree_
+ tree2 = est2.estimators_[0].tree_
+
+ msg = "Tree without `max_samples` restriction should have more nodes"
+ assert tree1.node_count > tree2.node_count, msg
+
+
+def test_balanced_random_forest_pruning(imbalanced_dataset):
+ brf = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True, bootstrap=False
+ )
+ brf.fit(*imbalanced_dataset)
+ n_nodes_no_pruning = brf.estimators_[0].tree_.node_count
+
+ brf_pruned = BalancedRandomForestClassifier(
+ ccp_alpha=0.015, sampling_strategy="all", replacement=True, bootstrap=False
+ )
+ brf_pruned.fit(*imbalanced_dataset)
+ n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
+
+ assert n_nodes_no_pruning > n_nodes_pruning
+
+
+@pytest.mark.parametrize("ratio", [0.5, 0.1])
+@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
+def test_balanced_random_forest_oob_binomial(ratio):
+ # Regression test for #655: check that the oob score is closed to 0.5
+ # a binomial experiment.
+ rng = np.random.RandomState(42)
+ n_samples = 1000
+ X = np.arange(n_samples).reshape(-1, 1)
+ y = rng.binomial(1, ratio, size=n_samples)
+
+ erf = BalancedRandomForestClassifier(
+ oob_score=True,
+ random_state=42,
+ sampling_strategy="not minority",
+ replacement=False,
+ bootstrap=True,
+ )
+ erf.fit(X, y)
+ assert np.abs(erf.oob_score_ - 0.5) < 0.1
+
+
def test_balanced_bagging_classifier_n_features():
"""Check that we raise a FutureWarning when accessing `n_features_`."""
- pass
+ X, y = load_iris(return_X_y=True)
+ estimator = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True, bootstrap=False
+ ).fit(X, y)
+ with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
+ estimator.n_features_
+# TODO: remove in 0.13
def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
"""Check that we raise a change of behaviour for the parameters `sampling_strategy`
and `replacement`.
"""
- pass
+ estimator = BalancedRandomForestClassifier(sampling_strategy="all", bootstrap=False)
+ with pytest.warns(FutureWarning, match="The default of `replacement`"):
+ estimator.fit(*imbalanced_dataset)
+ estimator = BalancedRandomForestClassifier(replacement=True, bootstrap=False)
+ with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"):
+ estimator.fit(*imbalanced_dataset)
+ estimator = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True
+ )
+ with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
+ estimator.fit(*imbalanced_dataset)
-@pytest.mark.skipif(parse_version(sklearn_version.base_version) <
- parse_version('1.4'), reason='scikit-learn should be >= 1.4')
+@pytest.mark.skipif(
+ parse_version(sklearn_version.base_version) < parse_version("1.4"),
+ reason="scikit-learn should be >= 1.4",
+)
def test_missing_values_is_resilient():
"""Check that forest can deal with missing values and has decent performance."""
- pass
+ rng = np.random.RandomState(0)
+ n_samples, n_features = 1000, 10
+ X, y = make_classification(
+ n_samples=n_samples, n_features=n_features, random_state=rng
+ )
-@pytest.mark.skipif(parse_version(sklearn_version.base_version) <
- parse_version('1.4'), reason='scikit-learn should be >= 1.4')
+ # Create dataset with missing values
+ X_missing = X.copy()
+ X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
+ assert np.isnan(X_missing).any()
+
+ X_missing_train, X_missing_test, y_train, y_test = train_test_split(
+ X_missing, y, random_state=0
+ )
+
+ # Train forest with missing values
+ forest_with_missing = BalancedRandomForestClassifier(
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=False,
+ random_state=rng,
+ n_estimators=50,
+ )
+ forest_with_missing.fit(X_missing_train, y_train)
+ score_with_missing = forest_with_missing.score(X_missing_test, y_test)
+
+ # Train forest without missing values
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+ forest = BalancedRandomForestClassifier(
+ sampling_strategy="all",
+ replacement=True,
+ bootstrap=False,
+ random_state=rng,
+ n_estimators=50,
+ )
+ forest.fit(X_train, y_train)
+ score_without_missing = forest.score(X_test, y_test)
+
+ # Score is still 80 percent of the forest's score that had no missing values
+ assert score_with_missing >= 0.80 * score_without_missing
+
+
+@pytest.mark.skipif(
+ parse_version(sklearn_version.base_version) < parse_version("1.4"),
+ reason="scikit-learn should be >= 1.4",
+)
def test_missing_value_is_predictive():
"""Check that the forest learns when missing values are only present for
a predictive feature."""
- pass
+ rng = np.random.RandomState(0)
+ n_samples = 300
+
+ X_non_predictive = rng.standard_normal(size=(n_samples, 10))
+ y = rng.randint(0, high=2, size=n_samples)
+
+ # Create a predictive feature using `y` and with some noise
+ X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
+ y_mask = y.astype(bool)
+ y_mask[X_random_mask] = ~y_mask[X_random_mask]
+
+ predictive_feature = rng.standard_normal(size=n_samples)
+ predictive_feature[y_mask] = np.nan
+ assert np.isnan(predictive_feature).any()
+
+ X_predictive = X_non_predictive.copy()
+ X_predictive[:, 5] = predictive_feature
+
+ (
+ X_predictive_train,
+ X_predictive_test,
+ X_non_predictive_train,
+ X_non_predictive_test,
+ y_train,
+ y_test,
+ ) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
+ forest_predictive = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
+ ).fit(X_predictive_train, y_train)
+ forest_non_predictive = BalancedRandomForestClassifier(
+ sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
+ ).fit(X_non_predictive_train, y_train)
+
+ predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
+
+ assert predictive_test_score >= 0.75
+ assert predictive_test_score >= forest_non_predictive.score(
+ X_non_predictive_test, y_test
+ )
diff --git a/imblearn/ensemble/tests/test_weight_boosting.py b/imblearn/ensemble/tests/test_weight_boosting.py
index c1d0d1c..ad3dbca 100644
--- a/imblearn/ensemble/tests/test_weight_boosting.py
+++ b/imblearn/ensemble/tests/test_weight_boosting.py
@@ -5,5 +5,90 @@ from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.utils._testing import assert_array_equal
from sklearn.utils.fixes import parse_version
+
from imblearn.ensemble import RUSBoostClassifier
+
sklearn_version = parse_version(sklearn.__version__)
+
+
+@pytest.fixture
+def imbalanced_dataset():
+ return make_classification(
+ n_samples=10000,
+ n_features=3,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_classes=3,
+ n_clusters_per_class=1,
+ weights=[0.01, 0.05, 0.94],
+ class_sep=0.8,
+ random_state=0,
+ )
+
+
+@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
+@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm (the default) is")
+def test_rusboost(imbalanced_dataset, algorithm):
+ X, y = imbalanced_dataset
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, stratify=y, random_state=1
+ )
+ classes = np.unique(y)
+
+ n_estimators = 500
+ rusboost = RUSBoostClassifier(
+ n_estimators=n_estimators, algorithm=algorithm, random_state=0
+ )
+ rusboost.fit(X_train, y_train)
+ assert_array_equal(classes, rusboost.classes_)
+
+ # check that we have an ensemble of samplers and estimators with a
+ # consistent size
+ assert len(rusboost.estimators_) > 1
+ assert len(rusboost.estimators_) == len(rusboost.samplers_)
+ assert len(rusboost.pipelines_) == len(rusboost.samplers_)
+
+ # each sampler in the ensemble should have different random state
+ assert len({sampler.random_state for sampler in rusboost.samplers_}) == len(
+ rusboost.samplers_
+ )
+ # each estimator in the ensemble should have different random state
+ assert len({est.random_state for est in rusboost.estimators_}) == len(
+ rusboost.estimators_
+ )
+
+ # check the consistency of the feature importances
+ assert len(rusboost.feature_importances_) == imbalanced_dataset[0].shape[1]
+
+ # check the consistency of the prediction outpus
+ y_pred = rusboost.predict_proba(X_test)
+ assert y_pred.shape[1] == len(classes)
+ assert rusboost.decision_function(X_test).shape[1] == len(classes)
+
+ score = rusboost.score(X_test, y_test)
+ assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}"
+
+ y_pred = rusboost.predict(X_test)
+ assert y_pred.shape == y_test.shape
+
+
+@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
+@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm (the default) is")
+def test_rusboost_sample_weight(imbalanced_dataset, algorithm):
+ X, y = imbalanced_dataset
+ sample_weight = np.ones_like(y)
+ rusboost = RUSBoostClassifier(algorithm=algorithm, random_state=0)
+
+ # Predictions should be the same when sample_weight are all ones
+ y_pred_sample_weight = rusboost.fit(X, y, sample_weight).predict(X)
+ y_pred_no_sample_weight = rusboost.fit(X, y).predict(X)
+
+ assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight)
+
+ rng = np.random.RandomState(42)
+ sample_weight = rng.rand(y.shape[0])
+ y_pred_sample_weight = rusboost.fit(X, y, sample_weight).predict(X)
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight)
diff --git a/imblearn/exceptions.py b/imblearn/exceptions.py
index a78c1ad..1011d14 100644
--- a/imblearn/exceptions.py
+++ b/imblearn/exceptions.py
@@ -3,6 +3,9 @@ The :mod:`imblearn.exceptions` module includes all custom warnings and error
classes and functions used across imbalanced-learn.
"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
def raise_isinstance_error(variable_name, possible_type, variable):
"""Raise consistent error message for isinstance() function.
@@ -23,4 +26,7 @@ def raise_isinstance_error(variable_name, possible_type, variable):
ValueError
If the instance is not of the possible type.
"""
- pass
+ raise ValueError(
+ f"{variable_name} has to be one of {possible_type}. "
+ f"Got {type(variable)} instead."
+ )
diff --git a/imblearn/keras/_generator.py b/imblearn/keras/_generator.py
index f3de87c..ce49935 100644
--- a/imblearn/keras/_generator.py
+++ b/imblearn/keras/_generator.py
@@ -1,26 +1,67 @@
"""Implement generators for ``keras`` which will balance the data."""
+# This is a trick to avoid an error during tests collection with pytest. We
+# avoid the error when importing the package raise the error at the moment of
+# creating the instance.
+# This is a trick to avoid an error during tests collection with pytest. We
+# avoid the error when importing the package raise the error at the moment of
+# creating the instance.
def import_keras():
"""Try to import keras from keras and tensorflow.
This is possible to import the sequence from keras or tensorflow.
"""
- pass
+
+ def import_from_keras():
+ try:
+ import keras # noqa
+
+ if hasattr(keras.utils, "Sequence"):
+ return (keras.utils.Sequence,), True
+ else:
+ return (keras.utils.data_utils.Sequence,), True
+ except ImportError:
+ return tuple(), False
+
+ def import_from_tensforflow():
+ try:
+ from tensorflow import keras
+
+ if hasattr(keras.utils, "Sequence"):
+ return (keras.utils.Sequence,), True
+ else:
+ return (keras.utils.data_utils.Sequence,), True
+ except ImportError:
+ return tuple(), False
+
+ ParentClassKeras, has_keras_k = import_from_keras()
+ ParentClassTensorflow, has_keras_tf = import_from_tensforflow()
+ has_keras = has_keras_k or has_keras_tf
+ if has_keras:
+ if has_keras_k:
+ ParentClass = ParentClassKeras
+ else:
+ ParentClass = ParentClassTensorflow
+ else:
+ ParentClass = (object,)
+ return ParentClass, has_keras
ParentClass, HAS_KERAS = import_keras()
-from scipy.sparse import issparse
-from sklearn.base import clone
-from sklearn.utils import _safe_indexing
-from sklearn.utils import check_random_state
-from ..tensorflow import balanced_batch_generator as tf_bbg
-from ..under_sampling import RandomUnderSampler
-from ..utils import Substitution
-from ..utils._docstring import _random_state_docstring
+from scipy.sparse import issparse # noqa
+from sklearn.base import clone # noqa
+from sklearn.utils import _safe_indexing # noqa
+from sklearn.utils import check_random_state # noqa
+
+from ..tensorflow import balanced_batch_generator as tf_bbg # noqa
+from ..under_sampling import RandomUnderSampler # noqa
+from ..utils import Substitution # noqa
+from ..utils._docstring import _random_state_docstring # noqa
-class BalancedBatchGenerator(*ParentClass):
+
+class BalancedBatchGenerator(*ParentClass): # type: ignore
"""Create balanced batches when training a keras model.
Create a keras ``Sequence`` which is given to ``fit``. The
@@ -96,10 +137,21 @@ class BalancedBatchGenerator(*ParentClass):
... X, y, sampler=NearMiss(), batch_size=10, random_state=42)
>>> callback_history = model.fit(training_generator, epochs=10, verbose=0)
"""
+
+ # flag for keras sequence duck-typing
use_sequence_api = True
- def __init__(self, X, y, *, sample_weight=None, sampler=None,
- batch_size=32, keep_sparse=False, random_state=None):
+ def __init__(
+ self,
+ X,
+ y,
+ *,
+ sample_weight=None,
+ sampler=None,
+ batch_size=32,
+ keep_sparse=False,
+ random_state=None,
+ ):
if not HAS_KERAS:
raise ImportError("'No module named 'keras'")
self.X = X
@@ -111,20 +163,39 @@ class BalancedBatchGenerator(*ParentClass):
self.random_state = random_state
self._sample()
+ def _sample(self):
+ random_state = check_random_state(self.random_state)
+ if self.sampler is None:
+ self.sampler_ = RandomUnderSampler(random_state=random_state)
+ else:
+ self.sampler_ = clone(self.sampler)
+ self.sampler_.fit_resample(self.X, self.y)
+ if not hasattr(self.sampler_, "sample_indices_"):
+ raise ValueError("'sampler' needs to have an attribute 'sample_indices_'.")
+ self.indices_ = self.sampler_.sample_indices_
+ # shuffle the indices since the sampler are packing them by class
+ random_state.shuffle(self.indices_)
+
def __len__(self):
return int(self.indices_.size // self.batch_size)
def __getitem__(self, index):
- X_resampled = _safe_indexing(self.X, self.indices_[index * self.
- batch_size:(index + 1) * self.batch_size])
- y_resampled = _safe_indexing(self.y, self.indices_[index * self.
- batch_size:(index + 1) * self.batch_size])
+ X_resampled = _safe_indexing(
+ self.X,
+ self.indices_[index * self.batch_size : (index + 1) * self.batch_size],
+ )
+ y_resampled = _safe_indexing(
+ self.y,
+ self.indices_[index * self.batch_size : (index + 1) * self.batch_size],
+ )
if issparse(X_resampled) and not self.keep_sparse:
X_resampled = X_resampled.toarray()
if self.sample_weight is not None:
- sample_weight_resampled = _safe_indexing(self.sample_weight,
- self.indices_[index * self.batch_size:(index + 1) * self.
- batch_size])
+ sample_weight_resampled = _safe_indexing(
+ self.sample_weight,
+ self.indices_[index * self.batch_size : (index + 1) * self.batch_size],
+ )
+
if self.sample_weight is None:
return X_resampled, y_resampled
else:
@@ -132,8 +203,16 @@ class BalancedBatchGenerator(*ParentClass):
@Substitution(random_state=_random_state_docstring)
-def balanced_batch_generator(X, y, *, sample_weight=None, sampler=None,
- batch_size=32, keep_sparse=False, random_state=None):
+def balanced_batch_generator(
+ X,
+ y,
+ *,
+ sample_weight=None,
+ sampler=None,
+ batch_size=32,
+ keep_sparse=False,
+ random_state=None,
+):
"""Create a balanced batch generator to train keras model.
Returns a generator --- as well as the number of step per epoch --- which
@@ -204,4 +283,13 @@ def balanced_batch_generator(X, y, *, sample_weight=None, sampler=None,
... steps_per_epoch=steps_per_epoch,
... epochs=10, verbose=0)
"""
- pass
+
+ return tf_bbg(
+ X=X,
+ y=y,
+ sample_weight=sample_weight,
+ sampler=sampler,
+ batch_size=batch_size,
+ keep_sparse=keep_sparse,
+ random_state=random_state,
+ )
diff --git a/imblearn/keras/tests/test_generator.py b/imblearn/keras/tests/test_generator.py
index 59ccc2c..a073d84 100644
--- a/imblearn/keras/tests/test_generator.py
+++ b/imblearn/keras/tests/test_generator.py
@@ -4,11 +4,144 @@ from scipy import sparse
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from sklearn.preprocessing import LabelBinarizer
-keras = pytest.importorskip('keras')
-from keras.layers import Dense
-from keras.models import Sequential
-from imblearn.datasets import make_imbalance
-from imblearn.keras import BalancedBatchGenerator, balanced_batch_generator
-from imblearn.over_sampling import RandomOverSampler
-from imblearn.under_sampling import ClusterCentroids, NearMiss
+
+keras = pytest.importorskip("keras")
+from keras.layers import Dense # noqa: E402
+from keras.models import Sequential # noqa: E402
+
+from imblearn.datasets import make_imbalance # noqa: E402
+from imblearn.keras import (
+ BalancedBatchGenerator, # noqa: E402
+ balanced_batch_generator, # noqa: E402
+)
+from imblearn.over_sampling import RandomOverSampler # noqa: E402
+from imblearn.under_sampling import (
+ ClusterCentroids, # noqa: E402
+ NearMiss, # noqa: E402
+)
+
3
+
+
+@pytest.fixture
+def data():
+ iris = load_iris()
+ X, y = make_imbalance(
+ iris.data, iris.target, sampling_strategy={0: 30, 1: 50, 2: 40}
+ )
+ y = LabelBinarizer().fit_transform(y)
+ return X, y
+
+
+def _build_keras_model(n_classes, n_features):
+ model = Sequential()
+ model.add(Dense(n_classes, input_dim=n_features, activation="softmax"))
+ model.compile(
+ optimizer="sgd", loss="categorical_crossentropy", metrics=["accuracy"]
+ )
+ return model
+
+
+def test_balanced_batch_generator_class_no_return_indices(data):
+ with pytest.raises(ValueError, match="needs to have an attribute"):
+ BalancedBatchGenerator(
+ *data, sampler=ClusterCentroids(estimator=KMeans(n_init=1)), batch_size=10
+ )
+
+
+@pytest.mark.filterwarnings("ignore:`wait_time` is not used") # keras 2.2.4
+@pytest.mark.parametrize(
+ "sampler, sample_weight",
+ [
+ (None, None),
+ (RandomOverSampler(), None),
+ (NearMiss(), None),
+ (None, np.random.uniform(size=120)),
+ ],
+)
+def test_balanced_batch_generator_class(data, sampler, sample_weight):
+ X, y = data
+ model = _build_keras_model(y.shape[1], X.shape[1])
+ training_generator = BalancedBatchGenerator(
+ X,
+ y,
+ sample_weight=sample_weight,
+ sampler=sampler,
+ batch_size=10,
+ random_state=42,
+ )
+ model.fit(training_generator, epochs=10)
+
+
+@pytest.mark.parametrize("keep_sparse", [True, False])
+def test_balanced_batch_generator_class_sparse(data, keep_sparse):
+ X, y = data
+ training_generator = BalancedBatchGenerator(
+ sparse.csr_matrix(X),
+ y,
+ batch_size=10,
+ keep_sparse=keep_sparse,
+ random_state=42,
+ )
+ for idx in range(len(training_generator)):
+ X_batch, _ = training_generator.__getitem__(idx)
+ if keep_sparse:
+ assert sparse.issparse(X_batch)
+ else:
+ assert not sparse.issparse(X_batch)
+
+
+def test_balanced_batch_generator_function_no_return_indices(data):
+ with pytest.raises(ValueError, match="needs to have an attribute"):
+ balanced_batch_generator(
+ *data,
+ sampler=ClusterCentroids(estimator=KMeans(n_init=10)),
+ batch_size=10,
+ random_state=42,
+ )
+
+
+@pytest.mark.filterwarnings("ignore:`wait_time` is not used") # keras 2.2.4
+@pytest.mark.parametrize(
+ "sampler, sample_weight",
+ [
+ (None, None),
+ (RandomOverSampler(), None),
+ (NearMiss(), None),
+ (None, np.random.uniform(size=120)),
+ ],
+)
+def test_balanced_batch_generator_function(data, sampler, sample_weight):
+ X, y = data
+ model = _build_keras_model(y.shape[1], X.shape[1])
+ training_generator, steps_per_epoch = balanced_batch_generator(
+ X,
+ y,
+ sample_weight=sample_weight,
+ sampler=sampler,
+ batch_size=10,
+ random_state=42,
+ )
+ model.fit(
+ training_generator,
+ steps_per_epoch=steps_per_epoch,
+ epochs=10,
+ )
+
+
+@pytest.mark.parametrize("keep_sparse", [True, False])
+def test_balanced_batch_generator_function_sparse(data, keep_sparse):
+ X, y = data
+ training_generator, steps_per_epoch = balanced_batch_generator(
+ sparse.csr_matrix(X),
+ y,
+ keep_sparse=keep_sparse,
+ batch_size=10,
+ random_state=42,
+ )
+ for _ in range(steps_per_epoch):
+ X_batch, _ = next(training_generator)
+ if keep_sparse:
+ assert sparse.issparse(X_batch)
+ else:
+ assert not sparse.issparse(X_batch)
diff --git a/imblearn/metrics/_classification.py b/imblearn/metrics/_classification.py
index 6723b08..489066d 100644
--- a/imblearn/metrics/_classification.py
+++ b/imblearn/metrics/_classification.py
@@ -1,3 +1,4 @@
+# coding: utf-8
"""Metrics to assess performance on a classification task given class
predictions. The available metrics are complementary from the metrics available
in scikit-learn.
@@ -8,10 +9,16 @@ the better
Function named as ``*_error`` or ``*_loss`` return a scalar value to minimize:
the lower the better
"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Dariusz Brzezinski
+# License: MIT
+
import functools
import numbers
import warnings
from inspect import signature
+
import numpy as np
import scipy as sp
from sklearn.metrics import mean_absolute_error, precision_recall_fscore_support
@@ -19,17 +26,35 @@ from sklearn.metrics._classification import _check_targets, _prf_divide
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_consistent_length, column_or_1d
+
from ..utils._param_validation import Interval, StrOptions, validate_params
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'labels': ['array-like', None], 'pos_label': [str, numbers.Integral,
- None], 'average': [None, StrOptions({'binary', 'micro', 'macro',
- 'weighted', 'samples'})], 'warn_for': ['array-like'], 'sample_weight':
- ['array-like', None]}, prefer_skip_nested_validation=True)
-def sensitivity_specificity_support(y_true, y_pred, *, labels=None,
- pos_label=1, average=None, warn_for=('sensitivity', 'specificity'),
- sample_weight=None):
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "labels": ["array-like", None],
+ "pos_label": [str, numbers.Integral, None],
+ "average": [
+ None,
+ StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
+ ],
+ "warn_for": ["array-like"],
+ "sample_weight": ["array-like", None],
+ },
+ prefer_skip_nested_validation=True,
+)
+def sensitivity_specificity_support(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ pos_label=1,
+ average=None,
+ warn_for=("sensitivity", "specificity"),
+ sample_weight=None,
+):
"""Compute sensitivity, specificity, and support for each class.
The sensitivity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
@@ -106,13 +131,16 @@ def sensitivity_specificity_support(y_true, y_pred, *, labels=None,
Returns
-------
- sensitivity : float (if `average is None`) or ndarray of shape (n_unique_labels,)
+ sensitivity : float (if `average is None`) or ndarray of \
+ shape (n_unique_labels,)
The sensitivity metric.
- specificity : float (if `average is None`) or ndarray of shape (n_unique_labels,)
+ specificity : float (if `average is None`) or ndarray of \
+ shape (n_unique_labels,)
The specificity metric.
- support : int (if `average is None`) or ndarray of shape (n_unique_labels,)
+ support : int (if `average is None`) or ndarray of \
+ shape (n_unique_labels,)
The number of occurrences of each label in ``y_true``.
References
@@ -133,16 +161,163 @@ def sensitivity_specificity_support(y_true, y_pred, *, labels=None,
>>> sensitivity_specificity_support(y_true, y_pred, average='weighted')
(0.33..., 0.66..., None)
"""
- pass
-
-
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'labels': ['array-like', None], 'pos_label': [str, numbers.Integral,
- None], 'average': [None, StrOptions({'binary', 'micro', 'macro',
- 'weighted', 'samples'})], 'sample_weight': ['array-like', None]},
- prefer_skip_nested_validation=True)
-def sensitivity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
- 'binary', sample_weight=None):
+ average_options = (None, "micro", "macro", "weighted", "samples")
+ if average not in average_options and average != "binary":
+ raise ValueError("average has to be one of " + str(average_options))
+
+ y_type, y_true, y_pred = _check_targets(y_true, y_pred)
+ present_labels = unique_labels(y_true, y_pred)
+
+ if average == "binary":
+ if y_type == "binary":
+ if pos_label not in present_labels:
+ if len(present_labels) < 2:
+ # Only negative labels
+ return (0.0, 0.0, 0)
+ else:
+ raise ValueError(
+ "pos_label=%r is not a valid label: %r"
+ % (pos_label, present_labels)
+ )
+ labels = [pos_label]
+ else:
+ raise ValueError(
+ "Target is %s but average='binary'. Please "
+ "choose another average setting." % y_type
+ )
+ elif pos_label not in (None, 1):
+ warnings.warn(
+ "Note that pos_label (set to %r) is ignored when "
+ "average != 'binary' (got %r). You may use "
+ "labels=[pos_label] to specify a single positive class."
+ % (pos_label, average),
+ UserWarning,
+ )
+
+ if labels is None:
+ labels = present_labels
+ n_labels = None
+ else:
+ n_labels = len(labels)
+ labels = np.hstack(
+ [labels, np.setdiff1d(present_labels, labels, assume_unique=True)]
+ )
+
+ # Calculate tp_sum, pred_sum, true_sum ###
+
+ if y_type.startswith("multilabel"):
+ raise ValueError("imblearn does not support multilabel")
+ elif average == "samples":
+ raise ValueError(
+ "Sample-based precision, recall, fscore is "
+ "not meaningful outside multilabel "
+ "classification. See the accuracy_score instead."
+ )
+ else:
+ le = LabelEncoder()
+ le.fit(labels)
+ y_true = le.transform(y_true)
+ y_pred = le.transform(y_pred)
+ sorted_labels = le.classes_
+
+ # labels are now from 0 to len(labels) - 1 -> use bincount
+ tp = y_true == y_pred
+ tp_bins = y_true[tp]
+ if sample_weight is not None:
+ tp_bins_weights = np.asarray(sample_weight)[tp]
+ else:
+ tp_bins_weights = None
+
+ if len(tp_bins):
+ tp_sum = np.bincount(
+ tp_bins, weights=tp_bins_weights, minlength=len(labels)
+ )
+ else:
+ # Pathological case
+ true_sum = pred_sum = tp_sum = np.zeros(len(labels))
+ if len(y_pred):
+ pred_sum = np.bincount(y_pred, weights=sample_weight, minlength=len(labels))
+ if len(y_true):
+ true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels))
+
+ # Compute the true negative
+ tn_sum = y_true.size - (pred_sum + true_sum - tp_sum)
+
+ # Retain only selected labels
+ indices = np.searchsorted(sorted_labels, labels[:n_labels])
+ tp_sum = tp_sum[indices]
+ true_sum = true_sum[indices]
+ pred_sum = pred_sum[indices]
+ tn_sum = tn_sum[indices]
+
+ if average == "micro":
+ tp_sum = np.array([tp_sum.sum()])
+ pred_sum = np.array([pred_sum.sum()])
+ true_sum = np.array([true_sum.sum()])
+ tn_sum = np.array([tn_sum.sum()])
+
+ # Finally, we have all our sufficient statistics. Divide! #
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ # Divide, and on zero-division, set scores to 0 and warn:
+
+ # Oddly, we may get an "invalid" rather than a "divide" error
+ # here.
+ specificity = _prf_divide(
+ tn_sum,
+ tn_sum + pred_sum - tp_sum,
+ "specificity",
+ "predicted",
+ average,
+ warn_for,
+ )
+ sensitivity = _prf_divide(
+ tp_sum, true_sum, "sensitivity", "true", average, warn_for
+ )
+
+ # Average the results
+
+ if average == "weighted":
+ weights = true_sum
+ if weights.sum() == 0:
+ return 0, 0, None
+ elif average == "samples":
+ weights = sample_weight
+ else:
+ weights = None
+
+ if average is not None:
+ assert average != "binary" or len(specificity) == 1
+ specificity = np.average(specificity, weights=weights)
+ sensitivity = np.average(sensitivity, weights=weights)
+ true_sum = None # return no support
+
+ return sensitivity, specificity, true_sum
+
+
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "labels": ["array-like", None],
+ "pos_label": [str, numbers.Integral, None],
+ "average": [
+ None,
+ StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
+ ],
+ "sample_weight": ["array-like", None],
+ },
+ prefer_skip_nested_validation=True,
+)
+def sensitivity_score(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ pos_label=1,
+ average="binary",
+ sample_weight=None,
+):
"""Compute the sensitivity.
The sensitivity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
@@ -204,7 +379,8 @@ def sensitivity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
Returns
-------
- specificity : float (if `average is None`) or ndarray of shape (n_unique_labels,)
+ specificity : float (if `average is None`) or ndarray of \
+ shape (n_unique_labels,)
The specifcity metric.
Examples
@@ -222,16 +398,42 @@ def sensitivity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
>>> sensitivity_score(y_true, y_pred, average=None)
array([1., 0., 0.])
"""
- pass
-
-
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'labels': ['array-like', None], 'pos_label': [str, numbers.Integral,
- None], 'average': [None, StrOptions({'binary', 'micro', 'macro',
- 'weighted', 'samples'})], 'sample_weight': ['array-like', None]},
- prefer_skip_nested_validation=True)
-def specificity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
- 'binary', sample_weight=None):
+ s, _, _ = sensitivity_specificity_support(
+ y_true,
+ y_pred,
+ labels=labels,
+ pos_label=pos_label,
+ average=average,
+ warn_for=("sensitivity",),
+ sample_weight=sample_weight,
+ )
+
+ return s
+
+
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "labels": ["array-like", None],
+ "pos_label": [str, numbers.Integral, None],
+ "average": [
+ None,
+ StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
+ ],
+ "sample_weight": ["array-like", None],
+ },
+ prefer_skip_nested_validation=True,
+)
+def specificity_score(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ pos_label=1,
+ average="binary",
+ sample_weight=None,
+):
"""Compute the specificity.
The specificity is the ratio ``tn / (tn + fp)`` where ``tn`` is the number
@@ -293,7 +495,8 @@ def specificity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
Returns
-------
- specificity : float (if `average is None`) or ndarray of shape (n_unique_labels,)
+ specificity : float (if `average is None`) or ndarray of \
+ shape (n_unique_labels,)
The specificity metric.
Examples
@@ -311,17 +514,46 @@ def specificity_score(y_true, y_pred, *, labels=None, pos_label=1, average=
>>> specificity_score(y_true, y_pred, average=None)
array([0.75, 0.5 , 0.75])
"""
- pass
-
-
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'labels': ['array-like', None], 'pos_label': [str, numbers.Integral,
- None], 'average': [None, StrOptions({'binary', 'micro', 'macro',
- 'weighted', 'samples', 'multiclass'})], 'sample_weight': ['array-like',
- None], 'correction': [Interval(numbers.Real, 0, None, closed='left')]},
- prefer_skip_nested_validation=True)
-def geometric_mean_score(y_true, y_pred, *, labels=None, pos_label=1,
- average='multiclass', sample_weight=None, correction=0.0):
+ _, s, _ = sensitivity_specificity_support(
+ y_true,
+ y_pred,
+ labels=labels,
+ pos_label=pos_label,
+ average=average,
+ warn_for=("specificity",),
+ sample_weight=sample_weight,
+ )
+
+ return s
+
+
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "labels": ["array-like", None],
+ "pos_label": [str, numbers.Integral, None],
+ "average": [
+ None,
+ StrOptions(
+ {"binary", "micro", "macro", "weighted", "samples", "multiclass"}
+ ),
+ ],
+ "sample_weight": ["array-like", None],
+ "correction": [Interval(numbers.Real, 0, None, closed="left")],
+ },
+ prefer_skip_nested_validation=True,
+)
+def geometric_mean_score(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ pos_label=1,
+ average="multiclass",
+ sample_weight=None,
+ correction=0.0,
+):
"""Compute the geometric mean.
The geometric mean (G-mean) is the root of the product of class-wise
@@ -435,11 +667,76 @@ def geometric_mean_score(y_true, y_pred, *, labels=None, pos_label=1,
>>> geometric_mean_score(y_true, y_pred, average=None)
array([0.866..., 0. , 0. ])
"""
- pass
-
-
-@validate_params({'alpha': [numbers.Real], 'squared': ['boolean']},
- prefer_skip_nested_validation=True)
+ if average is None or average != "multiclass":
+ sen, spe, _ = sensitivity_specificity_support(
+ y_true,
+ y_pred,
+ labels=labels,
+ pos_label=pos_label,
+ average=average,
+ warn_for=("specificity", "specificity"),
+ sample_weight=sample_weight,
+ )
+
+ return np.sqrt(sen * spe)
+ else:
+ present_labels = unique_labels(y_true, y_pred)
+
+ if labels is None:
+ labels = present_labels
+ n_labels = None
+ else:
+ n_labels = len(labels)
+ labels = np.hstack(
+ [labels, np.setdiff1d(present_labels, labels, assume_unique=True)]
+ )
+
+ le = LabelEncoder()
+ le.fit(labels)
+ y_true = le.transform(y_true)
+ y_pred = le.transform(y_pred)
+ sorted_labels = le.classes_
+
+ # labels are now from 0 to len(labels) - 1 -> use bincount
+ tp = y_true == y_pred
+ tp_bins = y_true[tp]
+
+ if sample_weight is not None:
+ tp_bins_weights = np.asarray(sample_weight)[tp]
+ else:
+ tp_bins_weights = None
+
+ if len(tp_bins):
+ tp_sum = np.bincount(
+ tp_bins, weights=tp_bins_weights, minlength=len(labels)
+ )
+ else:
+ # Pathological case
+ true_sum = tp_sum = np.zeros(len(labels))
+ if len(y_true):
+ true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels))
+
+ # Retain only selected labels
+ indices = np.searchsorted(sorted_labels, labels[:n_labels])
+ tp_sum = tp_sum[indices]
+ true_sum = true_sum[indices]
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ recall = _prf_divide(tp_sum, true_sum, "recall", "true", None, "recall")
+ recall[recall == 0] = correction
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ gmean = sp.stats.gmean(recall)
+ # old version of scipy return MaskedConstant instead of 0.0
+ if isinstance(gmean, np.ma.core.MaskedConstant):
+ return 0.0
+ return gmean
+
+
+@validate_params(
+ {"alpha": [numbers.Real], "squared": ["boolean"]},
+ prefer_skip_nested_validation=True,
+)
def make_index_balanced_accuracy(*, alpha=0.1, squared=True):
"""Balance any scoring function using the index balanced accuracy.
@@ -489,19 +786,91 @@ def make_index_balanced_accuracy(*, alpha=0.1, squared=True):
>>> print(gmean(y_true, y_pred, average=None))
[0.44... 0.44...]
"""
- pass
-
-
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'labels': ['array-like', None], 'target_names': ['array-like', None],
- 'sample_weight': ['array-like', None], 'digits': [Interval(numbers.
- Integral, 0, None, closed='left')], 'alpha': [numbers.Real],
- 'output_dict': ['boolean'], 'zero_division': [StrOptions({'warn'}),
- Interval(numbers.Integral, 0, 1, closed='both')]},
- prefer_skip_nested_validation=True)
-def classification_report_imbalanced(y_true, y_pred, *, labels=None,
- target_names=None, sample_weight=None, digits=2, alpha=0.1, output_dict
- =False, zero_division='warn'):
+
+ def decorate(scoring_func):
+ @functools.wraps(scoring_func)
+ def compute_score(*args, **kwargs):
+ signature_scoring_func = signature(scoring_func)
+ params_scoring_func = set(signature_scoring_func.parameters.keys())
+
+ # check that the scoring function does not need a score
+ # and only a prediction
+ prohibitied_y_pred = set(["y_score", "y_prob", "y2"])
+ if prohibitied_y_pred.intersection(params_scoring_func):
+ raise AttributeError(
+ f"The function {scoring_func.__name__} has an unsupported"
+ f" attribute. Metric with`y_pred` are the"
+ f" only supported metrics is the only"
+ f" supported."
+ )
+
+ args_scoring_func = signature_scoring_func.bind(*args, **kwargs)
+ args_scoring_func.apply_defaults()
+ _score = scoring_func(*args_scoring_func.args, **args_scoring_func.kwargs)
+ if squared:
+ _score = np.power(_score, 2)
+
+ signature_sens_spec = signature(sensitivity_specificity_support)
+ params_sens_spec = set(signature_sens_spec.parameters.keys())
+ common_params = params_sens_spec.intersection(
+ set(args_scoring_func.arguments.keys())
+ )
+
+ args_sens_spec = {k: args_scoring_func.arguments[k] for k in common_params}
+
+ if scoring_func.__name__ == "geometric_mean_score":
+ if "average" in args_sens_spec:
+ if args_sens_spec["average"] == "multiclass":
+ args_sens_spec["average"] = "macro"
+ elif (
+ scoring_func.__name__ == "accuracy_score"
+ or scoring_func.__name__ == "jaccard_score"
+ ):
+ # We do not support multilabel so the only average supported
+ # is binary
+ args_sens_spec["average"] = "binary"
+
+ sensitivity, specificity, _ = sensitivity_specificity_support(
+ **args_sens_spec
+ )
+
+ dominance = sensitivity - specificity
+ return (1.0 + alpha * dominance) * _score
+
+ return compute_score
+
+ return decorate
+
+
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "labels": ["array-like", None],
+ "target_names": ["array-like", None],
+ "sample_weight": ["array-like", None],
+ "digits": [Interval(numbers.Integral, 0, None, closed="left")],
+ "alpha": [numbers.Real],
+ "output_dict": ["boolean"],
+ "zero_division": [
+ StrOptions({"warn"}),
+ Interval(numbers.Integral, 0, 1, closed="both"),
+ ],
+ },
+ prefer_skip_nested_validation=True,
+)
+def classification_report_imbalanced(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ target_names=None,
+ sample_weight=None,
+ digits=2,
+ alpha=0.1,
+ output_dict=False,
+ zero_division="warn",
+):
"""Build a classification report based on metrics used with imbalanced dataset.
Specific metrics have been proposed to evaluate the classification
@@ -571,21 +940,140 @@ def classification_report_imbalanced(y_true, y_pred, *, labels=None,
>>> y_true = [0, 1, 2, 2, 2]
>>> y_pred = [0, 0, 2, 2, 1]
>>> target_names = ['class 0', 'class 1', 'class 2']
- >>> print(classification_report_imbalanced(y_true, y_pred, target_names=target_names))
- pre rec spe f1 geo iba sup
+ >>> print(classification_report_imbalanced(y_true, y_pred, \
+ target_names=target_names))
+ pre rec spe f1 geo iba\
+ sup
<BLANKLINE>
- class 0 0.50 1.00 0.75 0.67 0.87 0.77 1
- class 1 0.00 0.00 0.75 0.00 0.00 0.00 1
- class 2 1.00 0.67 1.00 0.80 0.82 0.64 3
+ class 0 0.50 1.00 0.75 0.67 0.87 0.77\
+ 1
+ class 1 0.00 0.00 0.75 0.00 0.00 0.00\
+ 1
+ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
+ 3
<BLANKLINE>
- avg / total 0.70 0.60 0.90 0.61 0.66 0.54 5
+ avg / total 0.70 0.60 0.90 0.61 0.66 0.54\
+ 5
<BLANKLINE>
"""
- pass
-
-@validate_params({'y_true': ['array-like'], 'y_pred': ['array-like'],
- 'sample_weight': ['array-like', None]}, prefer_skip_nested_validation=True)
+ if labels is None:
+ labels = unique_labels(y_true, y_pred)
+ else:
+ labels = np.asarray(labels)
+
+ last_line_heading = "avg / total"
+
+ if target_names is None:
+ target_names = [f"{label}" for label in labels]
+ name_width = max(len(cn) for cn in target_names)
+ width = max(name_width, len(last_line_heading), digits)
+
+ headers = ["pre", "rec", "spe", "f1", "geo", "iba", "sup"]
+ fmt = "%% %ds" % width # first column: class name
+ fmt += " "
+ fmt += " ".join(["% 9s" for _ in headers])
+ fmt += "\n"
+
+ headers = [""] + headers
+ report = fmt % tuple(headers)
+ report += "\n"
+
+ # Compute the different metrics
+ # Precision/recall/f1
+ precision, recall, f1, support = precision_recall_fscore_support(
+ y_true,
+ y_pred,
+ labels=labels,
+ average=None,
+ sample_weight=sample_weight,
+ zero_division=zero_division,
+ )
+ # Specificity
+ specificity = specificity_score(
+ y_true,
+ y_pred,
+ labels=labels,
+ average=None,
+ sample_weight=sample_weight,
+ )
+ # Geometric mean
+ geo_mean = geometric_mean_score(
+ y_true,
+ y_pred,
+ labels=labels,
+ average=None,
+ sample_weight=sample_weight,
+ )
+ # Index balanced accuracy
+ iba_gmean = make_index_balanced_accuracy(alpha=alpha, squared=True)(
+ geometric_mean_score
+ )
+ iba = iba_gmean(
+ y_true,
+ y_pred,
+ labels=labels,
+ average=None,
+ sample_weight=sample_weight,
+ )
+
+ report_dict = {}
+ for i, label in enumerate(labels):
+ report_dict_label = {}
+ values = [target_names[i]]
+ for score_name, score_value in zip(
+ headers[1:-1],
+ [
+ precision[i],
+ recall[i],
+ specificity[i],
+ f1[i],
+ geo_mean[i],
+ iba[i],
+ ],
+ ):
+ values += ["{0:0.{1}f}".format(score_value, digits)]
+ report_dict_label[score_name] = score_value
+ values += [f"{support[i]}"]
+ report_dict_label[headers[-1]] = support[i]
+ report += fmt % tuple(values)
+
+ report_dict[target_names[i]] = report_dict_label
+
+ report += "\n"
+
+ # compute averages
+ values = [last_line_heading]
+ for score_name, score_value in zip(
+ headers[1:-1],
+ [
+ np.average(precision, weights=support),
+ np.average(recall, weights=support),
+ np.average(specificity, weights=support),
+ np.average(f1, weights=support),
+ np.average(geo_mean, weights=support),
+ np.average(iba, weights=support),
+ ],
+ ):
+ values += ["{0:0.{1}f}".format(score_value, digits)]
+ report_dict[f"avg_{score_name}"] = score_value
+ values += [f"{np.sum(support)}"]
+ report += fmt % tuple(values)
+ report_dict["total_support"] = np.sum(support)
+
+ if output_dict:
+ return report_dict
+ return report
+
+
+@validate_params(
+ {
+ "y_true": ["array-like"],
+ "y_pred": ["array-like"],
+ "sample_weight": ["array-like", None],
+ },
+ prefer_skip_nested_validation=True,
+)
def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
"""Compute Macro-Averaged MAE for imbalanced ordinal classification.
@@ -630,4 +1118,23 @@ def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
>>> macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred)
0.16...
"""
- pass
+ _, y_true, y_pred = _check_targets(y_true, y_pred)
+ if sample_weight is not None:
+ sample_weight = column_or_1d(sample_weight)
+ else:
+ sample_weight = np.ones(y_true.shape)
+ check_consistent_length(y_true, y_pred, sample_weight)
+ labels = unique_labels(y_true, y_pred)
+ mae = []
+ for possible_class in labels:
+ indices = np.flatnonzero(y_true == possible_class)
+
+ mae.append(
+ mean_absolute_error(
+ y_true[indices],
+ y_pred[indices],
+ sample_weight=sample_weight[indices],
+ )
+ )
+
+ return np.sum(mae) / len(mae)
diff --git a/imblearn/metrics/pairwise.py b/imblearn/metrics/pairwise.py
index d73edae..11f654f 100644
--- a/imblearn/metrics/pairwise.py
+++ b/imblearn/metrics/pairwise.py
@@ -1,24 +1,30 @@
"""Metrics to perform pairwise computation."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import numbers
+
import numpy as np
from scipy.spatial import distance_matrix
from sklearn.base import BaseEstimator
from sklearn.utils import check_consistent_length
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted
+
from ..base import _ParamsValidationMixin
from ..utils._param_validation import StrOptions
class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
- """Class implementing the Value Difference Metric.
+ r"""Class implementing the Value Difference Metric.
This metric computes the distance between samples containing only
categorical features. The distance between feature values of two samples is
defined as:
.. math::
- \\delta(x, y) = \\sum_{c=1}^{C} |p(c|x_{f}) - p(c|y_{f})|^{k} \\ ,
+ \delta(x, y) = \sum_{c=1}^{C} |p(c|x_{f}) - p(c|y_{f})|^{k} \ ,
where :math:`x` and :math:`y` are two samples and :math:`f` a given
feature, :math:`C` is the number of classes, :math:`p(c|x_{f})` is the
@@ -30,7 +36,7 @@ class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
subsequently defined as:
.. math::
- \\Delta(X, Y) = \\sum_{f=1}^{F} \\delta(X_{f}, Y_{f})^{r} \\ ,
+ \Delta(X, Y) = \sum_{f=1}^{F} \delta(X_{f}, Y_{f})^{r} \ ,
where :math:`F` is the number of feature and :math:`r` an exponent usually
defined equal to 1 or 2.
@@ -112,10 +118,13 @@ class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
[0.04, 0. , 1.44],
[1.96, 1.44, 0. ]])
"""
- _parameter_constraints: dict = {'n_categories': [StrOptions({'auto'}),
- 'array-like'], 'k': [numbers.Integral], 'r': [numbers.Integral]}
+ _parameter_constraints: dict = {
+ "n_categories": [StrOptions({"auto"}), "array-like"],
+ "k": [numbers.Integral],
+ "r": [numbers.Integral],
+ }
- def __init__(self, *, n_categories='auto', k=1, r=2):
+ def __init__(self, *, n_categories="auto", k=1, r=2):
self.n_categories = n_categories
self.k = k
self.r = r
@@ -137,7 +146,47 @@ class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
self : object
Return the instance itself.
"""
- pass
+ self._validate_params()
+ check_consistent_length(X, y)
+ X, y = self._validate_data(X, y, reset=True, dtype=np.int32)
+
+ if isinstance(self.n_categories, str) and self.n_categories == "auto":
+ # categories are expected to be encoded from 0 to n_categories - 1
+ self.n_categories_ = X.max(axis=0) + 1
+ else:
+ if len(self.n_categories) != self.n_features_in_:
+ raise ValueError(
+ f"The length of n_categories is not consistent with the "
+ f"number of feature in X. Got {len(self.n_categories)} "
+ f"elements in n_categories and {self.n_features_in_} in "
+ f"X."
+ )
+ self.n_categories_ = np.array(self.n_categories, copy=False)
+ classes = unique_labels(y)
+
+ # list of length n_features of ndarray (n_categories, n_classes)
+ # compute the counts
+ self.proba_per_class_ = [
+ np.empty(shape=(n_cat, len(classes)), dtype=np.float64)
+ for n_cat in self.n_categories_
+ ]
+ for feature_idx in range(self.n_features_in_):
+ for klass_idx, klass in enumerate(classes):
+ self.proba_per_class_[feature_idx][:, klass_idx] = np.bincount(
+ X[y == klass, feature_idx],
+ minlength=self.n_categories_[feature_idx],
+ )
+
+ # normalize by the summing over the classes
+ with np.errstate(invalid="ignore"):
+ # silence potential warning due to in-place division by zero
+ for feature_idx in range(self.n_features_in_):
+ self.proba_per_class_[feature_idx] /= (
+ self.proba_per_class_[feature_idx].sum(axis=1).reshape(-1, 1)
+ )
+ np.nan_to_num(self.proba_per_class_[feature_idx], copy=False)
+
+ return self
def pairwise(self, X, Y=None):
"""Compute the VDM distance pairwise.
@@ -157,4 +206,29 @@ class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
distance_matrix : ndarray of shape (n_samples, n_samples)
The VDM pairwise distance.
"""
- pass
+ check_is_fitted(self)
+ X = self._validate_data(X, reset=False, dtype=np.int32)
+ n_samples_X = X.shape[0]
+
+ if Y is not None:
+ Y = self._validate_data(Y, reset=False, dtype=np.int32)
+ n_samples_Y = Y.shape[0]
+ else:
+ n_samples_Y = n_samples_X
+
+ distance = np.zeros(shape=(n_samples_X, n_samples_Y), dtype=np.float64)
+ for feature_idx in range(self.n_features_in_):
+ proba_feature_X = self.proba_per_class_[feature_idx][X[:, feature_idx]]
+ if Y is not None:
+ proba_feature_Y = self.proba_per_class_[feature_idx][Y[:, feature_idx]]
+ else:
+ proba_feature_Y = proba_feature_X
+ distance += (
+ distance_matrix(proba_feature_X, proba_feature_Y, p=self.k) ** self.r
+ )
+ return distance
+
+ def _more_tags(self):
+ return {
+ "requires_positive_X": True, # X should be encoded with OrdinalEncoder
+ }
diff --git a/imblearn/metrics/tests/test_classification.py b/imblearn/metrics/tests/test_classification.py
index 7fce78f..8169cee 100644
--- a/imblearn/metrics/tests/test_classification.py
+++ b/imblearn/metrics/tests/test_classification.py
@@ -1,15 +1,47 @@
+# coding: utf-8
"""Testing the metric for classification with imbalanced dataset"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from functools import partial
+
import numpy as np
import pytest
from sklearn import datasets, svm
-from sklearn.metrics import accuracy_score, average_precision_score, brier_score_loss, cohen_kappa_score, jaccard_score, precision_score, recall_score, roc_auc_score
+from sklearn.metrics import (
+ accuracy_score,
+ average_precision_score,
+ brier_score_loss,
+ cohen_kappa_score,
+ jaccard_score,
+ precision_score,
+ recall_score,
+ roc_auc_score,
+)
from sklearn.preprocessing import label_binarize
-from sklearn.utils._testing import assert_allclose, assert_array_equal, assert_no_warnings
+from sklearn.utils._testing import (
+ assert_allclose,
+ assert_array_equal,
+ assert_no_warnings,
+)
from sklearn.utils.validation import check_random_state
-from imblearn.metrics import classification_report_imbalanced, geometric_mean_score, macro_averaged_mean_absolute_error, make_index_balanced_accuracy, sensitivity_score, sensitivity_specificity_support, specificity_score
+
+from imblearn.metrics import (
+ classification_report_imbalanced,
+ geometric_mean_score,
+ macro_averaged_mean_absolute_error,
+ make_index_balanced_accuracy,
+ sensitivity_score,
+ sensitivity_specificity_support,
+ specificity_score,
+)
+
RND_SEED = 42
-R_TOL = 0.01
+R_TOL = 1e-2
+
+###############################################################################
+# Utilities for testing
def make_prediction(dataset=None, binary=False):
@@ -17,4 +49,505 @@ def make_prediction(dataset=None, binary=False):
If binary is True restrict to a binary classification problem instead of a
multiclass classification problem
"""
- pass
+
+ if dataset is None:
+ # import some data to play with
+ dataset = datasets.load_iris()
+
+ X = dataset.data
+ y = dataset.target
+
+ if binary:
+ # restrict to a binary classification task
+ X, y = X[y < 2], y[y < 2]
+
+ n_samples, n_features = X.shape
+ p = np.arange(n_samples)
+
+ rng = check_random_state(37)
+ rng.shuffle(p)
+ X, y = X[p], y[p]
+ half = int(n_samples / 2)
+
+ # add noisy features to make the problem harder and avoid perfect results
+ rng = np.random.RandomState(0)
+ X = np.c_[X, rng.randn(n_samples, 200 * n_features)]
+
+ # run classifier, get class probabilities and label predictions
+ clf = svm.SVC(kernel="linear", probability=True, random_state=0)
+ probas_pred = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
+
+ if binary:
+ # only interested in probabilities of the positive case
+ # XXX: do we really want a special API for the binary case?
+ probas_pred = probas_pred[:, 1]
+
+ y_pred = clf.predict(X[half:])
+ y_true = y[half:]
+
+ return y_true, y_pred, probas_pred
+
+
+###############################################################################
+# Tests
+
+
+def test_sensitivity_specificity_score_binary():
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ # detailed measures for each class
+ sen, spe, sup = sensitivity_specificity_support(y_true, y_pred, average=None)
+ assert_allclose(sen, [0.88, 0.68], rtol=R_TOL)
+ assert_allclose(spe, [0.68, 0.88], rtol=R_TOL)
+ assert_array_equal(sup, [25, 25])
+
+ # individual scoring function that can be used for grid search: in the
+ # binary class case the score is the value of the measure for the positive
+ # class (e.g. label == 1). This is deprecated for average != 'binary'.
+ for kwargs in ({}, {"average": "binary"}):
+ sen = assert_no_warnings(sensitivity_score, y_true, y_pred, **kwargs)
+ assert sen == pytest.approx(0.68, rel=R_TOL)
+
+ spe = assert_no_warnings(specificity_score, y_true, y_pred, **kwargs)
+ assert spe == pytest.approx(0.88, rel=R_TOL)
+
+
+@pytest.mark.filterwarnings("ignore:Specificity is ill-defined")
+@pytest.mark.parametrize(
+ "y_pred, expected_sensitivity, expected_specificity",
+ [(([1, 1], [1, 1]), 1.0, 0.0), (([-1, -1], [-1, -1]), 0.0, 0.0)],
+)
+def test_sensitivity_specificity_f_binary_single_class(
+ y_pred, expected_sensitivity, expected_specificity
+):
+ # Such a case may occur with non-stratified cross-validation
+ assert sensitivity_score(*y_pred) == expected_sensitivity
+ assert specificity_score(*y_pred) == expected_specificity
+
+
+@pytest.mark.parametrize(
+ "average, expected_specificty",
+ [
+ (None, [1.0, 0.67, 1.0, 1.0, 1.0]),
+ ("macro", np.mean([1.0, 0.67, 1.0, 1.0, 1.0])),
+ ("micro", 15 / 16),
+ ],
+)
+def test_sensitivity_specificity_extra_labels(average, expected_specificty):
+ y_true = [1, 3, 3, 2]
+ y_pred = [1, 1, 3, 2]
+
+ actual = specificity_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average=average)
+ assert_allclose(expected_specificty, actual, rtol=R_TOL)
+
+
+def test_sensitivity_specificity_ignored_labels():
+ y_true = [1, 1, 2, 3]
+ y_pred = [1, 3, 3, 3]
+
+ specificity_13 = partial(specificity_score, y_true, y_pred, labels=[1, 3])
+ specificity_all = partial(specificity_score, y_true, y_pred, labels=None)
+
+ assert_allclose([1.0, 0.33], specificity_13(average=None), rtol=R_TOL)
+ assert_allclose(np.mean([1.0, 0.33]), specificity_13(average="macro"), rtol=R_TOL)
+ assert_allclose(
+ np.average([1.0, 0.33], weights=[2.0, 1.0]),
+ specificity_13(average="weighted"),
+ rtol=R_TOL,
+ )
+ assert_allclose(3.0 / (3.0 + 2.0), specificity_13(average="micro"), rtol=R_TOL)
+
+ # ensure the above were meaningful tests:
+ for each in ["macro", "weighted", "micro"]:
+ assert specificity_13(average=each) != specificity_all(average=each)
+
+
+def test_sensitivity_specificity_error_multilabels():
+ y_true = [1, 3, 3, 2]
+ y_pred = [1, 1, 3, 2]
+ y_true_bin = label_binarize(y_true, classes=np.arange(5))
+ y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
+
+ with pytest.raises(ValueError):
+ sensitivity_score(y_true_bin, y_pred_bin)
+
+
+def test_sensitivity_specificity_support_errors():
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ # Bad pos_label
+ with pytest.raises(ValueError):
+ sensitivity_specificity_support(y_true, y_pred, pos_label=2, average="binary")
+
+ # Bad average option
+ with pytest.raises(ValueError):
+ sensitivity_specificity_support([0, 1, 2], [1, 2, 0], average="mega")
+
+
+def test_sensitivity_specificity_unused_pos_label():
+ # but average != 'binary'; even if data is binary
+ msg = r"use labels=\[pos_label\] to specify a single"
+ with pytest.warns(UserWarning, match=msg):
+ sensitivity_specificity_support(
+ [1, 2, 1], [1, 2, 2], pos_label=2, average="macro"
+ )
+
+
+def test_geometric_mean_support_binary():
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ # compute the geometric mean for the binary problem
+ geo_mean = geometric_mean_score(y_true, y_pred)
+
+ assert_allclose(geo_mean, 0.77, rtol=R_TOL)
+
+
+@pytest.mark.filterwarnings("ignore:Recall is ill-defined")
+@pytest.mark.parametrize(
+ "y_true, y_pred, correction, expected_gmean",
+ [
+ ([0, 0, 1, 1], [0, 0, 1, 1], 0.0, 1.0),
+ ([0, 0, 0, 0], [1, 1, 1, 1], 0.0, 0.0),
+ ([0, 0, 0, 0], [0, 0, 0, 0], 0.001, 1.0),
+ ([0, 0, 0, 0], [1, 1, 1, 1], 0.001, 0.001),
+ ([0, 0, 1, 1], [0, 1, 1, 0], 0.001, 0.5),
+ (
+ [0, 1, 2, 0, 1, 2],
+ [0, 2, 1, 0, 0, 1],
+ 0.001,
+ (0.001**2) ** (1 / 3),
+ ),
+ ([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], 0.001, 1),
+ ([0, 1, 1, 1, 1, 0], [0, 0, 1, 1, 1, 1], 0.001, (0.5 * 0.75) ** 0.5),
+ ],
+)
+def test_geometric_mean_multiclass(y_true, y_pred, correction, expected_gmean):
+ gmean = geometric_mean_score(y_true, y_pred, correction=correction)
+ assert gmean == pytest.approx(expected_gmean, rel=R_TOL)
+
+
+@pytest.mark.filterwarnings("ignore:Recall is ill-defined")
+@pytest.mark.parametrize(
+ "y_true, y_pred, average, expected_gmean",
+ [
+ ([0, 1, 2, 0, 1, 2], [0, 2, 1, 0, 0, 1], "macro", 0.471),
+ ([0, 1, 2, 0, 1, 2], [0, 2, 1, 0, 0, 1], "micro", 0.471),
+ ([0, 1, 2, 0, 1, 2], [0, 2, 1, 0, 0, 1], "weighted", 0.471),
+ ([0, 1, 2, 0, 1, 2], [0, 2, 1, 0, 0, 1], None, [0.8660254, 0.0, 0.0]),
+ ],
+)
+def test_geometric_mean_average(y_true, y_pred, average, expected_gmean):
+ gmean = geometric_mean_score(y_true, y_pred, average=average)
+ assert gmean == pytest.approx(expected_gmean, rel=R_TOL)
+
+
+@pytest.mark.parametrize(
+ "y_true, y_pred, sample_weight, average, expected_gmean",
+ [
+ ([0, 1, 2, 0, 1, 2], [0, 1, 1, 0, 0, 1], None, "multiclass", 0.707),
+ (
+ [0, 1, 2, 0, 1, 2],
+ [0, 1, 1, 0, 0, 1],
+ [1, 2, 1, 1, 2, 1],
+ "multiclass",
+ 0.707,
+ ),
+ (
+ [0, 1, 2, 0, 1, 2],
+ [0, 1, 1, 0, 0, 1],
+ [1, 2, 1, 1, 2, 1],
+ "weighted",
+ 0.333,
+ ),
+ ],
+)
+def test_geometric_mean_sample_weight(
+ y_true, y_pred, sample_weight, average, expected_gmean
+):
+ gmean = geometric_mean_score(
+ y_true,
+ y_pred,
+ labels=[0, 1],
+ sample_weight=sample_weight,
+ average=average,
+ )
+ assert gmean == pytest.approx(expected_gmean, rel=R_TOL)
+
+
+@pytest.mark.parametrize(
+ "average, expected_gmean",
+ [
+ ("multiclass", 0.41),
+ (None, [0.85, 0.29, 0.7]),
+ ("macro", 0.68),
+ ("weighted", 0.65),
+ ],
+)
+def test_geometric_mean_score_prediction(average, expected_gmean):
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ gmean = geometric_mean_score(y_true, y_pred, average=average)
+ assert gmean == pytest.approx(expected_gmean, rel=R_TOL)
+
+
+def test_iba_geo_mean_binary():
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ iba_gmean = make_index_balanced_accuracy(alpha=0.5, squared=True)(
+ geometric_mean_score
+ )
+ iba = iba_gmean(y_true, y_pred)
+
+ assert_allclose(iba, 0.5948, rtol=R_TOL)
+
+
+def _format_report(report):
+ return " ".join(report.split())
+
+
+def test_classification_report_imbalanced_multiclass():
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with class names
+ expected_report = (
+ "pre rec spe f1 geo iba sup setosa 0.83 0.79 0.92 "
+ "0.81 0.85 0.72 24 versicolor 0.33 0.10 0.86 0.15 "
+ "0.29 0.08 31 virginica 0.42 0.90 0.55 0.57 0.70 "
+ "0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75"
+ )
+
+ report = classification_report_imbalanced(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ )
+ assert _format_report(report) == expected_report
+ # print classification report with label detection
+ expected_report = (
+ "pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 "
+ "0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 "
+ "2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total "
+ "0.51 0.53 0.80 0.47 0.58 0.40 75"
+ )
+
+ report = classification_report_imbalanced(y_true, y_pred)
+ assert _format_report(report) == expected_report
+
+
+def test_classification_report_imbalanced_multiclass_with_digits():
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with class names
+ expected_report = (
+ "pre rec spe f1 geo iba sup setosa 0.82609 0.79167 "
+ "0.92157 0.80851 0.85415 0.72010 24 versicolor "
+ "0.33333 0.09677 0.86364 0.15000 0.28910 0.07717 "
+ "31 virginica 0.41860 0.90000 0.54545 0.57143 0.70065 "
+ "0.50831 20 avg / total 0.51375 0.53333 0.79733 "
+ "0.47310 0.57966 0.39788 75"
+ )
+ report = classification_report_imbalanced(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ digits=5,
+ )
+ assert _format_report(report) == expected_report
+ # print classification report with label detection
+ expected_report = (
+ "pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 "
+ "0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 "
+ "2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total 0.51 "
+ "0.53 0.80 0.47 0.58 0.40 75"
+ )
+ report = classification_report_imbalanced(y_true, y_pred)
+ assert _format_report(report) == expected_report
+
+
+def test_classification_report_imbalanced_multiclass_with_string_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ y_true = np.array(["blue", "green", "red"])[y_true]
+ y_pred = np.array(["blue", "green", "red"])[y_pred]
+
+ expected_report = (
+ "pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 "
+ "0.85 0.72 24 green 0.33 0.10 0.86 0.15 0.29 0.08 31 "
+ "red 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total "
+ "0.51 0.53 0.80 0.47 0.58 0.40 75"
+ )
+ report = classification_report_imbalanced(y_true, y_pred)
+ assert _format_report(report) == expected_report
+
+ expected_report = (
+ "pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 0.85 "
+ "0.72 24 b 0.33 0.10 0.86 0.15 0.29 0.08 31 c 0.42 "
+ "0.90 0.55 0.57 0.70 0.51 20 avg / total 0.51 0.53 "
+ "0.80 0.47 0.58 0.40 75"
+ )
+ report = classification_report_imbalanced(
+ y_true, y_pred, target_names=["a", "b", "c"]
+ )
+ assert _format_report(report) == expected_report
+
+
+def test_classification_report_imbalanced_multiclass_with_unicode_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ labels = np.array(["blue\xa2", "green\xa2", "red\xa2"])
+ y_true = labels[y_true]
+ y_pred = labels[y_pred]
+
+ expected_report = (
+ "pre rec spe f1 geo iba sup blue¢ 0.83 0.79 0.92 0.81 "
+ "0.85 0.72 24 green¢ 0.33 0.10 0.86 0.15 0.29 0.08 31 "
+ "red¢ 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total "
+ "0.51 0.53 0.80 0.47 0.58 0.40 75"
+ )
+ report = classification_report_imbalanced(y_true, y_pred)
+ assert _format_report(report) == expected_report
+
+
+def test_classification_report_imbalanced_multiclass_with_long_string_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ labels = np.array(["blue", "green" * 5, "red"])
+ y_true = labels[y_true]
+ y_pred = labels[y_pred]
+
+ expected_report = (
+ "pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 "
+ "0.85 0.72 24 greengreengreengreengreen 0.33 0.10 "
+ "0.86 0.15 0.29 0.08 31 red 0.42 0.90 0.55 0.57 0.70 "
+ "0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75"
+ )
+
+ report = classification_report_imbalanced(y_true, y_pred)
+ assert _format_report(report) == expected_report
+
+
+@pytest.mark.parametrize(
+ "score, expected_score",
+ [
+ (accuracy_score, 0.54756),
+ (jaccard_score, 0.33176),
+ (precision_score, 0.65025),
+ (recall_score, 0.41616),
+ ],
+)
+def test_iba_sklearn_metrics(score, expected_score):
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ score_iba = make_index_balanced_accuracy(alpha=0.5, squared=True)(score)
+ score = score_iba(y_true, y_pred)
+ assert score == pytest.approx(expected_score)
+
+
+@pytest.mark.parametrize(
+ "score_loss",
+ [average_precision_score, brier_score_loss, cohen_kappa_score, roc_auc_score],
+)
+def test_iba_error_y_score_prob_error(score_loss):
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ aps = make_index_balanced_accuracy(alpha=0.5, squared=True)(score_loss)
+ with pytest.raises(AttributeError):
+ aps(y_true, y_pred)
+
+
+def test_classification_report_imbalanced_dict_with_target_names():
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ report = classification_report_imbalanced(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ output_dict=True,
+ )
+ outer_keys = set(report.keys())
+ inner_keys = set(report["setosa"].keys())
+
+ expected_outer_keys = {
+ "setosa",
+ "versicolor",
+ "virginica",
+ "avg_pre",
+ "avg_rec",
+ "avg_spe",
+ "avg_f1",
+ "avg_geo",
+ "avg_iba",
+ "total_support",
+ }
+ expected_inner_keys = {"spe", "f1", "sup", "rec", "geo", "iba", "pre"}
+
+ assert outer_keys == expected_outer_keys
+ assert inner_keys == expected_inner_keys
+
+
+def test_classification_report_imbalanced_dict_without_target_names():
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+ print(iris.target_names)
+ report = classification_report_imbalanced(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ output_dict=True,
+ )
+ print(report.keys())
+ outer_keys = set(report.keys())
+ inner_keys = set(report["0"].keys())
+
+ expected_outer_keys = {
+ "0",
+ "1",
+ "2",
+ "avg_pre",
+ "avg_rec",
+ "avg_spe",
+ "avg_f1",
+ "avg_geo",
+ "avg_iba",
+ "total_support",
+ }
+ expected_inner_keys = {"spe", "f1", "sup", "rec", "geo", "iba", "pre"}
+
+ assert outer_keys == expected_outer_keys
+ assert inner_keys == expected_inner_keys
+
+
+@pytest.mark.parametrize(
+ "y_true, y_pred, expected_ma_mae",
+ [
+ ([1, 1, 1, 2, 2, 2], [1, 2, 1, 2, 1, 2], 0.333),
+ ([1, 1, 1, 1, 1, 2], [1, 2, 1, 2, 1, 2], 0.2),
+ ([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.555),
+ ([1, 1, 1, 1, 1, 1, 2, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.166),
+ ],
+)
+def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae):
+ ma_mae = macro_averaged_mean_absolute_error(y_true, y_pred)
+ assert ma_mae == pytest.approx(expected_ma_mae, rel=R_TOL)
+
+
+def test_macro_averaged_mean_absolute_error_sample_weight():
+ y_true = [1, 1, 1, 2, 2, 2]
+ y_pred = [1, 2, 1, 2, 1, 2]
+
+ ma_mae_no_weights = macro_averaged_mean_absolute_error(y_true, y_pred)
+
+ sample_weight = [1, 1, 1, 1, 1, 1]
+ ma_mae_unit_weights = macro_averaged_mean_absolute_error(
+ y_true,
+ y_pred,
+ sample_weight=sample_weight,
+ )
+
+ assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights)
diff --git a/imblearn/metrics/tests/test_pairwise.py b/imblearn/metrics/tests/test_pairwise.py
index ccbfede..d724591 100644
--- a/imblearn/metrics/tests/test_pairwise.py
+++ b/imblearn/metrics/tests/test_pairwise.py
@@ -1,7 +1,172 @@
"""Test for the metrics that perform pairwise distance computation."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import numpy as np
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
from sklearn.utils._testing import _convert_container
+
from imblearn.metrics.pairwise import ValueDifferenceMetric
+
+
+@pytest.fixture
+def data():
+ rng = np.random.RandomState(0)
+
+ feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
+ feature_2 = ["A"] * 40 + ["B"] * 20
+ feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
+ X = np.array([feature_1, feature_2, feature_3], dtype=object).T
+ rng.shuffle(X)
+ y = rng.randint(low=0, high=2, size=X.shape[0])
+ y_labels = np.array(["not apple", "apple"], dtype=object)
+ y = y_labels[y]
+ return X, y
+
+
+@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
+@pytest.mark.parametrize("k, r", [(1, 1), (1, 2), (2, 1), (2, 2)])
+@pytest.mark.parametrize("y_type", ["list", "array"])
+@pytest.mark.parametrize("encode_label", [True, False])
+def test_value_difference_metric(data, dtype, k, r, y_type, encode_label):
+ # Check basic feature of the metric:
+ # * the shape of the distance matrix is (n_samples, n_samples)
+ # * computing pairwise distance of X is the same than explicitely between
+ # X and X.
+ X, y = data
+ y = _convert_container(y, y_type)
+ if encode_label:
+ y = LabelEncoder().fit_transform(y)
+
+ encoder = OrdinalEncoder(dtype=dtype)
+ X_encoded = encoder.fit_transform(X)
+
+ vdm = ValueDifferenceMetric(k=k, r=r)
+ vdm.fit(X_encoded, y)
+
+ dist_1 = vdm.pairwise(X_encoded)
+ dist_2 = vdm.pairwise(X_encoded, X_encoded)
+
+ np.testing.assert_allclose(dist_1, dist_2)
+ assert dist_1.shape == (X.shape[0], X.shape[0])
+ assert dist_2.shape == (X.shape[0], X.shape[0])
+
+
+@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
+@pytest.mark.parametrize("k, r", [(1, 1), (1, 2), (2, 1), (2, 2)])
+@pytest.mark.parametrize("y_type", ["list", "array"])
+@pytest.mark.parametrize("encode_label", [True, False])
+def test_value_difference_metric_property(dtype, k, r, y_type, encode_label):
+ # Check the property of the vdm distance. Let's check the property
+ # described in "Improved Heterogeneous Distance Functions", D.R. Wilson and
+ # T.R. Martinez, Journal of Artificial Intelligence Research 6 (1997) 1-34
+ # https://arxiv.org/pdf/cs/9701101.pdf
+ #
+ # "if an attribute color has three values red, green and blue, and the
+ # application is to identify whether or not an object is an apple, red and
+ # green would be considered closer than red and blue because the former two
+ # both have similar correlations with the output class apple."
+
+ # defined our feature
+ X = np.array(["green"] * 10 + ["red"] * 10 + ["blue"] * 10).reshape(-1, 1)
+ # 0 - not an apple / 1 - an apple
+ y = np.array([1] * 8 + [0] * 5 + [1] * 7 + [0] * 9 + [1])
+ y_labels = np.array(["not apple", "apple"], dtype=object)
+ y = y_labels[y]
+ y = _convert_container(y, y_type)
+ if encode_label:
+ y = LabelEncoder().fit_transform(y)
+
+ encoder = OrdinalEncoder(dtype=dtype)
+ X_encoded = encoder.fit_transform(X)
+
+ vdm = ValueDifferenceMetric(k=k, r=r)
+ vdm.fit(X_encoded, y)
+
+ sample_green = encoder.transform([["green"]])
+ sample_red = encoder.transform([["red"]])
+ sample_blue = encoder.transform([["blue"]])
+
+ for sample in (sample_green, sample_red, sample_blue):
+ # computing the distance between a sample of the same category should
+ # give a null distance
+ dist = vdm.pairwise(sample).squeeze()
+ assert dist == pytest.approx(0)
+
+ # check the property explained in the introduction example
+ dist_1 = vdm.pairwise(sample_green, sample_red).squeeze()
+ dist_2 = vdm.pairwise(sample_blue, sample_red).squeeze()
+ dist_3 = vdm.pairwise(sample_blue, sample_green).squeeze()
+
+ # green and red are very close
+ # blue is closer to red than green
+ assert dist_1 < dist_2
+ assert dist_1 < dist_3
+ assert dist_2 < dist_3
+
+
+def test_value_difference_metric_categories(data):
+ # Check that "auto" is equivalent to provide the number categories
+ # beforehand
+ X, y = data
+
+ encoder = OrdinalEncoder(dtype=np.int32)
+ X_encoded = encoder.fit_transform(X)
+ n_categories = np.array([len(cat) for cat in encoder.categories_])
+
+ vdm_auto = ValueDifferenceMetric().fit(X_encoded, y)
+ vdm_categories = ValueDifferenceMetric(n_categories=n_categories)
+ vdm_categories.fit(X_encoded, y)
+
+ np.testing.assert_array_equal(vdm_auto.n_categories_, n_categories)
+ np.testing.assert_array_equal(vdm_auto.n_categories_, vdm_categories.n_categories_)
+
+
+def test_value_difference_metric_categories_error(data):
+ # Check that we raise an error if n_categories is inconsistent with the
+ # number of features in X
+ X, y = data
+
+ encoder = OrdinalEncoder(dtype=np.int32)
+ X_encoded = encoder.fit_transform(X)
+ n_categories = [1, 2]
+
+ vdm = ValueDifferenceMetric(n_categories=n_categories)
+ err_msg = "The length of n_categories is not consistent with the number"
+ with pytest.raises(ValueError, match=err_msg):
+ vdm.fit(X_encoded, y)
+
+
+def test_value_difference_metric_missing_categories(data):
+ # Check that we don't get issue when a category is missing between 0
+ # n_categories - 1
+ X, y = data
+
+ encoder = OrdinalEncoder(dtype=np.int32)
+ X_encoded = encoder.fit_transform(X)
+ n_categories = np.array([len(cat) for cat in encoder.categories_])
+
+ # remove a categories that could be between 0 and n_categories
+ X_encoded[X_encoded[:, -1] == 1] = 0
+ np.testing.assert_array_equal(np.unique(X_encoded[:, -1]), [0, 2, 3])
+
+ vdm = ValueDifferenceMetric(n_categories=n_categories)
+ vdm.fit(X_encoded, y)
+
+ for n_cats, proba in zip(n_categories, vdm.proba_per_class_):
+ assert proba.shape == (n_cats, len(np.unique(y)))
+
+
+def test_value_difference_value_unfitted(data):
+ # Check that we raise a NotFittedError when `fit` is not not called before
+ # pairwise.
+ X, y = data
+
+ encoder = OrdinalEncoder(dtype=np.int32)
+ X_encoded = encoder.fit_transform(X)
+
+ with pytest.raises(NotFittedError):
+ ValueDifferenceMetric().pairwise(X_encoded)
diff --git a/imblearn/metrics/tests/test_score_objects.py b/imblearn/metrics/tests/test_score_objects.py
index f80a93e..10a1ced 100644
--- a/imblearn/metrics/tests/test_score_objects.py
+++ b/imblearn/metrics/tests/test_score_objects.py
@@ -1,8 +1,78 @@
"""Test for score"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import pytest
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV, train_test_split
-from imblearn.metrics import geometric_mean_score, make_index_balanced_accuracy, sensitivity_score, specificity_score
-R_TOL = 0.01
+
+from imblearn.metrics import (
+ geometric_mean_score,
+ make_index_balanced_accuracy,
+ sensitivity_score,
+ specificity_score,
+)
+
+R_TOL = 1e-2
+
+
+@pytest.fixture
+def data():
+ X, y = make_blobs(random_state=0, centers=2)
+ return train_test_split(X, y, random_state=0)
+
+
+@pytest.mark.parametrize(
+ "score, expected_score",
+ [
+ (sensitivity_score, 0.90),
+ (specificity_score, 0.90),
+ (geometric_mean_score, 0.90),
+ (make_index_balanced_accuracy()(geometric_mean_score), 0.82),
+ ],
+)
+@pytest.mark.parametrize("average", ["macro", "weighted", "micro"])
+def test_scorer_common_average(data, score, expected_score, average):
+ X_train, X_test, y_train, _ = data
+
+ scorer = make_scorer(score, pos_label=None, average=average)
+ grid = GridSearchCV(
+ LogisticRegression(),
+ param_grid={"C": [1, 10]},
+ scoring=scorer,
+ cv=3,
+ )
+ grid.fit(X_train, y_train).predict(X_test)
+
+ assert grid.best_score_ >= expected_score
+
+
+@pytest.mark.parametrize(
+ "score, average, expected_score",
+ [
+ (sensitivity_score, "binary", 0.94),
+ (specificity_score, "binary", 0.89),
+ (geometric_mean_score, "multiclass", 0.90),
+ (
+ make_index_balanced_accuracy()(geometric_mean_score),
+ "multiclass",
+ 0.82,
+ ),
+ ],
+)
+def test_scorer_default_average(data, score, average, expected_score):
+ X_train, X_test, y_train, _ = data
+
+ scorer = make_scorer(score, pos_label=1, average=average)
+ grid = GridSearchCV(
+ LogisticRegression(),
+ param_grid={"C": [1, 10]},
+ scoring=scorer,
+ cv=3,
+ )
+ grid.fit(X_train, y_train).predict(X_test)
+
+ assert grid.best_score_ >= expected_score
diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py
index d8159df..54e88b7 100644
--- a/imblearn/over_sampling/_adasyn.py
+++ b/imblearn/over_sampling/_adasyn.py
@@ -1,18 +1,27 @@
-"""Class to perform over-sampling using ADASYN."""
+"""Class to perform over-sampling using ADASYN."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
import warnings
+
import numpy as np
from scipy import sparse
from sklearn.utils import _safe_indexing, check_random_state
+
from ..utils import Substitution, check_neighbors_object
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import HasMethods, Interval
from .base import BaseOverSampler
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class ADASYN(BaseOverSampler):
"""Oversample using Adaptive Synthetic (ADASYN) algorithm.
@@ -112,13 +121,24 @@ class ADASYN(BaseOverSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 904, 1: 900}})
"""
- _parameter_constraints: dict = {**BaseOverSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'n_jobs': [numbers.Integral, None]}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- n_neighbors=5, n_jobs=None):
+ _parameter_constraints: dict = {
+ **BaseOverSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ n_neighbors=5,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.n_neighbors = n_neighbors
@@ -126,4 +146,88 @@ class ADASYN(BaseOverSampler):
def _validate_estimator(self):
"""Create the necessary objects for ADASYN"""
- pass
+ self.nn_ = check_neighbors_object(
+ "n_neighbors", self.n_neighbors, additional_neighbor=1
+ )
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ self._validate_estimator()
+ random_state = check_random_state(self.random_state)
+
+ X_resampled = [X.copy()]
+ y_resampled = [y.copy()]
+
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X, target_class_indices)
+
+ self.nn_.fit(X)
+ nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
+ # The ratio is computed using a one-vs-rest manner. Using majority
+ # in multi-class would lead to slightly different results at the
+ # cost of introducing a new parameter.
+ n_neighbors = self.nn_.n_neighbors - 1
+ ratio_nn = np.sum(y[nns] != class_sample, axis=1) / n_neighbors
+ if not np.sum(ratio_nn):
+ raise RuntimeError(
+ "Not any neigbours belong to the majority"
+ " class. This case will induce a NaN case"
+ " with a division by zero. ADASYN is not"
+ " suited for this specific dataset."
+ " Use SMOTE instead."
+ )
+ ratio_nn /= np.sum(ratio_nn)
+ n_samples_generate = np.rint(ratio_nn * n_samples).astype(int)
+ # rounding may cause new amount for n_samples
+ n_samples = np.sum(n_samples_generate)
+ if not n_samples:
+ raise ValueError(
+ "No samples will be generated with the provided ratio settings."
+ )
+
+ # the nearest neighbors need to be fitted only on the current class
+ # to find the class NN to generate new samples
+ self.nn_.fit(X_class)
+ nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
+
+ enumerated_class_indices = np.arange(len(target_class_indices))
+ rows = np.repeat(enumerated_class_indices, n_samples_generate)
+ cols = random_state.choice(n_neighbors, size=n_samples)
+ diffs = X_class[nns[rows, cols]] - X_class[rows]
+ steps = random_state.uniform(size=(n_samples, 1))
+
+ if sparse.issparse(X):
+ sparse_func = type(X).__name__
+ steps = getattr(sparse, sparse_func)(steps)
+ X_new = X_class[rows] + steps.multiply(diffs)
+ else:
+ X_new = X_class[rows] + steps * diffs
+
+ X_new = X_new.astype(X.dtype)
+ y_new = np.full(n_samples, fill_value=class_sample, dtype=y.dtype)
+ X_resampled.append(X_new)
+ y_resampled.append(y_new)
+
+ if sparse.issparse(X):
+ X_resampled = sparse.vstack(X_resampled, format=X.format)
+ else:
+ X_resampled = np.vstack(X_resampled)
+ y_resampled = np.hstack(y_resampled)
+
+ return X_resampled, y_resampled
+
+ def _more_tags(self):
+ return {
+ "X_types": ["2darray"],
+ }
diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py
index cffe043..63b5a66 100644
--- a/imblearn/over_sampling/_random_over_sampler.py
+++ b/imblearn/over_sampling/_random_over_sampler.py
@@ -1,10 +1,17 @@
-"""Class to perform random over-sampling."""
+"""Class to perform random over-sampling."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections.abc import Mapping
from numbers import Real
+
import numpy as np
from scipy import sparse
from sklearn.utils import _safe_indexing, check_array, check_random_state
from sklearn.utils.sparsefuncs import mean_variance_axis
+
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval
@@ -12,8 +19,10 @@ from ..utils._validation import _check_X
from .base import BaseOverSampler
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, random_state=_random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ random_state=_random_state_docstring,
+)
class RandomOverSampler(BaseOverSampler):
"""Class to perform random over-sampling.
@@ -127,12 +136,125 @@ class RandomOverSampler(BaseOverSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
- _parameter_constraints: dict = {**BaseOverSampler.
- _parameter_constraints, 'shrinkage': [Interval(Real, 0, None,
- closed='left'), dict, None]}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- shrinkage=None):
+ _parameter_constraints: dict = {
+ **BaseOverSampler._parameter_constraints,
+ "shrinkage": [Interval(Real, 0, None, closed="left"), dict, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ shrinkage=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.shrinkage = shrinkage
+
+ def _check_X_y(self, X, y):
+ y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
+ X = _check_X(X)
+ self._check_n_features(X, reset=True)
+ self._check_feature_names(X, reset=True)
+ return X, y, binarize_y
+
+ def _fit_resample(self, X, y):
+ random_state = check_random_state(self.random_state)
+
+ if isinstance(self.shrinkage, Real):
+ self.shrinkage_ = {
+ klass: self.shrinkage for klass in self.sampling_strategy_
+ }
+ elif self.shrinkage is None or isinstance(self.shrinkage, Mapping):
+ self.shrinkage_ = self.shrinkage
+
+ if self.shrinkage_ is not None:
+ missing_shrinkage_keys = (
+ self.sampling_strategy_.keys() - self.shrinkage_.keys()
+ )
+ if missing_shrinkage_keys:
+ raise ValueError(
+ f"`shrinkage` should contain a shrinkage factor for "
+ f"each class that will be resampled. The missing "
+ f"classes are: {repr(missing_shrinkage_keys)}"
+ )
+
+ for klass, shrink_factor in self.shrinkage_.items():
+ if shrink_factor < 0:
+ raise ValueError(
+ f"The shrinkage factor needs to be >= 0. "
+ f"Got {shrink_factor} for class {klass}."
+ )
+
+ # smoothed bootstrap imposes to make numerical operation; we need
+ # to be sure to have only numerical data in X
+ try:
+ X = check_array(X, accept_sparse=["csr", "csc"], dtype="numeric")
+ except ValueError as exc:
+ raise ValueError(
+ "When shrinkage is not None, X needs to contain only "
+ "numerical data to later generate a smoothed bootstrap "
+ "sample."
+ ) from exc
+
+ X_resampled = [X.copy()]
+ y_resampled = [y.copy()]
+
+ sample_indices = range(X.shape[0])
+ for class_sample, num_samples in self.sampling_strategy_.items():
+ target_class_indices = np.flatnonzero(y == class_sample)
+ bootstrap_indices = random_state.choice(
+ target_class_indices,
+ size=num_samples,
+ replace=True,
+ )
+ sample_indices = np.append(sample_indices, bootstrap_indices)
+ if self.shrinkage_ is not None:
+ # generate a smoothed bootstrap with a perturbation
+ n_samples, n_features = X.shape
+ smoothing_constant = (4 / ((n_features + 2) * n_samples)) ** (
+ 1 / (n_features + 4)
+ )
+ if sparse.issparse(X):
+ _, X_class_variance = mean_variance_axis(
+ X[target_class_indices, :],
+ axis=0,
+ )
+ X_class_scale = np.sqrt(X_class_variance, out=X_class_variance)
+ else:
+ X_class_scale = np.std(X[target_class_indices, :], axis=0)
+ smoothing_matrix = np.diagflat(
+ self.shrinkage_[class_sample] * smoothing_constant * X_class_scale
+ )
+ X_new = random_state.randn(num_samples, n_features)
+ X_new = X_new.dot(smoothing_matrix) + X[bootstrap_indices, :]
+ if sparse.issparse(X):
+ X_new = sparse.csr_matrix(X_new, dtype=X.dtype)
+ X_resampled.append(X_new)
+ else:
+ # generate a bootstrap
+ X_resampled.append(_safe_indexing(X, bootstrap_indices))
+
+ y_resampled.append(_safe_indexing(y, bootstrap_indices))
+
+ self.sample_indices_ = np.array(sample_indices)
+
+ if sparse.issparse(X):
+ X_resampled = sparse.vstack(X_resampled, format=X.format)
+ else:
+ X_resampled = np.vstack(X_resampled)
+ y_resampled = np.hstack(y_resampled)
+
+ return X_resampled, y_resampled
+
+ def _more_tags(self):
+ return {
+ "X_types": ["2darray", "string", "sparse", "dataframe"],
+ "sample_indices": True,
+ "allow_nan": True,
+ "_xfail_checks": {
+ "check_complex_data": "Robust to this type of data.",
+ },
+ }
diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py
index 968941c..8ef9029 100644
--- a/imblearn/over_sampling/_smote/base.py
+++ b/imblearn/over_sampling/_smote/base.py
@@ -1,17 +1,32 @@
"""Base class and original SMOTE methods for over-sampling"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Fernando Nogueira
+# Christos Aridas
+# Dzianis Dudnik
+# License: MIT
+
import math
import numbers
import warnings
+
import numpy as np
import sklearn
from scipy import sparse
from sklearn.base import clone
from sklearn.exceptions import DataConversionWarning
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
-from sklearn.utils import _safe_indexing, check_array, check_random_state
+from sklearn.utils import (
+ _safe_indexing,
+ check_array,
+ check_random_state,
+)
from sklearn.utils.fixes import parse_version
-from sklearn.utils.sparsefuncs_fast import csr_mean_variance_axis0
+from sklearn.utils.sparsefuncs_fast import (
+ csr_mean_variance_axis0,
+)
from sklearn.utils.validation import _num_features
+
from ...metrics.pairwise import ValueDifferenceMetric
from ...utils import Substitution, check_neighbors_object, check_target_type
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
@@ -19,8 +34,9 @@ from ...utils._param_validation import HasMethods, Interval, StrOptions
from ...utils._validation import _check_X
from ...utils.fixes import _is_pandas_df, _mode
from ..base import BaseOverSampler
+
sklearn_version = parse_version(sklearn.__version__).base_version
-if parse_version(sklearn_version) < parse_version('1.5'):
+if parse_version(sklearn_version) < parse_version("1.5"):
from sklearn.utils import _get_column_indices
else:
from sklearn.utils._indexing import _get_column_indices
@@ -28,13 +44,23 @@ else:
class BaseSMOTE(BaseOverSampler):
"""Base class for the different SMOTE algorithms."""
- _parameter_constraints: dict = {**BaseOverSampler.
- _parameter_constraints, 'k_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'n_jobs': [numbers.Integral, None]}
- def __init__(self, sampling_strategy='auto', random_state=None,
- k_neighbors=5, n_jobs=None):
+ _parameter_constraints: dict = {
+ **BaseOverSampler._parameter_constraints,
+ "k_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.k_neighbors = k_neighbors
@@ -44,10 +70,13 @@ class BaseSMOTE(BaseOverSampler):
"""Check the NN estimators shared across the different SMOTE
algorithms.
"""
- pass
+ self.nn_k_ = check_neighbors_object(
+ "k_neighbors", self.k_neighbors, additional_neighbor=1
+ )
- def _make_samples(self, X, y_dtype, y_type, nn_data, nn_num, n_samples,
- step_size=1.0, y=None):
+ def _make_samples(
+ self, X, y_dtype, y_type, nn_data, nn_num, n_samples, step_size=1.0, y=None
+ ):
"""A support function that returns artificial samples constructed along
the line connecting nearest neighbours.
@@ -88,21 +117,32 @@ class BaseSMOTE(BaseOverSampler):
y_new : ndarray of shape (n_samples_new,)
Target values for synthetic samples.
"""
- pass
+ random_state = check_random_state(self.random_state)
+ samples_indices = random_state.randint(low=0, high=nn_num.size, size=n_samples)
+
+ # np.newaxis for backwards compatability with random_state
+ steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis]
+ rows = np.floor_divide(samples_indices, nn_num.shape[1])
+ cols = np.mod(samples_indices, nn_num.shape[1])
+
+ X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps, y_type, y)
+ y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype)
+ return X_new, y_new
- def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps,
- y_type=None, y=None):
- """Generate a synthetic sample.
+ def _generate_samples(
+ self, X, nn_data, nn_num, rows, cols, steps, y_type=None, y=None
+ ):
+ r"""Generate a synthetic sample.
The rule for the generation is:
.. math::
- \\mathbf{s_{s}} = \\mathbf{s_{i}} + \\mathcal{u}(0, 1) \\times
- (\\mathbf{s_{i}} - \\mathbf{s_{nn}}) \\,
+ \mathbf{s_{s}} = \mathbf{s_{i}} + \mathcal{u}(0, 1) \times
+ (\mathbf{s_{i}} - \mathbf{s_{nn}}) \,
- where \\mathbf{s_{s}} is the new synthetic samples, \\mathbf{s_{i}} is
- the current sample, \\mathbf{s_{nn}} is a randomly selected neighbors of
- \\mathbf{s_{i}} and \\mathcal{u}(0, 1) is a random number between [0, 1).
+ where \mathbf{s_{s}} is the new synthetic samples, \mathbf{s_{i}} is
+ the current sample, \mathbf{s_{nn}} is a randomly selected neighbors of
+ \mathbf{s_{i}} and \mathcal{u}(0, 1) is a random number between [0, 1).
Parameters
----------
@@ -139,10 +179,24 @@ class BaseSMOTE(BaseOverSampler):
X_new : {ndarray, sparse matrix} of shape (n_samples, n_features)
Synthetically generated samples.
"""
- pass
-
- def _in_danger_noise(self, nn_estimator, samples, target_class, y, kind
- ='danger'):
+ diffs = nn_data[nn_num[rows, cols]] - X[rows]
+ if y is not None: # only entering for BorderlineSMOTE-2
+ random_state = check_random_state(self.random_state)
+ mask_pair_samples = y[nn_num[rows, cols]] != y_type
+ diffs[mask_pair_samples] *= random_state.uniform(
+ low=0.0, high=0.5, size=(mask_pair_samples.sum(), 1)
+ )
+
+ if sparse.issparse(X):
+ sparse_func = type(X).__name__
+ steps = getattr(sparse, sparse_func)(steps)
+ X_new = X[rows] + steps.multiply(diffs)
+ else:
+ X_new = X[rows] + steps * diffs
+
+ return X_new.astype(X.dtype)
+
+ def _in_danger_noise(self, nn_estimator, samples, target_class, y, kind="danger"):
"""Estimate if a set of sample are in danger or noise.
Used by BorderlineSMOTE and SVMSMOTE.
@@ -174,12 +228,26 @@ class BaseSMOTE(BaseOverSampler):
output : ndarray of shape (n_samples,)
A boolean array where True refer to samples in danger or noise.
"""
- pass
-
-
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+ x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:]
+ nn_label = (y[x] != target_class).astype(int)
+ n_maj = np.sum(nn_label, axis=1)
+
+ if kind == "danger":
+ # Samples are in danger for m/2 <= m' < m
+ return np.bitwise_and(
+ n_maj >= (nn_estimator.n_neighbors - 1) / 2,
+ n_maj < nn_estimator.n_neighbors - 1,
+ )
+ else: # kind == "noise":
+ # Samples are noise for m = m'
+ return n_maj == nn_estimator.n_neighbors - 1
+
+
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SMOTE(BaseSMOTE):
"""Class to perform over-sampling using SMOTE.
@@ -281,15 +349,64 @@ class SMOTE(BaseSMOTE):
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- k_neighbors=5, n_jobs=None):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
-
-
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ self._validate_estimator()
+
+ X_resampled = [X.copy()]
+ y_resampled = [y.copy()]
+
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X, target_class_indices)
+
+ self.nn_k_.fit(X_class)
+ nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:]
+ X_new, y_new = self._make_samples(
+ X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0
+ )
+ X_resampled.append(X_new)
+ y_resampled.append(y_new)
+
+ if sparse.issparse(X):
+ X_resampled = sparse.vstack(X_resampled, format=X.format)
+ else:
+ X_resampled = np.vstack(X_resampled)
+ y_resampled = np.hstack(y_resampled)
+
+ return X_resampled, y_resampled
+
+
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SMOTENC(SMOTE):
"""Synthetic Minority Over-sampling Technique for Nominal and Continuous.
@@ -303,7 +420,8 @@ class SMOTENC(SMOTE):
Parameters
----------
- categorical_features : "infer" or array-like of shape (n_cat_features,) or (n_features,), dtype={{bool, int, str}}
+ categorical_features : "infer" or array-like of shape (n_cat_features,) or \
+ (n_features,), dtype={{bool, int, str}}
Specified which features are categorical. Can either be:
- "auto" (default) to automatically detect categorical features. Only
@@ -444,17 +562,34 @@ class SMOTENC(SMOTE):
>>> print(f'Resampled dataset samples per class {{Counter(y_res)}}')
Resampled dataset samples per class Counter({{0: 900, 1: 900}})
"""
- _required_parameters = ['categorical_features']
- _parameter_constraints: dict = {**SMOTE._parameter_constraints,
- 'categorical_features': ['array-like', StrOptions({'auto'})],
- 'categorical_encoder': [HasMethods(['fit_transform',
- 'inverse_transform']), None]}
-
- def __init__(self, categorical_features, *, categorical_encoder=None,
- sampling_strategy='auto', random_state=None, k_neighbors=5, n_jobs=None
- ):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
+
+ _required_parameters = ["categorical_features"]
+
+ _parameter_constraints: dict = {
+ **SMOTE._parameter_constraints,
+ "categorical_features": ["array-like", StrOptions({"auto"})],
+ "categorical_encoder": [
+ HasMethods(["fit_transform", "inverse_transform"]),
+ None,
+ ],
+ }
+
+ def __init__(
+ self,
+ categorical_features,
+ *,
+ categorical_encoder=None,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
self.categorical_features = categorical_features
self.categorical_encoder = categorical_encoder
@@ -462,14 +597,170 @@ class SMOTENC(SMOTE):
"""Overwrite the checking to let pass some string for categorical
features.
"""
- pass
+ y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
+ X = _check_X(X)
+ self._check_n_features(X, reset=True)
+ self._check_feature_names(X, reset=True)
+ return X, y, binarize_y
def _validate_column_types(self, X):
"""Compute the indices of the categorical and continuous features."""
- pass
+ if self.categorical_features == "auto":
+ if not _is_pandas_df(X):
+ raise ValueError(
+ "When `categorical_features='auto'`, the input data "
+ f"should be a pandas.DataFrame. Got {type(X)} instead."
+ )
+ import pandas as pd # safely import pandas now
+
+ are_columns_categorical = np.array(
+ [isinstance(col_dtype, pd.CategoricalDtype) for col_dtype in X.dtypes]
+ )
+ self.categorical_features_ = np.flatnonzero(are_columns_categorical)
+ self.continuous_features_ = np.flatnonzero(~are_columns_categorical)
+ else:
+ self.categorical_features_ = np.array(
+ _get_column_indices(X, self.categorical_features)
+ )
+ self.continuous_features_ = np.setdiff1d(
+ np.arange(self.n_features_), self.categorical_features_
+ )
- def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps,
- y_type, y=None):
+ def _validate_estimator(self):
+ super()._validate_estimator()
+ if self.categorical_features_.size == self.n_features_in_:
+ raise ValueError(
+ "SMOTE-NC is not designed to work only with categorical "
+ "features. It requires some numerical features."
+ )
+ elif self.categorical_features_.size == 0:
+ raise ValueError(
+ "SMOTE-NC is not designed to work only with numerical "
+ "features. It requires some categorical features."
+ )
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ self.n_features_ = _num_features(X)
+ self._validate_column_types(X)
+ self._validate_estimator()
+
+ X_continuous = _safe_indexing(X, self.continuous_features_, axis=1)
+ X_continuous = check_array(X_continuous, accept_sparse=["csr", "csc"])
+ X_categorical = _safe_indexing(X, self.categorical_features_, axis=1)
+ if X_continuous.dtype.name != "object":
+ dtype_ohe = X_continuous.dtype
+ else:
+ dtype_ohe = np.float64
+
+ if self.categorical_encoder is None:
+ self.categorical_encoder_ = OneHotEncoder(
+ handle_unknown="ignore", dtype=dtype_ohe
+ )
+ else:
+ self.categorical_encoder_ = clone(self.categorical_encoder)
+
+ # the input of the OneHotEncoder needs to be dense
+ X_ohe = self.categorical_encoder_.fit_transform(
+ X_categorical.toarray() if sparse.issparse(X_categorical) else X_categorical
+ )
+ if not sparse.issparse(X_ohe):
+ X_ohe = sparse.csr_matrix(X_ohe, dtype=dtype_ohe)
+
+ X_encoded = sparse.hstack((X_continuous, X_ohe), format="csr", dtype=dtype_ohe)
+ X_resampled = [X_encoded.copy()]
+ y_resampled = [y.copy()]
+
+ # SMOTE resampling starts here
+ self.median_std_ = {}
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X_encoded, target_class_indices)
+
+ _, var = csr_mean_variance_axis0(
+ X_class[:, : self.continuous_features_.size]
+ )
+ self.median_std_[class_sample] = np.median(np.sqrt(var))
+
+ # In the edge case where the median of the std is equal to 0, the 1s
+ # entries will be also nullified. In this case, we store the original
+ # categorical encoding which will be later used for inverting the OHE
+ if math.isclose(self.median_std_[class_sample], 0):
+ # This variable will be used when generating data
+ self._X_categorical_minority_encoded = X_class[
+ :, self.continuous_features_.size :
+ ].toarray()
+
+ # we can replace the 1 entries of the categorical features with the
+ # median of the standard deviation. It will ensure that whenever
+ # distance is computed between 2 samples, the difference will be equal
+ # to the median of the standard deviation as in the original paper.
+ X_class_categorical = X_class[:, self.continuous_features_.size :]
+ # With one-hot encoding, the median will be repeated twice. We need
+ # to divide by sqrt(2) such that we only have one median value
+ # contributing to the Euclidean distance
+ X_class_categorical.data[:] = self.median_std_[class_sample] / np.sqrt(2)
+ X_class[:, self.continuous_features_.size :] = X_class_categorical
+
+ self.nn_k_.fit(X_class)
+ nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:]
+ X_new, y_new = self._make_samples(
+ X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0
+ )
+ X_resampled.append(X_new)
+ y_resampled.append(y_new)
+
+ X_resampled = sparse.vstack(X_resampled, format=X_encoded.format)
+ y_resampled = np.hstack(y_resampled)
+ # SMOTE resampling ends here
+
+ # reverse the encoding of the categorical features
+ X_res_cat = X_resampled[:, self.continuous_features_.size :]
+ X_res_cat.data = np.ones_like(X_res_cat.data)
+ X_res_cat_dec = self.categorical_encoder_.inverse_transform(X_res_cat)
+
+ if sparse.issparse(X):
+ X_resampled = sparse.hstack(
+ (
+ X_resampled[:, : self.continuous_features_.size],
+ X_res_cat_dec,
+ ),
+ format="csr",
+ )
+ else:
+ X_resampled = np.hstack(
+ (
+ X_resampled[:, : self.continuous_features_.size].toarray(),
+ X_res_cat_dec,
+ )
+ )
+
+ indices_reordered = np.argsort(
+ np.hstack((self.continuous_features_, self.categorical_features_))
+ )
+ if sparse.issparse(X_resampled):
+ # the matrix is supposed to be in the CSR format after the stacking
+ col_indices = X_resampled.indices.copy()
+ for idx, col_idx in enumerate(indices_reordered):
+ mask = X_resampled.indices == col_idx
+ col_indices[mask] = idx
+ X_resampled.indices = col_indices
+ else:
+ X_resampled = X_resampled[:, indices_reordered]
+
+ return X_resampled, y_resampled
+
+ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps, y_type, y=None):
"""Generate a synthetic sample with an additional steps for the
categorical features.
@@ -477,17 +768,59 @@ class SMOTENC(SMOTE):
categorical features are mapped to the most frequent nearest neighbors
of the majority class.
"""
- pass
+ rng = check_random_state(self.random_state)
+ X_new = super()._generate_samples(X, nn_data, nn_num, rows, cols, steps)
+ # change in sparsity structure more efficient with LIL than CSR
+ X_new = X_new.tolil() if sparse.issparse(X_new) else X_new
+
+ # convert to dense array since scipy.sparse doesn't handle 3D
+ nn_data = nn_data.toarray() if sparse.issparse(nn_data) else nn_data
+
+ # In the case that the median std was equal to zeros, we have to
+ # create non-null entry based on the encoded of OHE
+ if math.isclose(self.median_std_[y_type], 0):
+ nn_data[
+ :, self.continuous_features_.size :
+ ] = self._X_categorical_minority_encoded
+
+ all_neighbors = nn_data[nn_num[rows]]
+
+ categories_size = [self.continuous_features_.size] + [
+ cat.size for cat in self.categorical_encoder_.categories_
+ ]
+
+ for start_idx, end_idx in zip(
+ np.cumsum(categories_size)[:-1], np.cumsum(categories_size)[1:]
+ ):
+ col_maxs = all_neighbors[:, :, start_idx:end_idx].sum(axis=1)
+ # tie breaking argmax
+ is_max = np.isclose(col_maxs, col_maxs.max(axis=1, keepdims=True))
+ max_idxs = rng.permutation(np.argwhere(is_max))
+ xs, idx_sels = np.unique(max_idxs[:, 0], return_index=True)
+ col_sels = max_idxs[idx_sels, 1]
+
+ ys = start_idx + col_sels
+ X_new[:, start_idx:end_idx] = 0
+ X_new[xs, ys] = 1
+
+ return X_new
@property
def ohe_(self):
"""One-hot encoder used to encode the categorical features."""
- pass
-
-
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+ warnings.warn(
+ "'ohe_' attribute has been deprecated in 0.11 and will be removed "
+ "in 0.13. Use 'categorical_encoder_' instead.",
+ FutureWarning,
+ )
+ return self.categorical_encoder_
+
+
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SMOTEN(SMOTE):
"""Synthetic Minority Over-sampling Technique for Nominal.
@@ -595,20 +928,128 @@ class SMOTEN(SMOTE):
>>> print(f"Class counts after resampling {{Counter(y_res)}}")
Class counts after resampling Counter({{0: 40, 1: 40}})
"""
- _parameter_constraints: dict = {**SMOTE._parameter_constraints,
- 'categorical_encoder': [HasMethods(['fit_transform',
- 'inverse_transform']), None]}
-
- def __init__(self, categorical_encoder=None, *, sampling_strategy=
- 'auto', random_state=None, k_neighbors=5, n_jobs=None):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
+
+ _parameter_constraints: dict = {
+ **SMOTE._parameter_constraints,
+ "categorical_encoder": [
+ HasMethods(["fit_transform", "inverse_transform"]),
+ None,
+ ],
+ }
+
+ def __init__(
+ self,
+ categorical_encoder=None,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
self.categorical_encoder = categorical_encoder
def _check_X_y(self, X, y):
"""Check should accept strings and not sparse matrices."""
- pass
+ y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
+ X, y = self._validate_data(
+ X,
+ y,
+ reset=True,
+ dtype=None,
+ accept_sparse=["csr", "csc"],
+ )
+ return X, y, binarize_y
def _validate_estimator(self):
"""Force to use precomputed distance matrix."""
- pass
+ super()._validate_estimator()
+ self.nn_k_.set_params(metric="precomputed")
+
+ def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
+ random_state = check_random_state(self.random_state)
+ # generate sample indices that will be used to generate new samples
+ samples_indices = random_state.choice(
+ np.arange(X_class.shape[0]), size=n_samples, replace=True
+ )
+ # for each drawn samples, select its k-neighbors and generate a sample
+ # where for each feature individually, each category generated is the
+ # most common category
+ X_new = np.squeeze(
+ _mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
+ )
+ y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
+ return X_new, y_new
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ if sparse.issparse(X):
+ X_sparse_format = X.format
+ X = X.toarray()
+ warnings.warn(
+ "Passing a sparse matrix to SMOTEN is not really efficient since it is"
+ " converted to a dense array internally.",
+ DataConversionWarning,
+ )
+ else:
+ X_sparse_format = None
+
+ self._validate_estimator()
+
+ X_resampled = [X.copy()]
+ y_resampled = [y.copy()]
+
+ if self.categorical_encoder is None:
+ self.categorical_encoder_ = OrdinalEncoder(dtype=np.int32)
+ else:
+ self.categorical_encoder_ = clone(self.categorical_encoder)
+ X_encoded = self.categorical_encoder_.fit_transform(X)
+
+ vdm = ValueDifferenceMetric(
+ n_categories=[len(cat) for cat in self.categorical_encoder_.categories_]
+ ).fit(X_encoded, y)
+
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X_encoded, target_class_indices)
+
+ X_class_dist = vdm.pairwise(X_class)
+ self.nn_k_.fit(X_class_dist)
+ # the kneigbors search will include the sample itself which is
+ # expected from the original algorithm
+ nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False)
+ X_new, y_new = self._make_samples(
+ X_class, class_sample, y.dtype, nn_indices, n_samples
+ )
+
+ X_new = self.categorical_encoder_.inverse_transform(X_new)
+ X_resampled.append(X_new)
+ y_resampled.append(y_new)
+
+ X_resampled = np.vstack(X_resampled)
+ y_resampled = np.hstack(y_resampled)
+
+ if X_sparse_format == "csr":
+ return sparse.csr_matrix(X_resampled), y_resampled
+ elif X_sparse_format == "csc":
+ return sparse.csc_matrix(X_resampled), y_resampled
+ else:
+ return X_resampled, y_resampled
+
+ def _more_tags(self):
+ return {"X_types": ["2darray", "dataframe", "string"]}
diff --git a/imblearn/over_sampling/_smote/cluster.py b/imblearn/over_sampling/_smote/cluster.py
index 31fb344..2852cfd 100644
--- a/imblearn/over_sampling/_smote/cluster.py
+++ b/imblearn/over_sampling/_smote/cluster.py
@@ -1,12 +1,20 @@
"""SMOTE variant employing some clustering before the generation."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Fernando Nogueira
+# Christos Aridas
+# License: MIT
+
import math
import numbers
+
import numpy as np
from scipy import sparse
from sklearn.base import clone
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import pairwise_distances
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval, StrOptions
@@ -14,9 +22,11 @@ from ..base import BaseOverSampler
from .base import BaseSMOTE
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class KMeansSMOTE(BaseSMOTE):
"""Apply a KMeans clustering before to over-sample using SMOTE.
@@ -135,21 +145,163 @@ class KMeansSMOTE(BaseSMOTE):
>>> print("More 0 samples: %s" % ((y_res == 0).sum() > (y == 0).sum()))
More 0 samples: True
"""
- _parameter_constraints: dict = {**BaseSMOTE._parameter_constraints,
- 'kmeans_estimator': [HasMethods(['fit', 'predict']), Interval(
- numbers.Integral, 1, None, closed='left'), None],
- 'cluster_balance_threshold': [StrOptions({'auto'}), numbers.Real],
- 'density_exponent': [StrOptions({'auto'}), numbers.Real]}
-
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- k_neighbors=2, n_jobs=None, kmeans_estimator=None,
- cluster_balance_threshold='auto', density_exponent='auto'):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
+
+ _parameter_constraints: dict = {
+ **BaseSMOTE._parameter_constraints,
+ "kmeans_estimator": [
+ HasMethods(["fit", "predict"]),
+ Interval(numbers.Integral, 1, None, closed="left"),
+ None,
+ ],
+ "cluster_balance_threshold": [StrOptions({"auto"}), numbers.Real],
+ "density_exponent": [StrOptions({"auto"}), numbers.Real],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=2,
+ n_jobs=None,
+ kmeans_estimator=None,
+ cluster_balance_threshold="auto",
+ density_exponent="auto",
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
self.kmeans_estimator = kmeans_estimator
self.cluster_balance_threshold = cluster_balance_threshold
self.density_exponent = density_exponent
+ def _validate_estimator(self):
+ super()._validate_estimator()
+ if self.kmeans_estimator is None:
+ self.kmeans_estimator_ = MiniBatchKMeans(random_state=self.random_state)
+ elif isinstance(self.kmeans_estimator, int):
+ self.kmeans_estimator_ = MiniBatchKMeans(
+ n_clusters=self.kmeans_estimator,
+ random_state=self.random_state,
+ )
+ else:
+ self.kmeans_estimator_ = clone(self.kmeans_estimator)
+
+ self.cluster_balance_threshold_ = (
+ self.cluster_balance_threshold
+ if self.kmeans_estimator_.n_clusters != 1
+ else -np.inf
+ )
+
def _find_cluster_sparsity(self, X):
"""Compute the cluster sparsity."""
- pass
+ euclidean_distances = pairwise_distances(
+ X, metric="euclidean", n_jobs=self.n_jobs
+ )
+ # negate diagonal elements
+ for ind in range(X.shape[0]):
+ euclidean_distances[ind, ind] = 0
+
+ non_diag_elements = (X.shape[0] ** 2) - X.shape[0]
+ mean_distance = euclidean_distances.sum() / non_diag_elements
+ exponent = (
+ math.log(X.shape[0], 1.6) ** 1.8 * 0.16
+ if self.density_exponent == "auto"
+ else self.density_exponent
+ )
+ return (mean_distance**exponent) / X.shape[0]
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+ X_resampled = X.copy()
+ y_resampled = y.copy()
+ total_inp_samples = sum(self.sampling_strategy_.values())
+
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+
+ X_clusters = self.kmeans_estimator_.fit_predict(X)
+ valid_clusters = []
+ cluster_sparsities = []
+
+ # identify cluster which are answering the requirements
+ for cluster_idx in range(self.kmeans_estimator_.n_clusters):
+ cluster_mask = np.flatnonzero(X_clusters == cluster_idx)
+
+ if cluster_mask.size == 0:
+ # empty cluster
+ continue
+
+ X_cluster = _safe_indexing(X, cluster_mask)
+ y_cluster = _safe_indexing(y, cluster_mask)
+
+ cluster_class_mean = (y_cluster == class_sample).mean()
+
+ if self.cluster_balance_threshold_ == "auto":
+ balance_threshold = n_samples / total_inp_samples / 2
+ else:
+ balance_threshold = self.cluster_balance_threshold_
+
+ # the cluster is already considered balanced
+ if cluster_class_mean < balance_threshold:
+ continue
+
+ # not enough samples to apply SMOTE
+ anticipated_samples = cluster_class_mean * X_cluster.shape[0]
+ if anticipated_samples < self.nn_k_.n_neighbors:
+ continue
+
+ X_cluster_class = _safe_indexing(
+ X_cluster, np.flatnonzero(y_cluster == class_sample)
+ )
+
+ valid_clusters.append(cluster_mask)
+ cluster_sparsities.append(self._find_cluster_sparsity(X_cluster_class))
+
+ cluster_sparsities = np.array(cluster_sparsities)
+ cluster_weights = cluster_sparsities / cluster_sparsities.sum()
+
+ if not valid_clusters:
+ raise RuntimeError(
+ f"No clusters found with sufficient samples of "
+ f"class {class_sample}. Try lowering the "
+ f"cluster_balance_threshold or increasing the number of "
+ f"clusters."
+ )
+
+ for valid_cluster_idx, valid_cluster in enumerate(valid_clusters):
+ X_cluster = _safe_indexing(X, valid_cluster)
+ y_cluster = _safe_indexing(y, valid_cluster)
+
+ X_cluster_class = _safe_indexing(
+ X_cluster, np.flatnonzero(y_cluster == class_sample)
+ )
+
+ self.nn_k_.fit(X_cluster_class)
+ nns = self.nn_k_.kneighbors(X_cluster_class, return_distance=False)[
+ :, 1:
+ ]
+
+ cluster_n_samples = int(
+ math.ceil(n_samples * cluster_weights[valid_cluster_idx])
+ )
+
+ X_new, y_new = self._make_samples(
+ X_cluster_class,
+ y.dtype,
+ class_sample,
+ X_cluster_class,
+ nns,
+ cluster_n_samples,
+ 1.0,
+ )
+
+ stack = [np.vstack, sparse.vstack][int(sparse.issparse(X_new))]
+ X_resampled = stack((X_resampled, X_new))
+ y_resampled = np.hstack((y_resampled, y_new))
+
+ return X_resampled, y_resampled
diff --git a/imblearn/over_sampling/_smote/filter.py b/imblearn/over_sampling/_smote/filter.py
index 454b627..2916b68 100644
--- a/imblearn/over_sampling/_smote/filter.py
+++ b/imblearn/over_sampling/_smote/filter.py
@@ -1,11 +1,20 @@
-"""SMOTE variant applying some filtering before the generation process."""
+"""SMOTE variant applying some filtering before the generation process."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Fernando Nogueira
+# Christos Aridas
+# Dzianis Dudnik
+# License: MIT
+
import numbers
import warnings
+
import numpy as np
from scipy import sparse
from sklearn.base import clone
from sklearn.svm import SVC
from sklearn.utils import _safe_indexing, check_random_state
+
from ...utils import Substitution, check_neighbors_object
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval, StrOptions
@@ -13,9 +22,11 @@ from ..base import BaseOverSampler
from .base import BaseSMOTE
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class BorderlineSMOTE(BaseSMOTE):
"""Over-sampling using Borderline SMOTE.
@@ -145,22 +156,104 @@ class BorderlineSMOTE(BaseSMOTE):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
- _parameter_constraints: dict = {**BaseSMOTE._parameter_constraints,
- 'm_neighbors': [Interval(numbers.Integral, 1, None, closed='left'),
- HasMethods(['kneighbors', 'kneighbors_graph'])], 'kind': [
- StrOptions({'borderline-1', 'borderline-2'})]}
-
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- k_neighbors=5, n_jobs=None, m_neighbors=10, kind='borderline-1'):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
+
+ _parameter_constraints: dict = {
+ **BaseSMOTE._parameter_constraints,
+ "m_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "kind": [StrOptions({"borderline-1", "borderline-2"})],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ m_neighbors=10,
+ kind="borderline-1",
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
self.m_neighbors = m_neighbors
self.kind = kind
-
-@Substitution(sampling_strategy=BaseOverSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+ def _validate_estimator(self):
+ super()._validate_estimator()
+ self.nn_m_ = check_neighbors_object(
+ "m_neighbors", self.m_neighbors, additional_neighbor=1
+ )
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ self._validate_estimator()
+
+ X_resampled = X.copy()
+ y_resampled = y.copy()
+
+ self.in_danger_indices = {}
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X, target_class_indices)
+
+ self.nn_m_.fit(X)
+ mask_danger = self._in_danger_noise(
+ self.nn_m_, X_class, class_sample, y, kind="danger"
+ )
+ if not any(mask_danger):
+ continue
+ X_danger = _safe_indexing(X_class, mask_danger)
+ self.in_danger_indices[class_sample] = target_class_indices[mask_danger]
+
+ if self.kind == "borderline-1":
+ X_to_sample_from = X_class # consider the positive class only
+ y_to_check_neighbors = None
+ else: # self.kind == "borderline-2"
+ X_to_sample_from = X # consider the whole dataset
+ y_to_check_neighbors = y
+
+ self.nn_k_.fit(X_to_sample_from)
+ nns = self.nn_k_.kneighbors(X_danger, return_distance=False)[:, 1:]
+ X_new, y_new = self._make_samples(
+ X_danger,
+ y.dtype,
+ class_sample,
+ X_to_sample_from,
+ nns,
+ n_samples,
+ y=y_to_check_neighbors,
+ )
+ if sparse.issparse(X_new):
+ X_resampled = sparse.vstack([X_resampled, X_new])
+ else:
+ X_resampled = np.vstack((X_resampled, X_new))
+ y_resampled = np.hstack((y_resampled, y_new))
+
+ return X_resampled, y_resampled
+
+
+@Substitution(
+ sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class SVMSMOTE(BaseSMOTE):
"""Over-sampling using SVM-SMOTE.
@@ -295,17 +388,151 @@ class SVMSMOTE(BaseSMOTE):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
- _parameter_constraints: dict = {**BaseSMOTE._parameter_constraints,
- 'm_neighbors': [Interval(numbers.Integral, 1, None, closed='left'),
- HasMethods(['kneighbors', 'kneighbors_graph'])], 'svm_estimator': [
- HasMethods(['fit', 'predict']), None], 'out_step': [Interval(
- numbers.Real, 0, 1, closed='both')]}
-
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- k_neighbors=5, n_jobs=None, m_neighbors=10, svm_estimator=None,
- out_step=0.5):
- super().__init__(sampling_strategy=sampling_strategy, random_state=
- random_state, k_neighbors=k_neighbors, n_jobs=n_jobs)
+
+ _parameter_constraints: dict = {
+ **BaseSMOTE._parameter_constraints,
+ "m_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "svm_estimator": [HasMethods(["fit", "predict"]), None],
+ "out_step": [Interval(numbers.Real, 0, 1, closed="both")],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ k_neighbors=5,
+ n_jobs=None,
+ m_neighbors=10,
+ svm_estimator=None,
+ out_step=0.5,
+ ):
+ super().__init__(
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ )
self.m_neighbors = m_neighbors
self.svm_estimator = svm_estimator
self.out_step = out_step
+
+ def _validate_estimator(self):
+ super()._validate_estimator()
+ self.nn_m_ = check_neighbors_object(
+ "m_neighbors", self.m_neighbors, additional_neighbor=1
+ )
+
+ if self.svm_estimator is None:
+ self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state)
+ else:
+ self.svm_estimator_ = clone(self.svm_estimator)
+
+ def _fit_resample(self, X, y):
+ # FIXME: to be removed in 0.12
+ if self.n_jobs is not None:
+ warnings.warn(
+ "The parameter `n_jobs` has been deprecated in 0.10 and will be "
+ "removed in 0.12. You can pass an nearest neighbors estimator where "
+ "`n_jobs` is already set instead.",
+ FutureWarning,
+ )
+
+ self._validate_estimator()
+ random_state = check_random_state(self.random_state)
+ X_resampled = X.copy()
+ y_resampled = y.copy()
+
+ for class_sample, n_samples in self.sampling_strategy_.items():
+ if n_samples == 0:
+ continue
+ target_class_indices = np.flatnonzero(y == class_sample)
+ X_class = _safe_indexing(X, target_class_indices)
+
+ self.svm_estimator_.fit(X, y)
+ if not hasattr(self.svm_estimator_, "support_"):
+ raise RuntimeError(
+ "`svm_estimator` is required to exposed a `support_` fitted "
+ "attribute. Such estimator belongs to the familly of Support "
+ "Vector Machine."
+ )
+ support_index = self.svm_estimator_.support_[
+ y[self.svm_estimator_.support_] == class_sample
+ ]
+ support_vector = _safe_indexing(X, support_index)
+
+ self.nn_m_.fit(X)
+ noise_bool = self._in_danger_noise(
+ self.nn_m_, support_vector, class_sample, y, kind="noise"
+ )
+ support_vector = _safe_indexing(
+ support_vector, np.flatnonzero(np.logical_not(noise_bool))
+ )
+ if support_vector.shape[0] == 0:
+ raise ValueError(
+ "All support vectors are considered as noise. SVM-SMOTE is not "
+ "adapted to your dataset. Try another SMOTE variant."
+ )
+ danger_bool = self._in_danger_noise(
+ self.nn_m_, support_vector, class_sample, y, kind="danger"
+ )
+ safety_bool = np.logical_not(danger_bool)
+
+ self.nn_k_.fit(X_class)
+ fractions = random_state.beta(10, 10)
+ n_generated_samples = int(fractions * (n_samples + 1))
+ if np.count_nonzero(danger_bool) > 0:
+ nns = self.nn_k_.kneighbors(
+ _safe_indexing(support_vector, np.flatnonzero(danger_bool)),
+ return_distance=False,
+ )[:, 1:]
+
+ X_new_1, y_new_1 = self._make_samples(
+ _safe_indexing(support_vector, np.flatnonzero(danger_bool)),
+ y.dtype,
+ class_sample,
+ X_class,
+ nns,
+ n_generated_samples,
+ step_size=1.0,
+ )
+
+ if np.count_nonzero(safety_bool) > 0:
+ nns = self.nn_k_.kneighbors(
+ _safe_indexing(support_vector, np.flatnonzero(safety_bool)),
+ return_distance=False,
+ )[:, 1:]
+
+ X_new_2, y_new_2 = self._make_samples(
+ _safe_indexing(support_vector, np.flatnonzero(safety_bool)),
+ y.dtype,
+ class_sample,
+ X_class,
+ nns,
+ n_samples - n_generated_samples,
+ step_size=-self.out_step,
+ )
+
+ if np.count_nonzero(danger_bool) > 0 and np.count_nonzero(safety_bool) > 0:
+ if sparse.issparse(X_resampled):
+ X_resampled = sparse.vstack([X_resampled, X_new_1, X_new_2])
+ else:
+ X_resampled = np.vstack((X_resampled, X_new_1, X_new_2))
+ y_resampled = np.concatenate((y_resampled, y_new_1, y_new_2), axis=0)
+ elif np.count_nonzero(danger_bool) == 0:
+ if sparse.issparse(X_resampled):
+ X_resampled = sparse.vstack([X_resampled, X_new_2])
+ else:
+ X_resampled = np.vstack((X_resampled, X_new_2))
+ y_resampled = np.concatenate((y_resampled, y_new_2), axis=0)
+ elif np.count_nonzero(safety_bool) == 0:
+ if sparse.issparse(X_resampled):
+ X_resampled = sparse.vstack([X_resampled, X_new_1])
+ else:
+ X_resampled = np.vstack((X_resampled, X_new_1))
+ y_resampled = np.concatenate((y_resampled, y_new_1), axis=0)
+
+ return X_resampled, y_resampled
diff --git a/imblearn/over_sampling/_smote/tests/test_borderline_smote.py b/imblearn/over_sampling/_smote/tests/test_borderline_smote.py
index b11e0ea..0d85c4d 100644
--- a/imblearn/over_sampling/_smote/tests/test_borderline_smote.py
+++ b/imblearn/over_sampling/_smote/tests/test_borderline_smote.py
@@ -1,17 +1,36 @@
from collections import Counter
+
import pytest
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.over_sampling import BorderlineSMOTE
-@pytest.mark.parametrize('kind', ['borderline-1', 'borderline-2'])
+@pytest.mark.parametrize("kind", ["borderline-1", "borderline-2"])
def test_borderline_smote_no_in_danger_samples(kind):
"""Check that the algorithm behave properly even on a dataset without any sample
in danger.
"""
- pass
+ X, y = make_classification(
+ n_samples=500,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.2, 0.7],
+ class_sep=1.5,
+ random_state=1,
+ )
+ smote = BorderlineSMOTE(kind=kind, m_neighbors=3, k_neighbors=5, random_state=0)
+ X_res, y_res = smote.fit_resample(X, y)
+
+ assert_allclose(X, X_res)
+ assert_allclose(y, y_res)
+ assert not smote.in_danger_indices
def test_borderline_smote_kind():
@@ -21,4 +40,71 @@ def test_borderline_smote_kind():
"borderline-1". We generate an example where a logistic regression will perform
worse on "borderline-2" than on "borderline-1".
"""
- pass
+ X, y = make_classification(
+ n_samples=500,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.2, 0.7],
+ class_sep=1.0,
+ random_state=1,
+ )
+ smote = BorderlineSMOTE(
+ kind="borderline-1", m_neighbors=9, k_neighbors=5, random_state=0
+ )
+ X_res_borderline_1, y_res_borderline_1 = smote.fit_resample(X, y)
+ smote.set_params(kind="borderline-2")
+ X_res_borderline_2, y_res_borderline_2 = smote.fit_resample(X, y)
+
+ score_borderline_1 = (
+ LogisticRegression()
+ .fit(X_res_borderline_1, y_res_borderline_1)
+ .score(X_res_borderline_1, y_res_borderline_1)
+ )
+ score_borderline_2 = (
+ LogisticRegression()
+ .fit(X_res_borderline_2, y_res_borderline_2)
+ .score(X_res_borderline_2, y_res_borderline_2)
+ )
+ assert score_borderline_1 > score_borderline_2
+
+
+def test_borderline_smote_in_danger():
+ X, y = make_classification(
+ n_samples=500,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.2, 0.7],
+ class_sep=0.8,
+ random_state=1,
+ )
+ smote = BorderlineSMOTE(
+ kind="borderline-1",
+ m_neighbors=9,
+ k_neighbors=5,
+ random_state=0,
+ )
+ _, y_res_1 = smote.fit_resample(X, y)
+ in_danger_indices_borderline_1 = smote.in_danger_indices
+ smote.set_params(kind="borderline-2")
+ _, y_res_2 = smote.fit_resample(X, y)
+ in_danger_indices_borderline_2 = smote.in_danger_indices
+
+ for key1, key2 in zip(
+ in_danger_indices_borderline_1, in_danger_indices_borderline_2
+ ):
+ assert_array_equal(
+ in_danger_indices_borderline_1[key1], in_danger_indices_borderline_2[key2]
+ )
+ assert len(in_danger_indices_borderline_1) == len(in_danger_indices_borderline_2)
+ counter = Counter(y_res_1)
+ assert counter[0] == counter[1] == counter[2]
+ counter = Counter(y_res_2)
+ assert counter[0] == counter[1] == counter[2]
diff --git a/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py b/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py
index e69baba..71fa47c 100644
--- a/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py
+++ b/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py
@@ -4,4 +4,105 @@ from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.datasets import make_classification
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.over_sampling import SMOTE, KMeansSMOTE
+
+
+@pytest.fixture
+def data():
+ X = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ ]
+ )
+ y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
+ return X, y
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_kmeans_smote(data):
+ X, y = data
+ kmeans_smote = KMeansSMOTE(
+ kmeans_estimator=1,
+ random_state=42,
+ cluster_balance_threshold=0.0,
+ k_neighbors=5,
+ )
+ smote = SMOTE(random_state=42)
+
+ X_res_1, y_res_1 = kmeans_smote.fit_resample(X, y)
+ X_res_2, y_res_2 = smote.fit_resample(X, y)
+
+ assert_allclose(X_res_1, X_res_2)
+ assert_array_equal(y_res_1, y_res_2)
+
+ assert kmeans_smote.nn_k_.n_neighbors == 6
+ assert kmeans_smote.kmeans_estimator_.n_clusters == 1
+ assert "batch_size" in kmeans_smote.kmeans_estimator_.get_params()
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+@pytest.mark.parametrize("k_neighbors", [2, NearestNeighbors(n_neighbors=3)])
+@pytest.mark.parametrize(
+ "kmeans_estimator",
+ [
+ 3,
+ KMeans(n_clusters=3, n_init=1, random_state=42),
+ MiniBatchKMeans(n_clusters=3, n_init=1, random_state=42),
+ ],
+)
+def test_sample_kmeans_custom(data, k_neighbors, kmeans_estimator):
+ X, y = data
+ kmeans_smote = KMeansSMOTE(
+ random_state=42,
+ kmeans_estimator=kmeans_estimator,
+ k_neighbors=k_neighbors,
+ )
+ X_resampled, y_resampled = kmeans_smote.fit_resample(X, y)
+ assert X_resampled.shape == (24, 2)
+ assert y_resampled.shape == (24,)
+
+ assert kmeans_smote.nn_k_.n_neighbors == 3
+ assert kmeans_smote.kmeans_estimator_.n_clusters == 3
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_sample_kmeans_not_enough_clusters(data):
+ X, y = data
+ smote = KMeansSMOTE(cluster_balance_threshold=10, random_state=42)
+ with pytest.raises(RuntimeError):
+ smote.fit_resample(X, y)
+
+
+@pytest.mark.parametrize("density_exponent", ["auto", 10])
+@pytest.mark.parametrize("cluster_balance_threshold", ["auto", 0.1])
+def test_sample_kmeans_density_estimation(density_exponent, cluster_balance_threshold):
+ X, y = make_classification(
+ n_samples=10_000, n_classes=2, weights=[0.3, 0.7], random_state=42
+ )
+ smote = KMeansSMOTE(
+ kmeans_estimator=MiniBatchKMeans(n_init=1, random_state=42),
+ random_state=0,
+ density_exponent=density_exponent,
+ cluster_balance_threshold=cluster_balance_threshold,
+ )
+ smote.fit_resample(X, y)
diff --git a/imblearn/over_sampling/_smote/tests/test_smote.py b/imblearn/over_sampling/_smote/tests/test_smote.py
index f8343f1..060ac8c 100644
--- a/imblearn/over_sampling/_smote/tests/test_smote.py
+++ b/imblearn/over_sampling/_smote/tests/test_smote.py
@@ -1,17 +1,149 @@
"""Test the module SMOTE."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.over_sampling import SMOTE
+
RND_SEED = 0
-X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [
- 1.25192108, -0.22367336], [0.53366841, -0.30312976], [1.52091956, -
- 0.49283504], [-0.28162401, -2.10400981], [0.83680821, 1.72827342], [
- 0.3084254, 0.33299982], [0.70472253, -0.73309052], [0.28893132, -
- 0.38761769], [1.15514042, 0.0129463], [0.88407872, 0.35454207], [
- 1.31301027, -0.92648734], [-1.11515198, -0.93689695], [-0.18410027, -
- 0.45194484], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], [-
- 0.41635887, -0.38299653], [0.08711622, 0.93259929], [1.70580611, -
- 0.11219234]])
+X = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ ]
+)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
-R_TOL = 0.0001
+R_TOL = 1e-4
+
+
+def test_sample_regular():
+ smote = SMOTE(random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ [0.29307743, -0.14670439],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571668],
+ [0.66052536, -0.28246517],
+ ]
+ )
+ y_gt = np.array(
+ [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
+ )
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_sample_regular_half():
+ sampling_strategy = {0: 9, 1: 12}
+ smote = SMOTE(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ [0.36784496, -0.1953161],
+ ]
+ )
+ y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0])
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_sample_regular_with_nn():
+ nn_k = NearestNeighbors(n_neighbors=6)
+ smote = SMOTE(random_state=RND_SEED, k_neighbors=nn_k)
+ X_resampled, y_resampled = smote.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ [0.29307743, -0.14670439],
+ [0.84976473, -0.15570176],
+ [0.61319159, -0.11571668],
+ [0.66052536, -0.28246517],
+ ]
+ )
+ y_gt = np.array(
+ [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
+ )
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
diff --git a/imblearn/over_sampling/_smote/tests/test_smote_nc.py b/imblearn/over_sampling/_smote/tests/test_smote_nc.py
index 06080ca..1314ea9 100644
--- a/imblearn/over_sampling/_smote/tests/test_smote_nc.py
+++ b/imblearn/over_sampling/_smote/tests/test_smote_nc.py
@@ -1,5 +1,11 @@
"""Test the module SMOTENC."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# Dzianis Dudnik
+# License: MIT
+
from collections import Counter
+
import numpy as np
import pytest
import sklearn
@@ -8,26 +14,301 @@ from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils._testing import assert_allclose, assert_array_equal
from sklearn.utils.fixes import parse_version
+
from imblearn.over_sampling import SMOTENC
-from imblearn.utils.estimator_checks import _set_checking_parameters, check_param_validation
+from imblearn.utils.estimator_checks import (
+ _set_checking_parameters,
+ check_param_validation,
+)
+
sklearn_version = parse_version(sklearn.__version__)
+def data_heterogneous_ordered():
+ rng = np.random.RandomState(42)
+ X = np.empty((30, 4), dtype=object)
+ # create 2 random continuous feature
+ X[:, :2] = rng.randn(30, 2)
+ # create a categorical feature using some string
+ X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object)
+ # create a categorical feature using some integer
+ X[:, 3] = rng.randint(3, size=30)
+ y = np.array([0] * 10 + [1] * 20)
+ # return the categories
+ return X, y, [2, 3]
+
+
+def data_heterogneous_unordered():
+ rng = np.random.RandomState(42)
+ X = np.empty((30, 4), dtype=object)
+ # create 2 random continuous feature
+ X[:, [1, 2]] = rng.randn(30, 2)
+ # create a categorical feature using some string
+ X[:, 0] = rng.choice(["a", "b", "c"], size=30).astype(object)
+ # create a categorical feature using some integer
+ X[:, 3] = rng.randint(3, size=30)
+ y = np.array([0] * 10 + [1] * 20)
+ # return the categories
+ return X, y, [0, 3]
+
+
+def data_heterogneous_masked():
+ rng = np.random.RandomState(42)
+ X = np.empty((30, 4), dtype=object)
+ # create 2 random continuous feature
+ X[:, [1, 2]] = rng.randn(30, 2)
+ # create a categorical feature using some string
+ X[:, 0] = rng.choice(["a", "b", "c"], size=30).astype(object)
+ # create a categorical feature using some integer
+ X[:, 3] = rng.randint(3, size=30)
+ y = np.array([0] * 10 + [1] * 20)
+ # return the categories
+ return X, y, [True, False, False, True]
+
+
+def data_heterogneous_unordered_multiclass():
+ rng = np.random.RandomState(42)
+ X = np.empty((50, 4), dtype=object)
+ # create 2 random continuous feature
+ X[:, [1, 2]] = rng.randn(50, 2)
+ # create a categorical feature using some string
+ X[:, 0] = rng.choice(["a", "b", "c"], size=50).astype(object)
+ # create a categorical feature using some integer
+ X[:, 3] = rng.randint(3, size=50)
+ y = np.array([0] * 10 + [1] * 15 + [2] * 25)
+ # return the categories
+ return X, y, [0, 3]
+
+
+def data_sparse(format):
+ rng = np.random.RandomState(42)
+ X = np.empty((30, 4), dtype=np.float64)
+ # create 2 random continuous feature
+ X[:, [1, 2]] = rng.randn(30, 2)
+ # create a categorical feature using some string
+ X[:, 0] = rng.randint(3, size=30)
+ # create a categorical feature using some integer
+ X[:, 3] = rng.randint(3, size=30)
+ y = np.array([0] * 10 + [1] * 20)
+ X = sparse.csr_matrix(X) if format == "csr" else sparse.csc_matrix(X)
+ return X, y, [0, 3]
+
+
+def test_smotenc_error():
+ X, y, _ = data_heterogneous_unordered()
+ categorical_features = [0, 10]
+ smote = SMOTENC(random_state=0, categorical_features=categorical_features)
+ with pytest.raises(ValueError, match="all features must be in"):
+ smote.fit_resample(X, y)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ data_heterogneous_ordered(),
+ data_heterogneous_unordered(),
+ data_heterogneous_masked(),
+ data_sparse("csr"),
+ data_sparse("csc"),
+ ],
+)
+def test_smotenc(data):
+ X, y, categorical_features = data
+ smote = SMOTENC(random_state=0, categorical_features=categorical_features)
+ X_resampled, y_resampled = smote.fit_resample(X, y)
+
+ assert X_resampled.dtype == X.dtype
+
+ categorical_features = np.array(categorical_features)
+ if categorical_features.dtype == bool:
+ categorical_features = np.flatnonzero(categorical_features)
+ for cat_idx in categorical_features:
+ if sparse.issparse(X):
+ assert set(X[:, cat_idx].data) == set(X_resampled[:, cat_idx].data)
+ assert X[:, cat_idx].dtype == X_resampled[:, cat_idx].dtype
+ else:
+ assert set(X[:, cat_idx]) == set(X_resampled[:, cat_idx])
+ assert X[:, cat_idx].dtype == X_resampled[:, cat_idx].dtype
+
+ assert isinstance(smote.median_std_, dict)
+
+
+# part of the common test which apply to SMOTE-NC even if it is not default
+# constructible
+def test_smotenc_check_target_type():
+ X, _, categorical_features = data_heterogneous_unordered()
+ y = np.linspace(0, 1, 30)
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ with pytest.raises(ValueError, match="Unknown label type"):
+ smote.fit_resample(X, y)
+ rng = np.random.RandomState(42)
+ y = rng.randint(2, size=(20, 3))
+ msg = "Multilabel and multioutput targets are not supported."
+ with pytest.raises(ValueError, match=msg):
+ smote.fit_resample(X, y)
+
+
+def test_smotenc_samplers_one_label():
+ X, _, categorical_features = data_heterogneous_unordered()
+ y = np.zeros(30)
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ with pytest.raises(ValueError, match="needs to have more than 1 class"):
+ smote.fit(X, y)
+
+
+def test_smotenc_fit():
+ X, y, categorical_features = data_heterogneous_unordered()
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ smote.fit_resample(X, y)
+ assert hasattr(
+ smote, "sampling_strategy_"
+ ), "No fitted attribute sampling_strategy_"
+
+
+def test_smotenc_fit_resample():
+ X, y, categorical_features = data_heterogneous_unordered()
+ target_stats = Counter(y)
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ _, y_res = smote.fit_resample(X, y)
+ _ = Counter(y_res)
+ n_samples = max(target_stats.values())
+ assert all(value >= n_samples for value in Counter(y_res).values())
+
+
+def test_smotenc_fit_resample_sampling_strategy():
+ X, y, categorical_features = data_heterogneous_unordered_multiclass()
+ expected_stat = Counter(y)[1]
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ sampling_strategy = {2: 25, 0: 25}
+ smote.set_params(sampling_strategy=sampling_strategy)
+ X_res, y_res = smote.fit_resample(X, y)
+ assert Counter(y_res)[1] == expected_stat
+
+
+def test_smotenc_pandas():
+ pd = pytest.importorskip("pandas")
+ # Check that the samplers handle pandas dataframe and pandas series
+ X, y, categorical_features = data_heterogneous_unordered_multiclass()
+ X_pd = pd.DataFrame(X)
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ X_res_pd, y_res_pd = smote.fit_resample(X_pd, y)
+ X_res, y_res = smote.fit_resample(X, y)
+ assert_array_equal(X_res_pd.to_numpy(), X_res)
+ assert_allclose(y_res_pd, y_res)
+ assert set(smote.median_std_.keys()) == {0, 1}
+
+
+def test_smotenc_preserve_dtype():
+ X, y = make_classification(
+ n_samples=50,
+ n_classes=3,
+ n_informative=4,
+ weights=[0.2, 0.3, 0.5],
+ random_state=0,
+ )
+ # Cast X and y to not default dtype
+ X = X.astype(np.float32)
+ y = y.astype(np.int32)
+ smote = SMOTENC(categorical_features=[1], random_state=0)
+ X_res, y_res = smote.fit_resample(X, y)
+ assert X.dtype == X_res.dtype, "X dtype is not preserved"
+ assert y.dtype == y_res.dtype, "y dtype is not preserved"
+
+
+@pytest.mark.parametrize("categorical_features", [[True, True, True], [0, 1, 2]])
+def test_smotenc_raising_error_all_categorical(categorical_features):
+ X, y = make_classification(
+ n_features=3,
+ n_informative=1,
+ n_redundant=1,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ )
+ smote = SMOTENC(categorical_features=categorical_features)
+ err_msg = "SMOTE-NC is not designed to work only with categorical features"
+ with pytest.raises(ValueError, match=err_msg):
+ smote.fit_resample(X, y)
+
+
+def test_smote_nc_with_null_median_std():
+ # Non-regression test for #662
+ # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/662
+ data = np.array(
+ [
+ [1, 2, 1, "A"],
+ [2, 1, 2, "A"],
+ [2, 1, 2, "A"],
+ [1, 2, 3, "B"],
+ [1, 2, 4, "C"],
+ [1, 2, 5, "C"],
+ [1, 2, 4, "C"],
+ [1, 2, 4, "C"],
+ [1, 2, 4, "C"],
+ ],
+ dtype="object",
+ )
+ labels = np.array(
+ [
+ "class_1",
+ "class_1",
+ "class_1",
+ "class_1",
+ "class_2",
+ "class_2",
+ "class_3",
+ "class_3",
+ "class_3",
+ ],
+ dtype=object,
+ )
+ smote = SMOTENC(categorical_features=[3], k_neighbors=1, random_state=0)
+ X_res, y_res = smote.fit_resample(data, labels)
+ # check that the categorical feature is not random but correspond to the
+ # categories seen in the minority class samples
+ assert_array_equal(X_res[-3:, -1], np.array(["C", "C", "C"], dtype=object))
+ assert smote.median_std_ == {"class_2": 0.0, "class_3": 0.0}
+
+
def test_smotenc_categorical_encoder():
"""Check that we can pass our own categorical encoder."""
- pass
+ # TODO: only use `sparse_output` when sklearn >= 1.2
+ param = "sparse" if sklearn_version < parse_version("1.2") else "sparse_output"
+
+ X, y, categorical_features = data_heterogneous_unordered()
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ smote.fit_resample(X, y)
+ assert getattr(smote.categorical_encoder_, param) is True
+
+ encoder = OneHotEncoder()
+ encoder.set_params(**{param: False})
+ smote.set_params(categorical_encoder=encoder).fit_resample(X, y)
+ assert smote.categorical_encoder is encoder
+ assert smote.categorical_encoder_ is not encoder
+ assert getattr(smote.categorical_encoder_, param) is False
+
+
+# TODO(0.13): remove this test
def test_smotenc_deprecation_ohe_():
"""Check that we raise a deprecation warning when using `ohe_`."""
- pass
+ X, y, categorical_features = data_heterogneous_unordered()
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ smote.fit_resample(X, y)
+
+ with pytest.warns(FutureWarning, match="'ohe_' attribute has been deprecated"):
+ smote.ohe_
def test_smotenc_param_validation():
"""Check that we validate the parameters correctly since this estimator requires
a specific parameter.
"""
- pass
+ categorical_features = [0]
+ smote = SMOTENC(categorical_features=categorical_features, random_state=0)
+ name = smote.__class__.__name__
+ _set_checking_parameters(smote)
+ check_param_validation(name, smote)
def test_smotenc_bool_categorical():
@@ -37,23 +318,102 @@ def test_smotenc_bool_categorical():
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/974
"""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "c": pd.Categorical([x for x in "abbacaba" * 3]),
+ "f": [0.3, 0.5, 0.1, 0.2] * 6,
+ "b": [False, False, True] * 8,
+ }
+ )
+ y = pd.DataFrame({"out": [1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0] * 2})
+ smote = SMOTENC(categorical_features=[0])
+
+ X_res, y_res = smote.fit_resample(X, y)
+ pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
+ assert len(X_res) == len(y_res)
+
+ smote.set_params(categorical_features=[0, 2])
+ X_res, y_res = smote.fit_resample(X, y)
+ pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
+ assert len(X_res) == len(y_res)
+
+ X = X.astype({"b": "category"})
+ X_res, y_res = smote.fit_resample(X, y)
+ pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
+ assert len(X_res) == len(y_res)
def test_smotenc_categorical_features_str():
"""Check that we support array-like of strings for `categorical_features` using
pandas dataframe.
"""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "B": ["a", "b"] * 5,
+ "C": ["a", "b", "c"] * 3 + ["a"],
+ }
+ )
+ X = pd.concat([X] * 10, ignore_index=True)
+ y = np.array([0] * 70 + [1] * 30)
+ smote = SMOTENC(categorical_features=["B", "C"], random_state=0)
+ X_res, y_res = smote.fit_resample(X, y)
+ assert X_res["B"].isin(["a", "b"]).all()
+ assert X_res["C"].isin(["a", "b", "c"]).all()
+ counter = Counter(y_res)
+ assert counter[0] == counter[1] == 70
+ assert_array_equal(smote.categorical_features_, [1, 2])
+ assert_array_equal(smote.continuous_features_, [0])
def test_smotenc_categorical_features_auto():
"""Check that we can automatically detect categorical features based on pandas
dataframe.
"""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "B": ["a", "b"] * 5,
+ "C": ["a", "b", "c"] * 3 + ["a"],
+ }
+ )
+ X = pd.concat([X] * 10, ignore_index=True)
+ X["B"] = X["B"].astype("category")
+ X["C"] = X["C"].astype("category")
+ y = np.array([0] * 70 + [1] * 30)
+ smote = SMOTENC(categorical_features="auto", random_state=0)
+ X_res, y_res = smote.fit_resample(X, y)
+ assert X_res["B"].isin(["a", "b"]).all()
+ assert X_res["C"].isin(["a", "b", "c"]).all()
+ counter = Counter(y_res)
+ assert counter[0] == counter[1] == 70
+ assert_array_equal(smote.categorical_features_, [1, 2])
+ assert_array_equal(smote.continuous_features_, [0])
def test_smote_nc_categorical_features_auto_error():
"""Check that we raise a proper error when we cannot use the `'auto'` mode."""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "B": ["a", "b"] * 5,
+ "C": ["a", "b", "c"] * 3 + ["a"],
+ }
+ )
+ y = np.array([0] * 70 + [1] * 30)
+ smote = SMOTENC(categorical_features="auto", random_state=0)
+
+ with pytest.raises(ValueError, match="the input data should be a pandas.DataFrame"):
+ smote.fit_resample(X.to_numpy(), y)
+
+ err_msg = "SMOTE-NC is not designed to work only with numerical features"
+ with pytest.raises(ValueError, match=err_msg):
+ smote.fit_resample(X, y)
diff --git a/imblearn/over_sampling/_smote/tests/test_smoten.py b/imblearn/over_sampling/_smote/tests/test_smoten.py
index b4fceeb..2e30e3f 100644
--- a/imblearn/over_sampling/_smote/tests/test_smoten.py
+++ b/imblearn/over_sampling/_smote/tests/test_smoten.py
@@ -3,19 +3,93 @@ import pytest
from sklearn.exceptions import DataConversionWarning
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.utils._testing import _convert_container
+
from imblearn.over_sampling import SMOTEN
-@pytest.mark.parametrize('sparse_format', ['sparse_csr', 'sparse_csc'])
+@pytest.fixture
+def data():
+ rng = np.random.RandomState(0)
+
+ feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
+ feature_2 = ["A"] * 40 + ["B"] * 20
+ feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
+ X = np.array([feature_1, feature_2, feature_3], dtype=object).T
+ rng.shuffle(X)
+ y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
+ y_labels = np.array(["not apple", "apple"], dtype=object)
+ y = y_labels[y]
+ return X, y
+
+
+def test_smoten(data):
+ # overall check for SMOTEN
+ X, y = data
+ sampler = SMOTEN(random_state=0)
+ X_res, y_res = sampler.fit_resample(X, y)
+
+ assert X_res.shape == (80, 3)
+ assert y_res.shape == (80,)
+ assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
+
+
+def test_smoten_resampling():
+ # check if the SMOTEN resample data as expected
+ # we generate data such that "not apple" will be the minority class and
+ # samples from this class will be generated. We will force the "blue"
+ # category to be associated with this class. Therefore, the new generated
+ # samples should as well be from the "blue" category.
+ X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, dtype=object).reshape(
+ -1, 1
+ )
+ y = np.array(
+ ["apple"] * 5
+ + ["not apple"] * 3
+ + ["apple"] * 7
+ + ["not apple"] * 5
+ + ["apple"] * 2,
+ dtype=object,
+ )
+ sampler = SMOTEN(random_state=0)
+ X_res, y_res = sampler.fit_resample(X, y)
+
+ X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
+ np.testing.assert_array_equal(X_generated, "blue")
+ np.testing.assert_array_equal(y_generated, "not apple")
+
+
+@pytest.mark.parametrize("sparse_format", ["sparse_csr", "sparse_csc"])
def test_smoten_sparse_input(data, sparse_format):
"""Check that we handle sparse input in SMOTEN even if it is not efficient.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/971
"""
- pass
+ X, y = data
+ X = OneHotEncoder().fit_transform(X).toarray()
+ X = _convert_container(X, sparse_format)
+
+ with pytest.warns(DataConversionWarning, match="is not really efficient"):
+ X_res, y_res = SMOTEN(random_state=0).fit_resample(X, y)
+
+ assert X_res.format == X.format
+ assert X_res.shape[0] == len(y_res)
def test_smoten_categorical_encoder(data):
"""Check that `categorical_encoder` is used when provided."""
- pass
+
+ X, y = data
+ sampler = SMOTEN(random_state=0)
+ sampler.fit_resample(X, y)
+
+ assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
+ assert sampler.categorical_encoder_.dtype == np.int32
+
+ encoder = OrdinalEncoder(dtype=np.int64)
+ sampler.set_params(categorical_encoder=encoder).fit_resample(X, y)
+
+ assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
+ assert sampler.categorical_encoder is encoder
+ assert sampler.categorical_encoder_ is not encoder
+ assert sampler.categorical_encoder_.dtype == np.int64
diff --git a/imblearn/over_sampling/_smote/tests/test_svm_smote.py b/imblearn/over_sampling/_smote/tests/test_svm_smote.py
index dd43004..49e01f6 100644
--- a/imblearn/over_sampling/_smote/tests/test_svm_smote.py
+++ b/imblearn/over_sampling/_smote/tests/test_svm_smote.py
@@ -5,13 +5,63 @@ from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors
from sklearn.svm import SVC
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.over_sampling import SVMSMOTE
+@pytest.fixture
+def data():
+ X = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ ]
+ )
+ y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
+ return X, y
+
+
+def test_svm_smote(data):
+ svm_smote = SVMSMOTE(random_state=42)
+ svm_smote_nn = SVMSMOTE(
+ random_state=42,
+ k_neighbors=NearestNeighbors(n_neighbors=6),
+ m_neighbors=NearestNeighbors(n_neighbors=11),
+ svm_estimator=SVC(gamma="scale", random_state=42),
+ )
+
+ X_res_1, y_res_1 = svm_smote.fit_resample(*data)
+ X_res_2, y_res_2 = svm_smote_nn.fit_resample(*data)
+
+ assert_allclose(X_res_1, X_res_2)
+ assert_array_equal(y_res_1, y_res_2)
+
+
def test_svm_smote_not_svm(data):
"""Check that we raise a proper error if passing an estimator that does not
expose a `support_` fitted attribute."""
- pass
+
+ err_msg = "`svm_estimator` is required to exposed a `support_` fitted attribute."
+ with pytest.raises(RuntimeError, match=err_msg):
+ SVMSMOTE(svm_estimator=LogisticRegression()).fit_resample(*data)
def test_svm_smote_all_noise(data):
@@ -21,4 +71,18 @@ def test_svm_smote_all_noise(data):
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/742
"""
- pass
+ X, y = make_classification(
+ n_classes=3,
+ class_sep=0.001,
+ weights=[0.004, 0.451, 0.545],
+ n_informative=3,
+ n_redundant=0,
+ flip_y=0,
+ n_features=3,
+ n_clusters_per_class=2,
+ n_samples=1000,
+ random_state=10,
+ )
+
+ with pytest.raises(ValueError, match="SVM-SMOTE is not adapted to your dataset"):
+ SVMSMOTE(k_neighbors=4, random_state=42).fit_resample(X, y)
diff --git a/imblearn/over_sampling/base.py b/imblearn/over_sampling/base.py
index f71cf78..fbd982b 100644
--- a/imblearn/over_sampling/base.py
+++ b/imblearn/over_sampling/base.py
@@ -1,8 +1,13 @@
"""
Base class for the over-sampling method.
"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
from collections.abc import Mapping
+
from ..base import BaseSampler
from ..utils._param_validation import Interval, StrOptions
@@ -13,9 +18,10 @@ class BaseOverSampler(BaseSampler):
Warning: This class should not be used directly. Use the derive classes
instead.
"""
- _sampling_type = 'over-sampling'
- _sampling_strategy_docstring = (
- """sampling_strategy : float, str, dict or callable, default='auto'
+
+ _sampling_type = "over-sampling"
+
+ _sampling_strategy_docstring = """sampling_strategy : float, str, dict or callable, default='auto'
Sampling information to resample the data set.
- When ``float``, it corresponds to the desired ratio of the number of
@@ -50,9 +56,14 @@ class BaseOverSampler(BaseSampler):
- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.
- """
- .strip())
- _parameter_constraints: dict = {'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'minority',
- 'not minority', 'not majority', 'all'}), Mapping, callable],
- 'random_state': ['random_state']}
+ """.strip() # noqa: E501
+
+ _parameter_constraints: dict = {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "minority", "not minority", "not majority", "all"}),
+ Mapping,
+ callable,
+ ],
+ "random_state": ["random_state"],
+ }
diff --git a/imblearn/over_sampling/tests/test_adasyn.py b/imblearn/over_sampling/tests/test_adasyn.py
index ea8f98e..4df6362 100644
--- a/imblearn/over_sampling/tests/test_adasyn.py
+++ b/imblearn/over_sampling/tests/test_adasyn.py
@@ -1,17 +1,121 @@
"""Test the module under sampler."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.over_sampling import ADASYN
+
RND_SEED = 0
-X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [
- 1.25192108, -0.22367336], [0.53366841, -0.30312976], [1.52091956, -
- 0.49283504], [-0.28162401, -2.10400981], [0.83680821, 1.72827342], [
- 0.3084254, 0.33299982], [0.70472253, -0.73309052], [0.28893132, -
- 0.38761769], [1.15514042, 0.0129463], [0.88407872, 0.35454207], [
- 1.31301027, -0.92648734], [-1.11515198, -0.93689695], [-0.18410027, -
- 0.45194484], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], [-
- 0.41635887, -0.38299653], [0.08711622, 0.93259929], [1.70580611, -
- 0.11219234]])
+X = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ ]
+)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
-R_TOL = 0.0001
+R_TOL = 1e-4
+
+
+def test_ada_init():
+ sampling_strategy = "auto"
+ ada = ADASYN(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ assert ada.random_state == RND_SEED
+
+
+def test_ada_fit_resample():
+ ada = ADASYN(random_state=RND_SEED)
+ X_resampled, y_resampled = ada.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ [0.88161986, -0.2829741],
+ [0.35681689, -0.18814597],
+ [1.4148276, 0.05308106],
+ [0.3136591, -0.31327875],
+ ]
+ )
+ y_gt = np.array(
+ [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
+ )
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_ada_fit_resample_nn_obj():
+ nn = NearestNeighbors(n_neighbors=6)
+ ada = ADASYN(random_state=RND_SEED, n_neighbors=nn)
+ X_resampled, y_resampled = ada.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.11622591, -0.0317206],
+ [0.77481731, 0.60935141],
+ [1.25192108, -0.22367336],
+ [0.53366841, -0.30312976],
+ [1.52091956, -0.49283504],
+ [-0.28162401, -2.10400981],
+ [0.83680821, 1.72827342],
+ [0.3084254, 0.33299982],
+ [0.70472253, -0.73309052],
+ [0.28893132, -0.38761769],
+ [1.15514042, 0.0129463],
+ [0.88407872, 0.35454207],
+ [1.31301027, -0.92648734],
+ [-1.11515198, -0.93689695],
+ [-0.18410027, -0.45194484],
+ [0.9281014, 0.53085498],
+ [-0.14374509, 0.27370049],
+ [-0.41635887, -0.38299653],
+ [0.08711622, 0.93259929],
+ [1.70580611, -0.11219234],
+ [0.88161986, -0.2829741],
+ [0.35681689, -0.18814597],
+ [1.4148276, 0.05308106],
+ [0.3136591, -0.31327875],
+ ]
+ )
+ y_gt = np.array(
+ [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
+ )
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_array_equal(y_resampled, y_gt)
diff --git a/imblearn/over_sampling/tests/test_common.py b/imblearn/over_sampling/tests/test_common.py
index e0a133d..cdd85c1 100644
--- a/imblearn/over_sampling/tests/test_common.py
+++ b/imblearn/over_sampling/tests/test_common.py
@@ -1,6 +1,145 @@
from collections import Counter
+
import numpy as np
import pytest
from sklearn.cluster import MiniBatchKMeans
-from imblearn.over_sampling import ADASYN, SMOTE, SMOTEN, SMOTENC, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE
+
+from imblearn.over_sampling import (
+ ADASYN,
+ SMOTE,
+ SMOTEN,
+ SMOTENC,
+ SVMSMOTE,
+ BorderlineSMOTE,
+ KMeansSMOTE,
+)
from imblearn.utils.testing import _CustomNearestNeighbors
+
+
+@pytest.fixture
+def numerical_data():
+ rng = np.random.RandomState(0)
+ X = rng.randn(100, 2)
+ y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5)
+
+ return X, y
+
+
+@pytest.fixture
+def categorical_data():
+ rng = np.random.RandomState(0)
+
+ feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
+ feature_2 = ["A"] * 40 + ["B"] * 20
+ feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
+ X = np.array([feature_1, feature_2, feature_3], dtype=object).T
+ rng.shuffle(X)
+ y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
+ y_labels = np.array(["not apple", "apple"], dtype=object)
+ y = y_labels[y]
+ return X, y
+
+
+@pytest.fixture
+def heterogeneous_data():
+ rng = np.random.RandomState(42)
+ X = np.empty((30, 4), dtype=object)
+ X[:, :2] = rng.randn(30, 2)
+ X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object)
+ X[:, 3] = rng.randint(3, size=30)
+ y = np.array([0] * 10 + [1] * 20)
+ return X, y, [2, 3]
+
+
+@pytest.mark.parametrize(
+ "smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
+)
+def test_smote_m_neighbors(numerical_data, smote):
+ # check that m_neighbors is properly set. Regression test for:
+ # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
+ X, y = numerical_data
+ _ = smote.fit_resample(X, y)
+ assert smote.nn_k_.n_neighbors == 6
+ assert smote.nn_m_.n_neighbors == 11
+
+
+@pytest.mark.parametrize(
+ "smote, neighbor_estimator_name",
+ [
+ (ADASYN(random_state=0), "n_neighbors"),
+ (BorderlineSMOTE(random_state=0), "k_neighbors"),
+ (
+ KMeansSMOTE(
+ kmeans_estimator=MiniBatchKMeans(n_init=1, random_state=0),
+ random_state=1,
+ ),
+ "k_neighbors",
+ ),
+ (SMOTE(random_state=0), "k_neighbors"),
+ (SVMSMOTE(random_state=0), "k_neighbors"),
+ ],
+ ids=["adasyn", "borderline", "kmeans", "smote", "svm"],
+)
+def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name):
+ X, y = numerical_data
+ params = {
+ neighbor_estimator_name: _CustomNearestNeighbors(n_neighbors=5),
+ }
+ smote.set_params(**params)
+ X_res, _ = smote.fit_resample(X, y)
+
+ assert X_res.shape[0] >= 120
+
+
+def test_categorical_smote_k_custom_nn(categorical_data):
+ X, y = categorical_data
+ smote = SMOTEN(k_neighbors=_CustomNearestNeighbors(n_neighbors=5))
+ X_res, y_res = smote.fit_resample(X, y)
+
+ assert X_res.shape == (80, 3)
+ assert Counter(y_res) == {"apple": 40, "not apple": 40}
+
+
+def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
+ X, y, categorical_features = heterogeneous_data
+ smote = SMOTENC(
+ categorical_features, k_neighbors=_CustomNearestNeighbors(n_neighbors=5)
+ )
+ X_res, y_res = smote.fit_resample(X, y)
+
+ assert X_res.shape == (40, 4)
+ assert Counter(y_res) == {0: 20, 1: 20}
+
+
+@pytest.mark.parametrize(
+ "smote",
+ [BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)],
+ ids=["borderline", "svm"],
+)
+def test_numerical_smote_extra_custom_nn(numerical_data, smote):
+ X, y = numerical_data
+ smote.set_params(m_neighbors=_CustomNearestNeighbors(n_neighbors=5))
+ X_res, y_res = smote.fit_resample(X, y)
+
+ assert X_res.shape == (120, 2)
+ assert Counter(y_res) == {0: 60, 1: 60}
+
+
+# FIXME: to be removed in 0.12
+@pytest.mark.parametrize(
+ "sampler",
+ [
+ ADASYN(random_state=0),
+ BorderlineSMOTE(random_state=0),
+ SMOTE(random_state=0),
+ SMOTEN(random_state=0),
+ SMOTENC([0], random_state=0),
+ SVMSMOTE(random_state=0),
+ ],
+)
+def test_n_jobs_deprecation_warning(numerical_data, sampler):
+ X, y = numerical_data
+ sampler.set_params(n_jobs=2)
+ warning_msg = "The parameter `n_jobs` has been deprecated"
+ with pytest.warns(FutureWarning, match=warning_msg):
+ sampler.fit_resample(X, y)
diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py
index c239b21..efa40c8 100644
--- a/imblearn/over_sampling/tests/test_random_over_sampler.py
+++ b/imblearn/over_sampling/tests/test_random_over_sampler.py
@@ -1,25 +1,292 @@
"""Test the module under sampler."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
from datetime import datetime
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
-from sklearn.utils._testing import _convert_container, assert_allclose, assert_array_equal
+from sklearn.utils._testing import (
+ _convert_container,
+ assert_allclose,
+ assert_array_equal,
+)
+
from imblearn.over_sampling import RandomOverSampler
+
RND_SEED = 0
-@pytest.mark.parametrize('sampling_strategy', ['auto', 'minority',
- 'not minority', 'not majority', 'all'])
+@pytest.fixture
+def data():
+ X = np.array(
+ [
+ [0.04352327, -0.20515826],
+ [0.92923648, 0.76103773],
+ [0.20792588, 1.49407907],
+ [0.47104475, 0.44386323],
+ [0.22950086, 0.33367433],
+ [0.15490546, 0.3130677],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.13347175, 0.12167502],
+ [0.094035, -2.55298982],
+ ]
+ )
+ Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1])
+ return X, Y
+
+
+def test_ros_init():
+ sampling_strategy = "auto"
+ ros = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ assert ros.random_state == RND_SEED
+
+
+@pytest.mark.parametrize(
+ "params", [{"shrinkage": None}, {"shrinkage": 0}, {"shrinkage": {0: 0}}]
+)
+@pytest.mark.parametrize("X_type", ["array", "dataframe"])
+def test_ros_fit_resample(X_type, data, params):
+ X, Y = data
+ X_ = _convert_container(X, X_type)
+ ros = RandomOverSampler(**params, random_state=RND_SEED)
+ X_resampled, y_resampled = ros.fit_resample(X_, Y)
+ X_gt = np.array(
+ [
+ [0.04352327, -0.20515826],
+ [0.92923648, 0.76103773],
+ [0.20792588, 1.49407907],
+ [0.47104475, 0.44386323],
+ [0.22950086, 0.33367433],
+ [0.15490546, 0.3130677],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.13347175, 0.12167502],
+ [0.094035, -2.55298982],
+ [0.92923648, 0.76103773],
+ [0.47104475, 0.44386323],
+ [0.92923648, 0.76103773],
+ [0.47104475, 0.44386323],
+ ]
+ )
+ y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0])
+
+ if X_type == "dataframe":
+ assert hasattr(X_resampled, "loc")
+ # FIXME: we should use to_numpy with pandas >= 0.25
+ X_resampled = X_resampled.values
+
+ assert_allclose(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+ if params["shrinkage"] is None:
+ assert ros.shrinkage_ is None
+ else:
+ assert ros.shrinkage_ == {0: 0}
+
+
+@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
+def test_ros_fit_resample_half(data, params):
+ X, Y = data
+ sampling_strategy = {0: 3, 1: 7}
+ ros = RandomOverSampler(
+ **params, sampling_strategy=sampling_strategy, random_state=RND_SEED
+ )
+ X_resampled, y_resampled = ros.fit_resample(X, Y)
+ X_gt = np.array(
+ [
+ [0.04352327, -0.20515826],
+ [0.92923648, 0.76103773],
+ [0.20792588, 1.49407907],
+ [0.47104475, 0.44386323],
+ [0.22950086, 0.33367433],
+ [0.15490546, 0.3130677],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.13347175, 0.12167502],
+ [0.094035, -2.55298982],
+ ]
+ )
+ y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1])
+ assert_allclose(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+ if params["shrinkage"] is None:
+ assert ros.shrinkage_ is None
+ else:
+ assert ros.shrinkage_ == {0: 0, 1: 0}
+
+
+@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
+def test_multiclass_fit_resample(data, params):
+ # check the random over-sampling with a multiclass problem
+ X, Y = data
+ y = Y.copy()
+ y[5] = 2
+ y[6] = 2
+ ros = RandomOverSampler(**params, random_state=RND_SEED)
+ X_resampled, y_resampled = ros.fit_resample(X, y)
+ count_y_res = Counter(y_resampled)
+ assert count_y_res[0] == 5
+ assert count_y_res[1] == 5
+ assert count_y_res[2] == 5
+
+ if params["shrinkage"] is None:
+ assert ros.shrinkage_ is None
+ else:
+ assert ros.shrinkage_ == {0: 0, 2: 0}
+
+
+def test_random_over_sampling_heterogeneous_data():
+ # check that resampling with heterogeneous dtype is working with basic
+ # resampling
+ X_hetero = np.array(
+ [["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
+ )
+ y = np.array([0, 0, 1])
+ ros = RandomOverSampler(random_state=RND_SEED)
+ X_res, y_res = ros.fit_resample(X_hetero, y)
+
+ assert X_res.shape[0] == 4
+ assert y_res.shape[0] == 4
+ assert X_res.dtype == object
+ assert X_res[-1, 0] in X_hetero[:, 0]
+
+
+def test_random_over_sampling_nan_inf(data):
+ # check that we can oversample even with missing or infinite data
+ # regression tests for #605
+ X, Y = data
+ rng = np.random.RandomState(42)
+ n_not_finite = X.shape[0] // 3
+ row_indices = rng.choice(np.arange(X.shape[0]), size=n_not_finite)
+ col_indices = rng.randint(0, X.shape[1], size=n_not_finite)
+ not_finite_values = rng.choice([np.nan, np.inf], size=n_not_finite)
+
+ X_ = X.copy()
+ X_[row_indices, col_indices] = not_finite_values
+
+ ros = RandomOverSampler(random_state=0)
+ X_res, y_res = ros.fit_resample(X_, Y)
+
+ assert y_res.shape == (14,)
+ assert X_res.shape == (14, 2)
+ assert np.any(~np.isfinite(X_res))
+
+
+def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
+ # check that we raise an error when heterogeneous dtype data are given
+ # and a smoothed bootstrap is requested
+ X_hetero = np.array(
+ [["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
+ )
+ y = np.array([0, 0, 1])
+ ros = RandomOverSampler(shrinkage=1, random_state=RND_SEED)
+ err_msg = "When shrinkage is not None, X needs to contain only numerical"
+ with pytest.raises(ValueError, match=err_msg):
+ ros.fit_resample(X_hetero, y)
+
+
+@pytest.mark.parametrize("X_type", ["dataframe", "array", "sparse_csr", "sparse_csc"])
+def test_random_over_sampler_smoothed_bootstrap(X_type, data):
+ # check that smoothed bootstrap is working for numerical array
+ X, y = data
+ sampler = RandomOverSampler(shrinkage=1)
+ X = _convert_container(X, X_type)
+ X_res, y_res = sampler.fit_resample(X, y)
+
+ assert y_res.shape == (14,)
+ assert X_res.shape == (14, 2)
+
+ if X_type == "dataframe":
+ assert hasattr(X_res, "loc")
+
+
+def test_random_over_sampler_equivalence_shrinkage(data):
+ # check that a shrinkage factor of 0 is equivalent to not create a smoothed
+ # bootstrap
+ X, y = data
+
+ ros_not_shrink = RandomOverSampler(shrinkage=0, random_state=0)
+ ros_hard_bootstrap = RandomOverSampler(shrinkage=None, random_state=0)
+
+ X_res_not_shrink, y_res_not_shrink = ros_not_shrink.fit_resample(X, y)
+ X_res, y_res = ros_hard_bootstrap.fit_resample(X, y)
+
+ assert_allclose(X_res_not_shrink, X_res)
+ assert_allclose(y_res_not_shrink, y_res)
+
+ assert y_res.shape == (14,)
+ assert X_res.shape == (14, 2)
+ assert y_res_not_shrink.shape == (14,)
+ assert X_res_not_shrink.shape == (14, 2)
+
+
+def test_random_over_sampler_shrinkage_behaviour(data):
+ # check the behaviour of the shrinkage parameter
+ # the covariance of the data generated with the larger shrinkage factor
+ # should also be larger.
+ X, y = data
+
+ ros = RandomOverSampler(shrinkage=1, random_state=0)
+ X_res_shink_1, y_res_shrink_1 = ros.fit_resample(X, y)
+
+ ros.set_params(shrinkage=5)
+ X_res_shink_5, y_res_shrink_5 = ros.fit_resample(X, y)
+
+ disperstion_shrink_1 = np.linalg.det(np.cov(X_res_shink_1[y_res_shrink_1 == 0].T))
+ disperstion_shrink_5 = np.linalg.det(np.cov(X_res_shink_5[y_res_shrink_5 == 0].T))
+
+ assert disperstion_shrink_1 < disperstion_shrink_5
+
+
+@pytest.mark.parametrize(
+ "shrinkage, err_msg",
+ [
+ ({}, "`shrinkage` should contain a shrinkage factor for each class"),
+ ({0: -1}, "The shrinkage factor needs to be >= 0"),
+ ],
+)
+def test_random_over_sampler_shrinkage_error(data, shrinkage, err_msg):
+ # check the validation of the shrinkage parameter
+ X, y = data
+ ros = RandomOverSampler(shrinkage=shrinkage)
+ with pytest.raises(ValueError, match=err_msg):
+ ros.fit_resample(X, y)
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy", ["auto", "minority", "not minority", "not majority", "all"]
+)
def test_random_over_sampler_strings(sampling_strategy):
"""Check that we support all supposed strings as `sampling_strategy` in
a sampler inheriting from `BaseOverSampler`."""
- pass
+
+ X, y = make_classification(
+ n_samples=100,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.3, 0.6],
+ random_state=0,
+ )
+ RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)
def test_random_over_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
- pass
+ pd = pytest.importorskip("pandas")
+ X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
+ y = X["label"]
+ ros = RandomOverSampler(random_state=0)
+ X_res, y_res = ros.fit_resample(X, y)
+
+ pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
+ pd.testing.assert_index_equal(X_res.index, y_res.index)
+ assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1]))
def test_random_over_sampler_full_nat():
@@ -28,4 +295,18 @@ def test_random_over_sampler_full_nat():
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
"""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "col_str": ["abc", "def", "xyz"],
+ "col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]),
+ }
+ )
+ y = np.array([0, 0, 1])
+
+ X_res, y_res = RandomOverSampler().fit_resample(X, y)
+ assert X_res.shape == (4, 2)
+ assert y_res.shape == (4,)
+
+ assert X_res["col_timedelta"].dtype == "timedelta64[ns]"
diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py
index 6c4f580..7453446 100644
--- a/imblearn/pipeline.py
+++ b/imblearn/pipeline.py
@@ -1,7 +1,17 @@
-"""
+"""
The :mod:`imblearn.pipeline` module implements utilities to build a
composite estimator, as a chain of transforms, samples and estimators.
"""
+# Adapted from scikit-learn
+
+# Author: Edouard Duchesnay
+# Gael Varoquaux
+# Virgile Fritsch
+# Alexandre Gramfort
+# Lars Buitinck
+# Christos Aridas
+# Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: BSD
import sklearn
from sklearn import pipeline
from sklearn.base import clone
@@ -9,14 +19,25 @@ from sklearn.utils import Bunch
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_memory
+
from .base import _ParamsValidationMixin
-from .utils._metadata_requests import METHODS, MetadataRouter, MethodMapping, _raise_for_params, _routing_enabled, process_routing
+from .utils._metadata_requests import (
+ METHODS,
+ MetadataRouter,
+ MethodMapping,
+ _raise_for_params,
+ _routing_enabled,
+ process_routing,
+)
from .utils._param_validation import HasMethods, validate_params
from .utils.fixes import _fit_context
-METHODS.append('fit_resample')
-__all__ = ['Pipeline', 'make_pipeline']
+
+METHODS.append("fit_resample")
+
+__all__ = ["Pipeline", "make_pipeline"]
+
sklearn_version = parse_version(sklearn.__version__).base_version
-if parse_version(sklearn_version) < parse_version('1.5'):
+if parse_version(sklearn_version) < parse_version("1.5"):
from sklearn.utils import _print_elapsed_time
else:
from sklearn.utils._user_interface import _print_elapsed_time
@@ -128,20 +149,138 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
weighted avg 0.99 0.98 0.98 250
<BLANKLINE>
"""
- _parameter_constraints: dict = {'steps': 'no_validation', 'memory': [
- None, str, HasMethods(['cache'])], 'verbose': ['boolean']}
- def _iter(self, with_final=True, filter_passthrough=True,
- filter_resample=True):
+ _parameter_constraints: dict = {
+ "steps": "no_validation", # validated in `_validate_steps`
+ "memory": [None, str, HasMethods(["cache"])],
+ "verbose": ["boolean"],
+ }
+
+ # BaseEstimator interface
+
+ def _validate_steps(self):
+ names, estimators = zip(*self.steps)
+
+ # validate names
+ self._validate_names(names)
+
+ # validate estimators
+ transformers = estimators[:-1]
+ estimator = estimators[-1]
+
+ for t in transformers:
+ if t is None or t == "passthrough":
+ continue
+
+ is_transfomer = hasattr(t, "fit") and hasattr(t, "transform")
+ is_sampler = hasattr(t, "fit_resample")
+ is_not_transfomer_or_sampler = not (is_transfomer or is_sampler)
+
+ if is_not_transfomer_or_sampler:
+ raise TypeError(
+ "All intermediate steps of the chain should "
+ "be estimators that implement fit and transform or "
+ "fit_resample (but not both) or be a string 'passthrough' "
+ "'%s' (type %s) doesn't)" % (t, type(t))
+ )
+
+ if is_transfomer and is_sampler:
+ raise TypeError(
+ "All intermediate steps of the chain should "
+ "be estimators that implement fit and transform or "
+ "fit_resample."
+ " '%s' implements both)" % (t)
+ )
+
+ if isinstance(t, pipeline.Pipeline):
+ raise TypeError(
+ "All intermediate steps of the chain should not be Pipelines"
+ )
+
+ # We allow last estimator to be None as an identity transformation
+ if (
+ estimator is not None
+ and estimator != "passthrough"
+ and not hasattr(estimator, "fit")
+ ):
+ raise TypeError(
+ "Last step of Pipeline should implement fit or be "
+ "the string 'passthrough'. '%s' (type %s) doesn't"
+ % (estimator, type(estimator))
+ )
+
+ def _iter(self, with_final=True, filter_passthrough=True, filter_resample=True):
"""Generate (idx, (name, trans)) tuples from self.steps.
When `filter_passthrough` is `True`, 'passthrough' and None
transformers are filtered out. When `filter_resample` is `True`,
estimator with a method `fit_resample` are filtered out.
"""
- pass
-
- @_fit_context(prefer_skip_nested_validation=False)
+ it = super()._iter(with_final, filter_passthrough)
+ if filter_resample:
+ return filter(lambda x: not hasattr(x[-1], "fit_resample"), it)
+ else:
+ return it
+
+ # Estimator interface
+
+ # def _fit(self, X, y=None, **fit_params_steps):
+ def _fit(self, X, y=None, routed_params=None):
+ self.steps = list(self.steps)
+ self._validate_steps()
+ # Setup the memory
+ memory = check_memory(self.memory)
+
+ fit_transform_one_cached = memory.cache(_fit_transform_one)
+ fit_resample_one_cached = memory.cache(_fit_resample_one)
+
+ for step_idx, name, transformer in self._iter(
+ with_final=False, filter_passthrough=False, filter_resample=False
+ ):
+ if transformer is None or transformer == "passthrough":
+ with _print_elapsed_time("Pipeline", self._log_message(step_idx)):
+ continue
+
+ if hasattr(memory, "location") and memory.location is None:
+ # we do not clone when caching is disabled to
+ # preserve backward compatibility
+ cloned_transformer = transformer
+ else:
+ cloned_transformer = clone(transformer)
+
+ # Fit or load from cache the current transformer
+ if hasattr(cloned_transformer, "transform") or hasattr(
+ cloned_transformer, "fit_transform"
+ ):
+ X, fitted_transformer = fit_transform_one_cached(
+ cloned_transformer,
+ X,
+ y,
+ None,
+ message_clsname="Pipeline",
+ message=self._log_message(step_idx),
+ params=routed_params[name],
+ )
+ elif hasattr(cloned_transformer, "fit_resample"):
+ X, y, fitted_transformer = fit_resample_one_cached(
+ cloned_transformer,
+ X,
+ y,
+ message_clsname="Pipeline",
+ message=self._log_message(step_idx),
+ params=routed_params[name],
+ )
+ # Replace the transformer of the step with the fitted
+ # transformer. This is necessary when loading the transformer
+ # from the cache.
+ self.steps[step_idx] = (name, fitted_transformer)
+ return X, y
+
+ # The `fit_*` methods need to be overridden to support the samplers.
+ @_fit_context(
+ # estimators in Pipeline.steps are not validated yet
+ prefer_skip_nested_validation=False
+ )
def fit(self, X, y=None, **params):
"""Fit the model.
@@ -186,10 +325,26 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
self : Pipeline
This estimator.
"""
- pass
+ routed_params = self._check_method_params(method="fit", props=params)
+ Xt, yt = self._fit(X, y, routed_params)
+ with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
+ if self._final_estimator != "passthrough":
+ last_step_params = routed_params[self.steps[-1][0]]
+ self._final_estimator.fit(Xt, yt, **last_step_params["fit"])
+ return self
+
+ def _can_fit_transform(self):
+ return (
+ self._final_estimator == "passthrough"
+ or hasattr(self._final_estimator, "transform")
+ or hasattr(self._final_estimator, "fit_transform")
+ )
@available_if(_can_fit_transform)
- @_fit_context(prefer_skip_nested_validation=False)
+ @_fit_context(
+ # estimators in Pipeline.steps are not validated yet
+ prefer_skip_nested_validation=False
+ )
def fit_transform(self, X, y=None, **params):
"""Fit the model and transform with the final estimator.
@@ -233,9 +388,24 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Xt : array-like of shape (n_samples, n_transformed_features)
Transformed samples.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('predict'))
+ routed_params = self._check_method_params(method="fit_transform", props=params)
+ Xt, yt = self._fit(X, y, routed_params)
+
+ last_step = self._final_estimator
+ with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
+ if last_step == "passthrough":
+ return Xt
+ last_step_params = routed_params[self.steps[-1][0]]
+ if hasattr(last_step, "fit_transform"):
+ return last_step.fit_transform(
+ Xt, yt, **last_step_params["fit_transform"]
+ )
+ else:
+ return last_step.fit(Xt, y, **last_step_params["fit"]).transform(
+ Xt, **last_step_params["transform"]
+ )
+
+ @available_if(pipeline._final_estimator_has("predict"))
def predict(self, X, **params):
"""Transform the data, and apply `predict` with the final estimator.
@@ -282,10 +452,29 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_pred : ndarray
Result of calling `predict` on the final estimator.
"""
- pass
+ Xt = X
+
+ if not _routing_enabled():
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt)
+ return self.steps[-1][1].predict(Xt, **params)
+
+ # metadata routing enabled
+ routed_params = process_routing(self, "predict", **params)
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt, **routed_params[name].transform)
+ return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
+
+ def _can_fit_resample(self):
+ return self._final_estimator == "passthrough" or hasattr(
+ self._final_estimator, "fit_resample"
+ )
@available_if(_can_fit_resample)
- @_fit_context(prefer_skip_nested_validation=False)
+ @_fit_context(
+ # estimators in Pipeline.steps are not validated yet
+ prefer_skip_nested_validation=False
+ )
def fit_resample(self, X, y=None, **params):
"""Fit the model and sample with the final estimator.
@@ -332,10 +521,23 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
yt : array-like of shape (n_samples, n_transformed_features)
Transformed target.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('fit_predict'))
- @_fit_context(prefer_skip_nested_validation=False)
+ routed_params = self._check_method_params(method="fit_resample", props=params)
+ Xt, yt = self._fit(X, y, routed_params)
+ last_step = self._final_estimator
+ with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
+ if last_step == "passthrough":
+ return Xt
+ last_step_params = routed_params[self.steps[-1][0]]
+ if hasattr(last_step, "fit_resample"):
+ return last_step.fit_resample(
+ Xt, yt, **last_step_params["fit_resample"]
+ )
+
+ @available_if(pipeline._final_estimator_has("fit_predict"))
+ @_fit_context(
+ # estimators in Pipeline.steps are not validated yet
+ prefer_skip_nested_validation=False
+ )
def fit_predict(self, X, y=None, **params):
"""Apply `fit_predict` of last step in pipeline after transforms.
@@ -385,9 +587,20 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_pred : ndarray of shape (n_samples,)
The predicted target.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('predict_proba'))
+ routed_params = self._check_method_params(method="fit_predict", props=params)
+ Xt, yt = self._fit(X, y, routed_params)
+
+ params_last_step = routed_params[self.steps[-1][0]]
+ with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
+ y_pred = self.steps[-1][-1].fit_predict(
+ Xt, yt, **params_last_step.get("fit_predict", {})
+ )
+ return y_pred
+
+ # TODO: remove the following methods when the minimum scikit-learn >= 1.4
+ # They do not depend on resampling but we need to redefine them for the
+ # compatibility with the metadata routing framework.
+ @available_if(pipeline._final_estimator_has("predict_proba"))
def predict_proba(self, X, **params):
"""Transform the data, and apply `predict_proba` with the final estimator.
@@ -429,9 +642,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_proba` on the final estimator.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('decision_function'))
+ Xt = X
+
+ if not _routing_enabled():
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt)
+ return self.steps[-1][1].predict_proba(Xt, **params)
+
+ # metadata routing enabled
+ routed_params = process_routing(self, "predict_proba", **params)
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt, **routed_params[name].transform)
+ return self.steps[-1][1].predict_proba(
+ Xt, **routed_params[self.steps[-1][0]].predict_proba
+ )
+
+ @available_if(pipeline._final_estimator_has("decision_function"))
def decision_function(self, X, **params):
"""Transform the data, and apply `decision_function` with the final estimator.
@@ -461,9 +687,22 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_score : ndarray of shape (n_samples, n_classes)
Result of calling `decision_function` on the final estimator.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('score_samples'))
+ _raise_for_params(params, self, "decision_function")
+
+ # not branching here since params is only available if
+ # enable_metadata_routing=True
+ routed_params = process_routing(self, "decision_function", **params)
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(
+ Xt, **routed_params.get(name, {}).get("transform", {})
+ )
+ return self.steps[-1][1].decision_function(
+ Xt, **routed_params.get(self.steps[-1][0], {}).get("decision_function", {})
+ )
+
+ @available_if(pipeline._final_estimator_has("score_samples"))
def score_samples(self, X):
"""Transform the data, and apply `score_samples` with the final estimator.
@@ -483,9 +722,12 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_score : ndarray of shape (n_samples,)
Result of calling `score_samples` on the final estimator.
"""
- pass
+ Xt = X
+ for _, _, transformer in self._iter(with_final=False):
+ Xt = transformer.transform(Xt)
+ return self.steps[-1][1].score_samples(Xt)
- @available_if(pipeline._final_estimator_has('predict_log_proba'))
+ @available_if(pipeline._final_estimator_has("predict_log_proba"))
def predict_log_proba(self, X, **params):
"""Transform the data, and apply `predict_log_proba` with the final estimator.
@@ -527,7 +769,25 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
y_log_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_log_proba` on the final estimator.
"""
- pass
+ Xt = X
+
+ if not _routing_enabled():
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt)
+ return self.steps[-1][1].predict_log_proba(Xt, **params)
+
+ # metadata routing enabled
+ routed_params = process_routing(self, "predict_log_proba", **params)
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt, **routed_params[name].transform)
+ return self.steps[-1][1].predict_log_proba(
+ Xt, **routed_params[self.steps[-1][0]].predict_log_proba
+ )
+
+ def _can_transform(self):
+ return self._final_estimator == "passthrough" or hasattr(
+ self._final_estimator, "transform"
+ )
@available_if(_can_transform)
def transform(self, X, **params):
@@ -562,7 +822,18 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Xt : ndarray of shape (n_samples, n_transformed_features)
Transformed data.
"""
- pass
+ _raise_for_params(params, self, "transform")
+
+ # not branching here since params is only available if
+ # enable_metadata_routing=True
+ routed_params = process_routing(self, "transform", **params)
+ Xt = X
+ for _, name, transform in self._iter():
+ Xt = transform.transform(Xt, **routed_params[name].transform)
+ return Xt
+
+ def _can_inverse_transform(self):
+ return all(hasattr(t, "inverse_transform") for _, _, t in self._iter())
@available_if(_can_inverse_transform)
def inverse_transform(self, Xt, **params):
@@ -594,9 +865,19 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
Inverse transformed data, that is, data in the original feature
space.
"""
- pass
-
- @available_if(pipeline._final_estimator_has('score'))
+ _raise_for_params(params, self, "inverse_transform")
+
+ # we don't have to branch here, since params is only non-empty if
+ # enable_metadata_routing=True.
+ routed_params = process_routing(self, "inverse_transform", **params)
+ reverse_iter = reversed(list(self._iter()))
+ for _, name, transform in reverse_iter:
+ Xt = transform.inverse_transform(
+ Xt, **routed_params[name].inverse_transform
+ )
+ return Xt
+
+ @available_if(pipeline._final_estimator_has("score"))
def score(self, X, y=None, sample_weight=None, **params):
"""Transform the data, and apply `score` with the final estimator.
@@ -633,8 +914,27 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
score : float
Result of calling `score` on the final estimator.
"""
- pass
-
+ Xt = X
+ if not _routing_enabled():
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt)
+ score_params = {}
+ if sample_weight is not None:
+ score_params["sample_weight"] = sample_weight
+ return self.steps[-1][1].score(Xt, y, **score_params)
+
+ # metadata routing is enabled.
+ routed_params = process_routing(
+ self, "score", sample_weight=sample_weight, **params
+ )
+
+ Xt = X
+ for _, name, transform in self._iter(with_final=False):
+ Xt = transform.transform(Xt, **routed_params[name].transform)
+ return self.steps[-1][1].score(Xt, y, **routed_params[self.steps[-1][0]].score)
+
+ # TODO: once scikit-learn >= 1.4, the following function should be simplified by
+ # calling `super().get_metadata_routing()`
def get_metadata_routing(self):
"""Get metadata routing of this object.
@@ -647,7 +947,116 @@ class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
- pass
+ router = MetadataRouter(owner=self.__class__.__name__)
+
+ # first we add all steps except the last one
+ for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
+ method_mapping = MethodMapping()
+ # fit, fit_predict, and fit_transform call fit_transform if it
+ # exists, or else fit and transform
+ if hasattr(trans, "fit_transform"):
+ (
+ method_mapping.add(caller="fit", callee="fit_transform")
+ .add(caller="fit_transform", callee="fit_transform")
+ .add(caller="fit_predict", callee="fit_transform")
+ .add(caller="fit_resample", callee="fit_transform")
+ )
+ else:
+ (
+ method_mapping.add(caller="fit", callee="fit")
+ .add(caller="fit", callee="transform")
+ .add(caller="fit_transform", callee="fit")
+ .add(caller="fit_transform", callee="transform")
+ .add(caller="fit_predict", callee="fit")
+ .add(caller="fit_predict", callee="transform")
+ .add(caller="fit_resample", callee="fit")
+ .add(caller="fit_resample", callee="transform")
+ )
+
+ (
+ method_mapping.add(caller="predict", callee="transform")
+ .add(caller="predict", callee="transform")
+ .add(caller="predict_proba", callee="transform")
+ .add(caller="decision_function", callee="transform")
+ .add(caller="predict_log_proba", callee="transform")
+ .add(caller="transform", callee="transform")
+ .add(caller="inverse_transform", callee="inverse_transform")
+ .add(caller="score", callee="transform")
+ .add(caller="fit_resample", callee="transform")
+ )
+
+ router.add(method_mapping=method_mapping, **{name: trans})
+
+ final_name, final_est = self.steps[-1]
+ if final_est is None or final_est == "passthrough":
+ return router
+
+ # then we add the last step
+ method_mapping = MethodMapping()
+ if hasattr(final_est, "fit_transform"):
+ (
+ method_mapping.add(caller="fit_transform", callee="fit_transform").add(
+ caller="fit_resample", callee="fit_transform"
+ )
+ )
+ else:
+ (
+ method_mapping.add(caller="fit", callee="fit")
+ .add(caller="fit", callee="transform")
+ .add(caller="fit_resample", callee="fit")
+ .add(caller="fit_resample", callee="transform")
+ )
+ (
+ method_mapping.add(caller="fit", callee="fit")
+ .add(caller="predict", callee="predict")
+ .add(caller="fit_predict", callee="fit_predict")
+ .add(caller="predict_proba", callee="predict_proba")
+ .add(caller="decision_function", callee="decision_function")
+ .add(caller="predict_log_proba", callee="predict_log_proba")
+ .add(caller="transform", callee="transform")
+ .add(caller="inverse_transform", callee="inverse_transform")
+ .add(caller="score", callee="score")
+ .add(caller="fit_resample", callee="fit_resample")
+ )
+
+ router.add(method_mapping=method_mapping, **{final_name: final_est})
+ return router
+
+ def _check_method_params(self, method, props, **kwargs):
+ if _routing_enabled():
+ routed_params = process_routing(self, method, **props, **kwargs)
+ return routed_params
+ else:
+ fit_params_steps = Bunch(
+ **{
+ name: Bunch(**{method: {} for method in METHODS})
+ for name, step in self.steps
+ if step is not None
+ }
+ )
+ for pname, pval in props.items():
+ if "__" not in pname:
+ raise ValueError(
+ "Pipeline.fit does not accept the {} parameter. "
+ "You can pass parameters to specific steps of your "
+ "pipeline using the stepname__parameter format, e.g. "
+ "`Pipeline.fit(X, y, logisticregression__sample_weight"
+ "=sample_weight)`.".format(pname)
+ )
+ step, param = pname.split("__", 1)
+ fit_params_steps[step]["fit"][param] = pval
+ # without metadata routing, fit_transform and fit_predict
+ # get all the same params and pass it to the last fit.
+ fit_params_steps[step]["fit_transform"][param] = pval
+ fit_params_steps[step]["fit_predict"][param] = pval
+ return fit_params_steps
+
+
+def _fit_resample_one(sampler, X, y, message_clsname="", message=None, params=None):
+ with _print_elapsed_time(message_clsname, message):
+ X_res, y_res = sampler.fit_resample(X, y, **params.get("fit_resample", {}))
+
+ return X_res, y_res, sampler
def _transform_one(transformer, X, y, weight, params):
@@ -672,11 +1081,16 @@ def _transform_one(transformer, X, y, weight, params):
This should be of the form ``process_routing()["step_name"]``.
"""
- pass
+ res = transformer.transform(X, **params.transform)
+ # if we have a weight for this transformer, multiply output
+ if weight is None:
+ return res
+ return res * weight
-def _fit_transform_one(transformer, X, y, weight, message_clsname='',
- message=None, params=None):
+def _fit_transform_one(
+ transformer, X, y, weight, message_clsname="", message=None, params=None
+):
"""
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
with the fitted transformer. If ``weight`` is not ``None``, the result will
@@ -684,11 +1098,24 @@ def _fit_transform_one(transformer, X, y, weight, message_clsname='',
``params`` needs to be of the form ``process_routing()["step_name"]``.
"""
- pass
-
-
-@validate_params({'memory': [None, str, HasMethods(['cache'])], 'verbose':
- ['boolean']}, prefer_skip_nested_validation=True)
+ params = params or {}
+ with _print_elapsed_time(message_clsname, message):
+ if hasattr(transformer, "fit_transform"):
+ res = transformer.fit_transform(X, y, **params.get("fit_transform", {}))
+ else:
+ res = transformer.fit(X, y, **params.get("fit", {})).transform(
+ X, **params.get("transform", {})
+ )
+
+ if weight is None:
+ return res, transformer
+ return res * weight, transformer
+
+
+@validate_params(
+ {"memory": [None, str, HasMethods(["cache"])], "verbose": ["boolean"]},
+ prefer_skip_nested_validation=True,
+)
def make_pipeline(*steps, memory=None, verbose=False):
"""Construct a Pipeline from the given estimators.
@@ -733,4 +1160,4 @@ def make_pipeline(*steps, memory=None, verbose=False):
Pipeline(steps=[('standardscaler', StandardScaler()),
('gaussiannb', GaussianNB())])
"""
- pass
+ return Pipeline(pipeline._name_estimators(steps), memory=memory, verbose=verbose)
diff --git a/imblearn/tensorflow/_generator.py b/imblearn/tensorflow/_generator.py
index c55dd52..7e50322 100644
--- a/imblearn/tensorflow/_generator.py
+++ b/imblearn/tensorflow/_generator.py
@@ -1,15 +1,25 @@
"""Implement generators for ``tensorflow`` which will balance the data."""
+
from scipy.sparse import issparse
from sklearn.base import clone
from sklearn.utils import _safe_indexing, check_random_state
+
from ..under_sampling import RandomUnderSampler
from ..utils import Substitution
from ..utils._docstring import _random_state_docstring
@Substitution(random_state=_random_state_docstring)
-def balanced_batch_generator(X, y, *, sample_weight=None, sampler=None,
- batch_size=32, keep_sparse=False, random_state=None):
+def balanced_batch_generator(
+ X,
+ y,
+ *,
+ sample_weight=None,
+ sampler=None,
+ batch_size=32,
+ keep_sparse=False,
+ random_state=None,
+):
"""Create a balanced batch generator to train tensorflow model.
Returns a generator --- as well as the number of step per epoch --- to
@@ -53,4 +63,35 @@ def balanced_batch_generator(X, y, *, sample_weight=None, sampler=None,
steps_per_epoch : int
The number of samples per epoch.
"""
- pass
+
+ random_state = check_random_state(random_state)
+ if sampler is None:
+ sampler_ = RandomUnderSampler(random_state=random_state)
+ else:
+ sampler_ = clone(sampler)
+ sampler_.fit_resample(X, y)
+ if not hasattr(sampler_, "sample_indices_"):
+ raise ValueError("'sampler' needs to have an attribute 'sample_indices_'.")
+ indices = sampler_.sample_indices_
+ # shuffle the indices since the sampler are packing them by class
+ random_state.shuffle(indices)
+
+ def generator(X, y, sample_weight, indices, batch_size):
+ while True:
+ for index in range(0, len(indices), batch_size):
+ X_res = _safe_indexing(X, indices[index : index + batch_size])
+ y_res = _safe_indexing(y, indices[index : index + batch_size])
+ if issparse(X_res) and not keep_sparse:
+ X_res = X_res.toarray()
+ if sample_weight is None:
+ yield X_res, y_res
+ else:
+ sw_res = _safe_indexing(
+ sample_weight, indices[index : index + batch_size]
+ )
+ yield X_res, y_res, sw_res
+
+ return (
+ generator(X, y, sample_weight, indices, batch_size),
+ int(indices.size // batch_size),
+ )
diff --git a/imblearn/tensorflow/tests/test_generator.py b/imblearn/tensorflow/tests/test_generator.py
index 752979f..e0c7a91 100644
--- a/imblearn/tensorflow/tests/test_generator.py
+++ b/imblearn/tensorflow/tests/test_generator.py
@@ -3,8 +3,169 @@ import pytest
from scipy import sparse
from sklearn.datasets import load_iris
from sklearn.utils.fixes import parse_version
+
from imblearn.datasets import make_imbalance
from imblearn.over_sampling import RandomOverSampler
from imblearn.tensorflow import balanced_batch_generator
from imblearn.under_sampling import NearMiss
-tf = pytest.importorskip('tensorflow')
+
+tf = pytest.importorskip("tensorflow")
+
+
+@pytest.fixture
+def data():
+ X, y = load_iris(return_X_y=True)
+ X, y = make_imbalance(X, y, sampling_strategy={0: 30, 1: 50, 2: 40})
+ X = X.astype(np.float32)
+ return X, y
+
+
+def check_balanced_batch_generator_tf_1_X_X(dataset, sampler):
+ X, y = dataset
+ batch_size = 10
+ training_generator, steps_per_epoch = balanced_batch_generator(
+ X,
+ y,
+ sample_weight=None,
+ sampler=sampler,
+ batch_size=batch_size,
+ random_state=42,
+ )
+
+ learning_rate = 0.01
+ epochs = 10
+ input_size = X.shape[1]
+ output_size = 3
+
+ # helper functions
+ def init_weights(shape):
+ return tf.Variable(tf.random_normal(shape, stddev=0.01))
+
+ def accuracy(y_true, y_pred):
+ return np.mean(np.argmax(y_pred, axis=1) == y_true)
+
+ # input and output
+ data = tf.placeholder("float32", shape=[None, input_size])
+ targets = tf.placeholder("int32", shape=[None])
+
+ # build the model and weights
+ W = init_weights([input_size, output_size])
+ b = init_weights([output_size])
+ out_act = tf.nn.sigmoid(tf.matmul(data, W) + b)
+
+ # build the loss, predict, and train operator
+ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=out_act, labels=targets
+ )
+ loss = tf.reduce_sum(cross_entropy)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+ train_op = optimizer.minimize(loss)
+ predict = tf.nn.softmax(out_act)
+
+ # Initialization of all variables in the graph
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+
+ for e in range(epochs):
+ for i in range(steps_per_epoch):
+ X_batch, y_batch = next(training_generator)
+ sess.run(
+ [train_op, loss],
+ feed_dict={data: X_batch, targets: y_batch},
+ )
+
+ # For each epoch, run accuracy on train and test
+ predicts_train = sess.run(predict, feed_dict={data: X})
+ print(f"epoch: {e} train accuracy: {accuracy(y, predicts_train):.3f}")
+
+
+def check_balanced_batch_generator_tf_2_X_X_compat_1_X_X(dataset, sampler):
+ tf.compat.v1.disable_eager_execution()
+
+ X, y = dataset
+ batch_size = 10
+ training_generator, steps_per_epoch = balanced_batch_generator(
+ X,
+ y,
+ sample_weight=None,
+ sampler=sampler,
+ batch_size=batch_size,
+ random_state=42,
+ )
+
+ learning_rate = 0.01
+ epochs = 10
+ input_size = X.shape[1]
+ output_size = 3
+
+ # helper functions
+ def init_weights(shape):
+ return tf.Variable(tf.random.normal(shape, stddev=0.01))
+
+ def accuracy(y_true, y_pred):
+ return np.mean(np.argmax(y_pred, axis=1) == y_true)
+
+ # input and output
+ data = tf.compat.v1.placeholder("float32", shape=[None, input_size])
+ targets = tf.compat.v1.placeholder("int32", shape=[None])
+
+ # build the model and weights
+ W = init_weights([input_size, output_size])
+ b = init_weights([output_size])
+ out_act = tf.nn.sigmoid(tf.matmul(data, W) + b)
+
+ # build the loss, predict, and train operator
+ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=out_act, labels=targets
+ )
+ loss = tf.reduce_sum(input_tensor=cross_entropy)
+ optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
+ train_op = optimizer.minimize(loss)
+ predict = tf.nn.softmax(out_act)
+
+ # Initialization of all variables in the graph
+ init = tf.compat.v1.global_variables_initializer()
+
+ with tf.compat.v1.Session() as sess:
+ sess.run(init)
+
+ for e in range(epochs):
+ for i in range(steps_per_epoch):
+ X_batch, y_batch = next(training_generator)
+ sess.run(
+ [train_op, loss],
+ feed_dict={data: X_batch, targets: y_batch},
+ )
+
+ # For each epoch, run accuracy on train and test
+ predicts_train = sess.run(predict, feed_dict={data: X})
+ print(f"epoch: {e} train accuracy: {accuracy(y, predicts_train):.3f}")
+
+
+@pytest.mark.parametrize("sampler", [None, NearMiss(), RandomOverSampler()])
+def test_balanced_batch_generator(data, sampler):
+ if parse_version(tf.__version__) < parse_version("2.0.0"):
+ check_balanced_batch_generator_tf_1_X_X(data, sampler)
+ else:
+ check_balanced_batch_generator_tf_2_X_X_compat_1_X_X(data, sampler)
+
+
+@pytest.mark.parametrize("keep_sparse", [True, False])
+def test_balanced_batch_generator_function_sparse(data, keep_sparse):
+ X, y = data
+
+ training_generator, steps_per_epoch = balanced_batch_generator(
+ sparse.csr_matrix(X),
+ y,
+ keep_sparse=keep_sparse,
+ batch_size=10,
+ random_state=42,
+ )
+ for idx in range(steps_per_epoch):
+ X_batch, y_batch = next(training_generator)
+ if keep_sparse:
+ assert sparse.issparse(X_batch)
+ else:
+ assert not sparse.issparse(X_batch)
diff --git a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py
index 8b6d169..5e2ca3a 100644
--- a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py
+++ b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py
@@ -1,20 +1,30 @@
"""Class to perform under-sampling by generating centroids based on
clustering."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Fernando Nogueira
+# Christos Aridas
+# License: MIT
+
import numpy as np
from scipy import sparse
from sklearn.base import clone
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution
from ...utils._docstring import _random_state_docstring
from ...utils._param_validation import HasMethods, StrOptions
from ..base import BaseUnderSampler
-VOTING_KIND = 'auto', 'hard', 'soft'
+
+VOTING_KIND = ("auto", "hard", "soft")
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, random_state=_random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ random_state=_random_state_docstring,
+)
class ClusterCentroids(BaseUnderSampler):
"""Undersample by generating centroids based on clustering methods.
@@ -103,13 +113,22 @@ class ClusterCentroids(BaseUnderSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{...}})
"""
- _parameter_constraints: dict = {**BaseUnderSampler.
- _parameter_constraints, 'estimator': [HasMethods(['fit', 'predict']
- ), None], 'voting': [StrOptions({'auto', 'hard', 'soft'})],
- 'random_state': ['random_state']}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- estimator=None, voting='auto'):
+ _parameter_constraints: dict = {
+ **BaseUnderSampler._parameter_constraints,
+ "estimator": [HasMethods(["fit", "predict"]), None],
+ "voting": [StrOptions({"auto", "hard", "soft"})],
+ "random_state": ["random_state"],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ estimator=None,
+ voting="auto",
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.estimator = estimator
@@ -117,4 +136,70 @@ class ClusterCentroids(BaseUnderSampler):
def _validate_estimator(self):
"""Private function to create the KMeans estimator"""
- pass
+ if self.estimator is None:
+ self.estimator_ = KMeans(random_state=self.random_state)
+ else:
+ self.estimator_ = clone(self.estimator)
+ if "n_clusters" not in self.estimator_.get_params():
+ raise ValueError(
+ "`estimator` should be a clustering estimator exposing a parameter"
+ " `n_clusters` and a fitted parameter `cluster_centers_`."
+ )
+
+ def _generate_sample(self, X, y, centroids, target_class):
+ if self.voting_ == "hard":
+ nearest_neighbors = NearestNeighbors(n_neighbors=1)
+ nearest_neighbors.fit(X, y)
+ indices = nearest_neighbors.kneighbors(centroids, return_distance=False)
+ X_new = _safe_indexing(X, np.squeeze(indices))
+ else:
+ if sparse.issparse(X):
+ X_new = sparse.csr_matrix(centroids, dtype=X.dtype)
+ else:
+ X_new = centroids
+ y_new = np.array([target_class] * centroids.shape[0], dtype=y.dtype)
+
+ return X_new, y_new
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+
+ if self.voting == "auto":
+ self.voting_ = "hard" if sparse.issparse(X) else "soft"
+ else:
+ self.voting_ = self.voting
+
+ X_resampled, y_resampled = [], []
+ for target_class in np.unique(y):
+ target_class_indices = np.flatnonzero(y == target_class)
+ if target_class in self.sampling_strategy_.keys():
+ n_samples = self.sampling_strategy_[target_class]
+ self.estimator_.set_params(**{"n_clusters": n_samples})
+ self.estimator_.fit(_safe_indexing(X, target_class_indices))
+ if not hasattr(self.estimator_, "cluster_centers_"):
+ raise RuntimeError(
+ "`estimator` should be a clustering estimator exposing a "
+ "fitted parameter `cluster_centers_`."
+ )
+ X_new, y_new = self._generate_sample(
+ _safe_indexing(X, target_class_indices),
+ _safe_indexing(y, target_class_indices),
+ self.estimator_.cluster_centers_,
+ target_class,
+ )
+ X_resampled.append(X_new)
+ y_resampled.append(y_new)
+ else:
+ X_resampled.append(_safe_indexing(X, target_class_indices))
+ y_resampled.append(_safe_indexing(y, target_class_indices))
+
+ if sparse.issparse(X):
+ X_resampled = sparse.vstack(X_resampled)
+ else:
+ X_resampled = np.vstack(X_resampled)
+ y_resampled = np.hstack(y_resampled)
+
+ return X_resampled, np.array(y_resampled, dtype=y.dtype)
+
+ def _more_tags(self):
+ return {"sample_indices": False}
diff --git a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py
index df6ac55..b51e350 100644
--- a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py
+++ b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py
@@ -1,17 +1,163 @@
"""Test the module cluster centroids."""
from collections import Counter
+
import numpy as np
import pytest
from scipy import sparse
from sklearn.cluster import KMeans
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
+
from imblearn.under_sampling import ClusterCentroids
from imblearn.utils.testing import _CustomClusterer
+
RND_SEED = 0
-X = np.array([[0.04352327, -0.20515826], [0.92923648, 0.76103773], [
- 0.20792588, 1.49407907], [0.47104475, 0.44386323], [0.22950086,
- 0.33367433], [0.15490546, 0.3130677], [0.09125309, -0.85409574], [
- 0.12372842, 0.6536186], [0.13347175, 0.12167502], [0.094035, -2.55298982]])
+X = np.array(
+ [
+ [0.04352327, -0.20515826],
+ [0.92923648, 0.76103773],
+ [0.20792588, 1.49407907],
+ [0.47104475, 0.44386323],
+ [0.22950086, 0.33367433],
+ [0.15490546, 0.3130677],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.13347175, 0.12167502],
+ [0.094035, -2.55298982],
+ ]
+)
Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1])
-R_TOL = 0.0001
+R_TOL = 1e-4
+
+
+@pytest.mark.parametrize(
+ "X, expected_voting", [(X, "soft"), (sparse.csr_matrix(X), "hard")]
+)
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_fit_resample_check_voting(X, expected_voting):
+ cc = ClusterCentroids(random_state=RND_SEED)
+ cc.fit_resample(X, Y)
+ assert cc.voting_ == expected_voting
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_fit_resample_auto():
+ sampling_strategy = "auto"
+ cc = ClusterCentroids(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ X_resampled, y_resampled = cc.fit_resample(X, Y)
+ assert X_resampled.shape == (6, 2)
+ assert y_resampled.shape == (6,)
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_fit_resample_half():
+ sampling_strategy = {0: 3, 1: 6}
+ cc = ClusterCentroids(sampling_strategy=sampling_strategy, random_state=RND_SEED)
+ X_resampled, y_resampled = cc.fit_resample(X, Y)
+ assert X_resampled.shape == (9, 2)
+ assert y_resampled.shape == (9,)
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_multiclass_fit_resample():
+ y = Y.copy()
+ y[5] = 2
+ y[6] = 2
+ cc = ClusterCentroids(random_state=RND_SEED)
+ _, y_resampled = cc.fit_resample(X, y)
+ count_y_res = Counter(y_resampled)
+ assert count_y_res[0] == 2
+ assert count_y_res[1] == 2
+ assert count_y_res[2] == 2
+
+
+def test_fit_resample_object():
+ sampling_strategy = "auto"
+ cluster = KMeans(random_state=RND_SEED, n_init=1)
+ cc = ClusterCentroids(
+ sampling_strategy=sampling_strategy,
+ random_state=RND_SEED,
+ estimator=cluster,
+ )
+
+ X_resampled, y_resampled = cc.fit_resample(X, Y)
+ assert X_resampled.shape == (6, 2)
+ assert y_resampled.shape == (6,)
+
+
+def test_fit_hard_voting():
+ sampling_strategy = "auto"
+ voting = "hard"
+ cluster = KMeans(random_state=RND_SEED, n_init=1)
+ cc = ClusterCentroids(
+ sampling_strategy=sampling_strategy,
+ random_state=RND_SEED,
+ estimator=cluster,
+ voting=voting,
+ )
+
+ X_resampled, y_resampled = cc.fit_resample(X, Y)
+ assert X_resampled.shape == (6, 2)
+ assert y_resampled.shape == (6,)
+ for x in X_resampled:
+ assert np.any(np.all(x == X, axis=1))
+
+
+@pytest.mark.filterwarnings("ignore:The default value of `n_init` will change")
+def test_cluster_centroids_hard_target_class():
+ # check that the samples selecting by the hard voting corresponds to the
+ # targeted class
+ # non-regression test for:
+ # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/738
+ X, y = make_classification(
+ n_samples=1000,
+ n_features=2,
+ n_informative=1,
+ n_redundant=0,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ weights=[0.3, 0.7],
+ class_sep=0.01,
+ random_state=0,
+ )
+
+ cc = ClusterCentroids(voting="hard", random_state=0)
+ X_res, y_res = cc.fit_resample(X, y)
+
+ minority_class_indices = np.flatnonzero(y == 0)
+ X_minority_class = X[minority_class_indices]
+
+ resampled_majority_class_indices = np.flatnonzero(y_res == 1)
+ X_res_majority = X_res[resampled_majority_class_indices]
+
+ sample_from_minority_in_majority = [
+ np.all(np.isclose(selected_sample, minority_sample))
+ for selected_sample in X_res_majority
+ for minority_sample in X_minority_class
+ ]
+ assert sum(sample_from_minority_in_majority) == 0
+
+
+def test_cluster_centroids_custom_clusterer():
+ clusterer = _CustomClusterer()
+ cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED)
+ cc.fit_resample(X, Y)
+ assert isinstance(cc.estimator_.cluster_centers_, np.ndarray)
+
+ clusterer = _CustomClusterer(expose_cluster_centers=False)
+ cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED)
+ err_msg = (
+ "`estimator` should be a clustering estimator exposing a fitted parameter "
+ "`cluster_centers_`."
+ )
+ with pytest.raises(RuntimeError, match=err_msg):
+ cc.fit_resample(X, Y)
+
+ clusterer = LogisticRegression()
+ cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED)
+ err_msg = (
+ "`estimator` should be a clustering estimator exposing a parameter "
+ "`n_clusters` and a fitted parameter `cluster_centers_`."
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ cc.fit_resample(X, Y)
diff --git a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py
index 5bd5434..ac012a0 100644
--- a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py
+++ b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py
@@ -1,22 +1,31 @@
"""Class to perform under-sampling based on the condensed nearest neighbour
method."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
import warnings
from collections import Counter
+
import numpy as np
from scipy.sparse import issparse
from sklearn.base import clone
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils import _safe_indexing, check_random_state
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval
from ..base import BaseCleaningSampler
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class CondensedNearestNeighbour(BaseCleaningSampler):
"""Undersample based on the condensed nearest neighbour method.
@@ -103,25 +112,41 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
>>> from collections import Counter # doctest: +SKIP
>>> from sklearn.datasets import fetch_openml # doctest: +SKIP
>>> from sklearn.preprocessing import scale # doctest: +SKIP
- >>> from imblearn.under_sampling import CondensedNearestNeighbour # doctest: +SKIP
+ >>> from imblearn.under_sampling import \
+CondensedNearestNeighbour # doctest: +SKIP
>>> X, y = fetch_openml('diabetes', version=1, return_X_y=True) # doctest: +SKIP
>>> X = scale(X) # doctest: +SKIP
>>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP
- Original dataset shape Counter({{'tested_negative': 500, 'tested_positive': 268}}) # doctest: +SKIP
+ Original dataset shape Counter({{'tested_negative': 500, \
+ 'tested_positive': 268}}) # doctest: +SKIP
>>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP
>>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP
>>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP
- Resampled dataset shape Counter({{'tested_positive': 268, 'tested_negative': 181}}) # doctest: +SKIP
+ Resampled dataset shape Counter({{'tested_positive': 268, \
+ 'tested_negative': 181}}) # doctest: +SKIP
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph']), None], 'n_seeds_S': [Interval(numbers.
- Integral, 1, None, closed='left')], 'n_jobs': [numbers.Integral,
- None], 'random_state': ['random_state']}
-
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- n_neighbors=None, n_seeds_S=1, n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ None,
+ ],
+ "n_seeds_S": [Interval(numbers.Integral, 1, None, closed="left")],
+ "n_jobs": [numbers.Integral, None],
+ "random_state": ["random_state"],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ n_neighbors=None,
+ n_seeds_S=1,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.n_neighbors = n_neighbors
@@ -130,9 +155,107 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
def _validate_estimator(self):
"""Private function to create the NN estimator"""
- pass
+ if self.n_neighbors is None:
+ estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
+ elif isinstance(self.n_neighbors, numbers.Integral):
+ estimator = KNeighborsClassifier(
+ n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
+ )
+ elif isinstance(self.n_neighbors, KNeighborsClassifier):
+ estimator = clone(self.n_neighbors)
+
+ return estimator
+
+ def _fit_resample(self, X, y):
+ estimator = self._validate_estimator()
+
+ random_state = check_random_state(self.random_state)
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+ idx_under = np.empty((0,), dtype=int)
+
+ self.estimators_ = []
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ # Randomly get one sample from the majority class
+ # Generate the index to select
+ idx_maj = np.flatnonzero(y == target_class)
+ idx_maj_sample = idx_maj[
+ random_state.randint(
+ low=0,
+ high=target_stats[target_class],
+ size=self.n_seeds_S,
+ )
+ ]
+
+ # Create the set C - One majority samples and all minority
+ C_indices = np.append(
+ np.flatnonzero(y == class_minority), idx_maj_sample
+ )
+ C_x = _safe_indexing(X, C_indices)
+ C_y = _safe_indexing(y, C_indices)
+
+ # Create the set S - all majority samples
+ S_indices = np.flatnonzero(y == target_class)
+ S_x = _safe_indexing(X, S_indices)
+ S_y = _safe_indexing(y, S_indices)
+
+ # fit knn on C
+ self.estimators_.append(clone(estimator).fit(C_x, C_y))
+
+ good_classif_label = idx_maj_sample.copy()
+ # Check each sample in S if we keep it or drop it
+ for idx_sam, (x_sam, y_sam) in enumerate(zip(S_x, S_y)):
+ # Do not select sample which are already well classified
+ if idx_sam in good_classif_label:
+ continue
+
+ # Classify on S
+ if not issparse(x_sam):
+ x_sam = x_sam.reshape(1, -1)
+ pred_y = self.estimators_[-1].predict(x_sam)
+
+ # If the prediction do not agree with the true label
+ # append it in C_x
+ if y_sam != pred_y:
+ # Keep the index for later
+ idx_maj_sample = np.append(idx_maj_sample, idx_maj[idx_sam])
+
+ # Update C
+ C_indices = np.append(C_indices, idx_maj[idx_sam])
+ C_x = _safe_indexing(X, C_indices)
+ C_y = _safe_indexing(y, C_indices)
+
+ # fit a knn on C
+ self.estimators_[-1].fit(C_x, C_y)
+
+ # This experimental to speed up the search
+ # Classify all the element in S and avoid to test the
+ # well classified elements
+ pred_S_y = self.estimators_[-1].predict(S_x)
+ good_classif_label = np.unique(
+ np.append(idx_maj_sample, np.flatnonzero(pred_S_y == S_y))
+ )
+
+ idx_under = np.concatenate((idx_under, idx_maj_sample), axis=0)
+ else:
+ idx_under = np.concatenate(
+ (idx_under, np.flatnonzero(y == target_class)), axis=0
+ )
+
+ self.sample_indices_ = idx_under
+
+ return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
@property
def estimator_(self):
"""Last fitted k-NN estimator."""
- pass
+ warnings.warn(
+ "`estimator_` attribute has been deprecated in 0.12 and will be "
+ "removed in 0.14. Use `estimators_` instead.",
+ FutureWarning,
+ )
+ return self.estimators_[-1]
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py
index 067fe55..38abd4b 100644
--- a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py
+++ b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py
@@ -1,19 +1,30 @@
"""Classes to perform under-sampling based on the edited nearest neighbour
method."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Dayvid Oliveira
+# Christos Aridas
+# License: MIT
+
import numbers
from collections import Counter
+
import numpy as np
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution, check_neighbors_object
from ...utils._docstring import _n_jobs_docstring
from ...utils._param_validation import HasMethods, Interval, StrOptions
from ...utils.fixes import _mode
from ..base import BaseCleaningSampler
-SEL_KIND = 'all', 'mode'
+
+SEL_KIND = ("all", "mode")
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class EditedNearestNeighbours(BaseCleaningSampler):
"""Undersample based on the edited nearest neighbour method.
@@ -111,14 +122,25 @@ class EditedNearestNeighbours(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 887, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'kind_sel': [StrOptions({'all', 'mode'})],
- 'n_jobs': [numbers.Integral, None]}
-
- def __init__(self, *, sampling_strategy='auto', n_neighbors=3, kind_sel
- ='all', n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "kind_sel": [StrOptions({"all", "mode"})],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ n_neighbors=3,
+ kind_sel="all",
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
@@ -126,11 +148,55 @@ class EditedNearestNeighbours(BaseCleaningSampler):
def _validate_estimator(self):
"""Validate the estimator created in the ENN."""
- pass
-
-
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+ self.nn_ = check_neighbors_object(
+ "n_neighbors", self.n_neighbors, additional_neighbor=1
+ )
+ self.nn_.set_params(**{"n_jobs": self.n_jobs})
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+
+ idx_under = np.empty((0,), dtype=int)
+
+ self.nn_.fit(X)
+
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ target_class_indices = np.flatnonzero(y == target_class)
+ X_class = _safe_indexing(X, target_class_indices)
+ y_class = _safe_indexing(y, target_class_indices)
+ nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
+ nnhood_label = y[nnhood_idx]
+ if self.kind_sel == "mode":
+ nnhood_label, _ = _mode(nnhood_label, axis=1)
+ nnhood_bool = np.ravel(nnhood_label) == y_class
+ elif self.kind_sel == "all":
+ nnhood_label = nnhood_label == target_class
+ nnhood_bool = np.all(nnhood_label, axis=1)
+ index_target_class = np.flatnonzero(nnhood_bool)
+ else:
+ index_target_class = slice(None)
+
+ idx_under = np.concatenate(
+ (
+ idx_under,
+ np.flatnonzero(y == target_class)[index_target_class],
+ ),
+ axis=0,
+ )
+
+ self.sample_indices_ = idx_under
+
+ return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
+
+ def _more_tags(self):
+ return {"sample_indices": True}
+
+
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class RepeatedEditedNearestNeighbours(BaseCleaningSampler):
"""Undersample based on the repeated edited nearest neighbour method.
@@ -241,15 +307,27 @@ class RepeatedEditedNearestNeighbours(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 887, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'max_iter': [Interval(numbers.Integral, 1,
- None, closed='left')], 'kind_sel': [StrOptions({'all', 'mode'})],
- 'n_jobs': [numbers.Integral, None]}
-
- def __init__(self, *, sampling_strategy='auto', n_neighbors=3, max_iter
- =100, kind_sel='all', n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "max_iter": [Interval(numbers.Integral, 1, None, closed="left")],
+ "kind_sel": [StrOptions({"all", "mode"})],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ n_neighbors=3,
+ max_iter=100,
+ kind_sel="all",
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
@@ -258,11 +336,88 @@ class RepeatedEditedNearestNeighbours(BaseCleaningSampler):
def _validate_estimator(self):
"""Private function to create the NN estimator"""
- pass
-
-
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+ self.nn_ = check_neighbors_object(
+ "n_neighbors", self.n_neighbors, additional_neighbor=1
+ )
+
+ self.enn_ = EditedNearestNeighbours(
+ sampling_strategy=self.sampling_strategy,
+ n_neighbors=self.nn_,
+ kind_sel=self.kind_sel,
+ n_jobs=self.n_jobs,
+ )
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+
+ X_, y_ = X, y
+ self.sample_indices_ = np.arange(X.shape[0], dtype=int)
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+
+ for n_iter in range(self.max_iter):
+ prev_len = y_.shape[0]
+ X_enn, y_enn = self.enn_.fit_resample(X_, y_)
+
+ # Check the stopping criterion
+ # 1. If there is no changes for the vector y
+ # 2. If the number of samples in the other class become inferior to
+ # the number of samples in the majority class
+ # 3. If one of the class is disappearing
+
+ # Case 1
+ b_conv = prev_len == y_enn.shape[0]
+
+ # Case 2
+ stats_enn = Counter(y_enn)
+ count_non_min = np.array(
+ [
+ val
+ for val, key in zip(stats_enn.values(), stats_enn.keys())
+ if key != class_minority
+ ]
+ )
+ b_min_bec_maj = np.any(count_non_min < target_stats[class_minority])
+
+ # Case 3
+ b_remove_maj_class = len(stats_enn) < len(target_stats)
+
+ (
+ X_,
+ y_,
+ ) = (
+ X_enn,
+ y_enn,
+ )
+ self.sample_indices_ = self.sample_indices_[self.enn_.sample_indices_]
+
+ if b_conv or b_min_bec_maj or b_remove_maj_class:
+ if b_conv:
+ (
+ X_,
+ y_,
+ ) = (
+ X_enn,
+ y_enn,
+ )
+ self.sample_indices_ = self.sample_indices_[
+ self.enn_.sample_indices_
+ ]
+ break
+
+ self.n_iter_ = n_iter + 1
+ X_resampled, y_resampled = X_, y_
+
+ return X_resampled, y_resampled
+
+ def _more_tags(self):
+ return {"sample_indices": True}
+
+
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class AllKNN(BaseCleaningSampler):
"""Undersample based on the AllKNN method.
@@ -372,14 +527,27 @@ class AllKNN(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 887, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'kind_sel': [StrOptions({'all', 'mode'})],
- 'allow_minority': ['boolean'], 'n_jobs': [numbers.Integral, None]}
-
- def __init__(self, *, sampling_strategy='auto', n_neighbors=3, kind_sel
- ='all', allow_minority=False, n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "kind_sel": [StrOptions({"all", "mode"})],
+ "allow_minority": ["boolean"],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ n_neighbors=3,
+ kind_sel="all",
+ allow_minority=False,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
@@ -388,4 +556,68 @@ class AllKNN(BaseCleaningSampler):
def _validate_estimator(self):
"""Create objects required by AllKNN"""
- pass
+ self.nn_ = check_neighbors_object(
+ "n_neighbors", self.n_neighbors, additional_neighbor=1
+ )
+
+ self.enn_ = EditedNearestNeighbours(
+ sampling_strategy=self.sampling_strategy,
+ n_neighbors=self.nn_,
+ kind_sel=self.kind_sel,
+ n_jobs=self.n_jobs,
+ )
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+
+ X_, y_ = X, y
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+
+ self.sample_indices_ = np.arange(X.shape[0], dtype=int)
+
+ for curr_size_ngh in range(1, self.nn_.n_neighbors):
+ self.enn_.n_neighbors = curr_size_ngh
+
+ X_enn, y_enn = self.enn_.fit_resample(X_, y_)
+
+ # Check the stopping criterion
+ # 1. If the number of samples in the other class become inferior to
+ # the number of samples in the majority class
+ # 2. If one of the class is disappearing
+ # Case 1else:
+
+ stats_enn = Counter(y_enn)
+ count_non_min = np.array(
+ [
+ val
+ for val, key in zip(stats_enn.values(), stats_enn.keys())
+ if key != class_minority
+ ]
+ )
+ b_min_bec_maj = np.any(count_non_min < target_stats[class_minority])
+ if self.allow_minority:
+ # overwrite b_min_bec_maj
+ b_min_bec_maj = False
+
+ # Case 2
+ b_remove_maj_class = len(stats_enn) < len(target_stats)
+
+ (
+ X_,
+ y_,
+ ) = (
+ X_enn,
+ y_enn,
+ )
+ self.sample_indices_ = self.sample_indices_[self.enn_.sample_indices_]
+
+ if b_min_bec_maj or b_remove_maj_class:
+ break
+
+ X_resampled, y_resampled = X_, y_
+
+ return X_resampled, y_resampled
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py
index c858b97..dac3f3c 100644
--- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py
+++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py
@@ -1,22 +1,32 @@
"""Class to perform under-sampling based on the instance hardness
threshold."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Dayvid Oliveira
+# Christos Aridas
+# License: MIT
+
import numbers
from collections import Counter
+
import numpy as np
from sklearn.base import clone, is_classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.utils import _safe_indexing, check_random_state
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods
from ..base import BaseUnderSampler
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class InstanceHardnessThreshold(BaseUnderSampler):
"""Undersample based on the instance hardness threshold.
@@ -98,13 +108,27 @@ class InstanceHardnessThreshold(BaseUnderSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 5..., 0: 100}})
"""
- _parameter_constraints: dict = {**BaseUnderSampler.
- _parameter_constraints, 'estimator': [HasMethods(['fit',
- 'predict_proba']), None], 'cv': ['cv_object'], 'n_jobs': [numbers.
- Integral, None], 'random_state': ['random_state']}
- def __init__(self, *, estimator=None, sampling_strategy='auto',
- random_state=None, cv=5, n_jobs=None):
+ _parameter_constraints: dict = {
+ **BaseUnderSampler._parameter_constraints,
+ "estimator": [
+ HasMethods(["fit", "predict_proba"]),
+ None,
+ ],
+ "cv": ["cv_object"],
+ "n_jobs": [numbers.Integral, None],
+ "random_state": ["random_state"],
+ }
+
+ def __init__(
+ self,
+ *,
+ estimator=None,
+ sampling_strategy="auto",
+ random_state=None,
+ cv=5,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.estimator = estimator
@@ -113,4 +137,68 @@ class InstanceHardnessThreshold(BaseUnderSampler):
def _validate_estimator(self, random_state):
"""Private function to create the classifier"""
- pass
+
+ if (
+ self.estimator is not None
+ and is_classifier(self.estimator)
+ and hasattr(self.estimator, "predict_proba")
+ ):
+ self.estimator_ = clone(self.estimator)
+ _set_random_states(self.estimator_, random_state)
+
+ elif self.estimator is None:
+ self.estimator_ = RandomForestClassifier(
+ n_estimators=100,
+ random_state=self.random_state,
+ n_jobs=self.n_jobs,
+ )
+
+ def _fit_resample(self, X, y):
+ random_state = check_random_state(self.random_state)
+ self._validate_estimator(random_state)
+
+ target_stats = Counter(y)
+ skf = StratifiedKFold(
+ n_splits=self.cv,
+ shuffle=True,
+ random_state=random_state,
+ )
+ probabilities = cross_val_predict(
+ self.estimator_,
+ X,
+ y,
+ cv=skf,
+ n_jobs=self.n_jobs,
+ method="predict_proba",
+ )
+ probabilities = probabilities[range(len(y)), y]
+
+ idx_under = np.empty((0,), dtype=int)
+
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ n_samples = self.sampling_strategy_[target_class]
+ threshold = np.percentile(
+ probabilities[y == target_class],
+ (1.0 - (n_samples / target_stats[target_class])) * 100.0,
+ )
+ index_target_class = np.flatnonzero(
+ probabilities[y == target_class] >= threshold
+ )
+ else:
+ index_target_class = slice(None)
+
+ idx_under = np.concatenate(
+ (
+ idx_under,
+ np.flatnonzero(y == target_class)[index_target_class],
+ ),
+ axis=0,
+ )
+
+ self.sample_indices_ = idx_under
+
+ return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/_nearmiss.py b/imblearn/under_sampling/_prototype_selection/_nearmiss.py
index f64b76a..70f647f 100644
--- a/imblearn/under_sampling/_prototype_selection/_nearmiss.py
+++ b/imblearn/under_sampling/_prototype_selection/_nearmiss.py
@@ -1,17 +1,26 @@
"""Class to perform under-sampling based on nearmiss methods."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
import warnings
from collections import Counter
+
import numpy as np
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution, check_neighbors_object
from ...utils._docstring import _n_jobs_docstring
from ...utils._param_validation import HasMethods, Interval
from ..base import BaseUnderSampler
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class NearMiss(BaseUnderSampler):
"""Class to perform under-sampling based on NearMiss methods.
@@ -102,24 +111,39 @@ class NearMiss(BaseUnderSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 100, 1: 100}})
"""
- _parameter_constraints: dict = {**BaseUnderSampler.
- _parameter_constraints, 'version': [Interval(numbers.Integral, 1, 3,
- closed='both')], 'n_neighbors': [Interval(numbers.Integral, 1, None,
- closed='left'), HasMethods(['kneighbors', 'kneighbors_graph'])],
- 'n_neighbors_ver3': [Interval(numbers.Integral, 1, None, closed=
- 'left'), HasMethods(['kneighbors', 'kneighbors_graph'])], 'n_jobs':
- [numbers.Integral, None]}
-
- def __init__(self, *, sampling_strategy='auto', version=1, n_neighbors=
- 3, n_neighbors_ver3=3, n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseUnderSampler._parameter_constraints,
+ "version": [Interval(numbers.Integral, 1, 3, closed="both")],
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "n_neighbors_ver3": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ version=1,
+ n_neighbors=3,
+ n_neighbors_ver3=3,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.version = version
self.n_neighbors = n_neighbors
self.n_neighbors_ver3 = n_neighbors_ver3
self.n_jobs = n_jobs
- def _selection_dist_based(self, X, y, dist_vec, num_samples, key,
- sel_strategy='nearest'):
+ def _selection_dist_based(
+ self, X, y, dist_vec, num_samples, key, sel_strategy="nearest"
+ ):
"""Select the appropriate samples depending of the strategy selected.
Parameters
@@ -148,8 +172,143 @@ class NearMiss(BaseUnderSampler):
The list of the indices of the selected samples.
"""
- pass
+
+ # Compute the distance considering the farthest neighbour
+ dist_avg_vec = np.sum(dist_vec[:, -self.nn_.n_neighbors :], axis=1)
+
+ target_class_indices = np.flatnonzero(y == key)
+ if dist_vec.shape[0] != _safe_indexing(X, target_class_indices).shape[0]:
+ raise RuntimeError(
+ "The samples to be selected do not correspond"
+ " to the distance matrix given. Ensure that"
+ " both `X[y == key]` and `dist_vec` are"
+ " related."
+ )
+
+ # Sort the list of distance and get the index
+ if sel_strategy == "nearest":
+ sort_way = False
+ else: # sel_strategy == "farthest":
+ sort_way = True
+
+ sorted_idx = sorted(
+ range(len(dist_avg_vec)),
+ key=dist_avg_vec.__getitem__,
+ reverse=sort_way,
+ )
+
+ # Throw a warning to tell the user that we did not have enough samples
+ # to select and that we just select everything
+ if len(sorted_idx) < num_samples:
+ warnings.warn(
+ "The number of the samples to be selected is larger"
+ " than the number of samples available. The"
+ " balancing ratio cannot be ensure and all samples"
+ " will be returned."
+ )
+
+ # Select the desired number of samples
+ return sorted_idx[:num_samples]
def _validate_estimator(self):
"""Private function to create the NN estimator"""
- pass
+
+ self.nn_ = check_neighbors_object("n_neighbors", self.n_neighbors)
+ self.nn_.set_params(**{"n_jobs": self.n_jobs})
+
+ if self.version == 3:
+ self.nn_ver3_ = check_neighbors_object(
+ "n_neighbors_ver3", self.n_neighbors_ver3
+ )
+ self.nn_ver3_.set_params(**{"n_jobs": self.n_jobs})
+
+ def _fit_resample(self, X, y):
+ self._validate_estimator()
+
+ idx_under = np.empty((0,), dtype=int)
+
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+ minority_class_indices = np.flatnonzero(y == class_minority)
+
+ self.nn_.fit(_safe_indexing(X, minority_class_indices))
+
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ n_samples = self.sampling_strategy_[target_class]
+ target_class_indices = np.flatnonzero(y == target_class)
+ X_class = _safe_indexing(X, target_class_indices)
+ y_class = _safe_indexing(y, target_class_indices)
+
+ if self.version == 1:
+ dist_vec, idx_vec = self.nn_.kneighbors(
+ X_class, n_neighbors=self.nn_.n_neighbors
+ )
+ index_target_class = self._selection_dist_based(
+ X,
+ y,
+ dist_vec,
+ n_samples,
+ target_class,
+ sel_strategy="nearest",
+ )
+ elif self.version == 2:
+ dist_vec, idx_vec = self.nn_.kneighbors(
+ X_class, n_neighbors=target_stats[class_minority]
+ )
+ index_target_class = self._selection_dist_based(
+ X,
+ y,
+ dist_vec,
+ n_samples,
+ target_class,
+ sel_strategy="nearest",
+ )
+ elif self.version == 3:
+ self.nn_ver3_.fit(X_class)
+ dist_vec, idx_vec = self.nn_ver3_.kneighbors(
+ _safe_indexing(X, minority_class_indices)
+ )
+ idx_vec_farthest = np.unique(idx_vec.reshape(-1))
+ X_class_selected = _safe_indexing(X_class, idx_vec_farthest)
+ y_class_selected = _safe_indexing(y_class, idx_vec_farthest)
+
+ dist_vec, idx_vec = self.nn_.kneighbors(
+ X_class_selected, n_neighbors=self.nn_.n_neighbors
+ )
+ index_target_class = self._selection_dist_based(
+ X_class_selected,
+ y_class_selected,
+ dist_vec,
+ n_samples,
+ target_class,
+ sel_strategy="farthest",
+ )
+ # idx_tmp is relative to the feature selected in the
+ # previous step and we need to find the indirection
+ index_target_class = idx_vec_farthest[index_target_class]
+ else:
+ index_target_class = slice(None)
+
+ idx_under = np.concatenate(
+ (
+ idx_under,
+ np.flatnonzero(y == target_class)[index_target_class],
+ ),
+ axis=0,
+ )
+
+ self.sample_indices_ = idx_under
+
+ return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
+
+ # fmt: off
+ def _more_tags(self):
+ return {
+ "sample_indices": True,
+ "_xfail_checks": {
+ "check_samplers_fit_resample":
+ "Fails for NearMiss-3 with less samples than expected"
+ }
+ }
+ # fmt: on
diff --git a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py
index e0a2f31..188ba32 100644
--- a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py
+++ b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py
@@ -1,21 +1,31 @@
"""Class performing under-sampling based on the neighbourhood cleaning rule."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
import warnings
from collections import Counter
+
import numpy as np
from sklearn.base import clone
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring
from ...utils._param_validation import HasMethods, Hidden, Interval, StrOptions
from ..base import BaseCleaningSampler
from ._edited_nearest_neighbours import EditedNearestNeighbours
-SEL_KIND = 'all', 'mode'
+
+SEL_KIND = ("all", "mode")
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class NeighbourhoodCleaningRule(BaseCleaningSampler):
"""Undersample based on the neighbourhood cleaning rule.
@@ -129,18 +139,32 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 888, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'edited_nearest_neighbours': [HasMethods([
- 'fit_resample']), None], 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph'])], 'kind_sel': [StrOptions({'all', 'mode'}),
- Hidden(StrOptions({'deprecated'}))], 'threshold_cleaning': [
- Interval(numbers.Real, 0, None, closed='neither')], 'n_jobs': [
- numbers.Integral, None]}
-
- def __init__(self, *, sampling_strategy='auto',
- edited_nearest_neighbours=None, n_neighbors=3, kind_sel=
- 'deprecated', threshold_cleaning=0.5, n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "edited_nearest_neighbours": [
+ HasMethods(["fit_resample"]),
+ None,
+ ],
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ ],
+ "kind_sel": [StrOptions({"all", "mode"}), Hidden(StrOptions({"deprecated"}))],
+ "threshold_cleaning": [Interval(numbers.Real, 0, None, closed="neither")],
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ edited_nearest_neighbours=None,
+ n_neighbors=3,
+ kind_sel="deprecated",
+ threshold_cleaning=0.5,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.edited_nearest_neighbours = edited_nearest_neighbours
self.n_neighbors = n_neighbors
@@ -150,4 +174,85 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
def _validate_estimator(self):
"""Create the objects required by NCR."""
- pass
+ if isinstance(self.n_neighbors, numbers.Integral):
+ self.nn_ = KNeighborsClassifier(
+ n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
+ )
+ elif isinstance(self.n_neighbors, NearestNeighbors):
+ # backward compatibility when passing a NearestNeighbors object
+ self.nn_ = KNeighborsClassifier(
+ n_neighbors=self.n_neighbors.n_neighbors - 1, n_jobs=self.n_jobs
+ )
+ else:
+ self.nn_ = clone(self.n_neighbors)
+
+ if self.edited_nearest_neighbours is None:
+ self.edited_nearest_neighbours_ = EditedNearestNeighbours(
+ sampling_strategy=self.sampling_strategy,
+ n_neighbors=self.n_neighbors,
+ kind_sel="mode",
+ n_jobs=self.n_jobs,
+ )
+ else:
+ self.edited_nearest_neighbours_ = clone(self.edited_nearest_neighbours)
+
+ def _fit_resample(self, X, y):
+ if self.kind_sel != "deprecated":
+ warnings.warn(
+ "`kind_sel` is deprecated in 0.12 and will be removed in 0.14. "
+ "It already has not effect and corresponds to the `'all'` option.",
+ FutureWarning,
+ )
+ self._validate_estimator()
+ self.edited_nearest_neighbours_.fit_resample(X, y)
+ index_not_a1 = self.edited_nearest_neighbours_.sample_indices_
+ index_a1 = np.ones(y.shape, dtype=bool)
+ index_a1[index_not_a1] = False
+ index_a1 = np.flatnonzero(index_a1)
+
+ # clean the neighborhood
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+ # compute which classes to consider for cleaning for the A2 group
+ self.classes_to_clean_ = [
+ c
+ for c, n_samples in target_stats.items()
+ if (
+ c in self.sampling_strategy_.keys()
+ and (n_samples > target_stats[class_minority] * self.threshold_cleaning)
+ )
+ ]
+ self.nn_.fit(X, y)
+
+ class_minority_indices = np.flatnonzero(y == class_minority)
+ X_minority = _safe_indexing(X, class_minority_indices)
+ y_minority = _safe_indexing(y, class_minority_indices)
+
+ y_pred_minority = self.nn_.predict(X_minority)
+ # add an additional sample since the query points contains the original dataset
+ neighbors_to_minority_indices = self.nn_.kneighbors(
+ X_minority, n_neighbors=self.nn_.n_neighbors + 1, return_distance=False
+ )[:, 1:]
+
+ mask_misclassified_minority = y_pred_minority != y_minority
+ index_a2 = np.ravel(neighbors_to_minority_indices[mask_misclassified_minority])
+ index_a2 = np.array(
+ [
+ index
+ for index in np.unique(index_a2)
+ if y[index] in self.classes_to_clean_
+ ]
+ )
+
+ union_a1_a2 = np.union1d(index_a1, index_a2).astype(int)
+ selected_samples = np.ones(y.shape, dtype=bool)
+ selected_samples[union_a1_a2] = False
+ self.sample_indices_ = np.flatnonzero(selected_samples)
+
+ return (
+ _safe_indexing(X, self.sample_indices_),
+ _safe_indexing(y, self.sample_indices_),
+ )
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py
index 6c0b322..e0e5b41 100644
--- a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py
+++ b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py
@@ -1,11 +1,18 @@
"""Class to perform under-sampling based on one-sided selection method."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numbers
import warnings
from collections import Counter
+
import numpy as np
from sklearn.base import clone
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils import _safe_indexing, check_random_state
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval
@@ -13,9 +20,11 @@ from ..base import BaseCleaningSampler
from ._tomek_links import TomekLinks
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=
- _random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+ random_state=_random_state_docstring,
+)
class OneSidedSelection(BaseCleaningSampler):
"""Class to perform under-sampling based on one-sided selection method.
@@ -109,15 +118,28 @@ class OneSidedSelection(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 496, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_neighbors': [Interval(numbers.Integral,
- 1, None, closed='left'), HasMethods(['kneighbors',
- 'kneighbors_graph']), None], 'n_seeds_S': [Interval(numbers.
- Integral, 1, None, closed='left')], 'n_jobs': [numbers.Integral,
- None], 'random_state': ['random_state']}
-
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- n_neighbors=None, n_seeds_S=1, n_jobs=None):
+
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_neighbors": [
+ Interval(numbers.Integral, 1, None, closed="left"),
+ HasMethods(["kneighbors", "kneighbors_graph"]),
+ None,
+ ],
+ "n_seeds_S": [Interval(numbers.Integral, 1, None, closed="left")],
+ "n_jobs": [numbers.Integral, None],
+ "random_state": ["random_state"],
+ }
+
+ def __init__(
+ self,
+ *,
+ sampling_strategy="auto",
+ random_state=None,
+ n_neighbors=None,
+ n_seeds_S=1,
+ n_jobs=None,
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.n_neighbors = n_neighbors
@@ -126,9 +148,80 @@ class OneSidedSelection(BaseCleaningSampler):
def _validate_estimator(self):
"""Private function to create the NN estimator"""
- pass
+ if self.n_neighbors is None:
+ estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
+ elif isinstance(self.n_neighbors, int):
+ estimator = KNeighborsClassifier(
+ n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
+ )
+ elif isinstance(self.n_neighbors, KNeighborsClassifier):
+ estimator = clone(self.n_neighbors)
+
+ return estimator
+
+ def _fit_resample(self, X, y):
+ estimator = self._validate_estimator()
+
+ random_state = check_random_state(self.random_state)
+ target_stats = Counter(y)
+ class_minority = min(target_stats, key=target_stats.get)
+
+ idx_under = np.empty((0,), dtype=int)
+
+ self.estimators_ = []
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ # select a sample from the current class
+ idx_maj = np.flatnonzero(y == target_class)
+ sel_idx_maj = random_state.randint(
+ low=0, high=target_stats[target_class], size=self.n_seeds_S
+ )
+ idx_maj_sample = idx_maj[sel_idx_maj]
+
+ minority_class_indices = np.flatnonzero(y == class_minority)
+ C_indices = np.append(minority_class_indices, idx_maj_sample)
+
+ # create the set composed of all minority samples and one
+ # sample from the current class.
+ C_x = _safe_indexing(X, C_indices)
+ C_y = _safe_indexing(y, C_indices)
+
+ # create the set S with removing the seed from S
+ # since that it will be added anyway
+ idx_maj_extracted = np.delete(idx_maj, sel_idx_maj, axis=0)
+ S_x = _safe_indexing(X, idx_maj_extracted)
+ S_y = _safe_indexing(y, idx_maj_extracted)
+ self.estimators_.append(clone(estimator).fit(C_x, C_y))
+ pred_S_y = self.estimators_[-1].predict(S_x)
+
+ S_misclassified_indices = np.flatnonzero(pred_S_y != S_y)
+ idx_tmp = idx_maj_extracted[S_misclassified_indices]
+ idx_under = np.concatenate((idx_under, idx_maj_sample, idx_tmp), axis=0)
+ else:
+ idx_under = np.concatenate(
+ (idx_under, np.flatnonzero(y == target_class)), axis=0
+ )
+
+ X_resampled = _safe_indexing(X, idx_under)
+ y_resampled = _safe_indexing(y, idx_under)
+
+ # apply Tomek cleaning
+ tl = TomekLinks(sampling_strategy=list(self.sampling_strategy_.keys()))
+ X_cleaned, y_cleaned = tl.fit_resample(X_resampled, y_resampled)
+
+ self.sample_indices_ = _safe_indexing(idx_under, tl.sample_indices_)
+
+ return X_cleaned, y_cleaned
@property
def estimator_(self):
"""Last fitted k-NN estimator."""
- pass
+ warnings.warn(
+ "`estimator_` attribute has been deprecated in 0.12 and will be "
+ "removed in 0.14. Use `estimators_` instead.",
+ FutureWarning,
+ )
+ return self.estimators_[-1]
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py
index 8e943b8..876195a 100644
--- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py
+++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py
@@ -1,14 +1,22 @@
"""Class to perform random under-sampling."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.utils import _safe_indexing, check_random_state
+
from ...utils import Substitution, check_target_type
from ...utils._docstring import _random_state_docstring
from ...utils._validation import _check_X
from ..base import BaseUnderSampler
-@Substitution(sampling_strategy=BaseUnderSampler.
- _sampling_strategy_docstring, random_state=_random_state_docstring)
+@Substitution(
+ sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
+ random_state=_random_state_docstring,
+)
class RandomUnderSampler(BaseUnderSampler):
"""Class to perform random under-sampling.
@@ -74,12 +82,61 @@ class RandomUnderSampler(BaseUnderSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 100, 1: 100}})
"""
- _parameter_constraints: dict = {**BaseUnderSampler.
- _parameter_constraints, 'replacement': ['boolean'], 'random_state':
- ['random_state']}
- def __init__(self, *, sampling_strategy='auto', random_state=None,
- replacement=False):
+ _parameter_constraints: dict = {
+ **BaseUnderSampler._parameter_constraints,
+ "replacement": ["boolean"],
+ "random_state": ["random_state"],
+ }
+
+ def __init__(
+ self, *, sampling_strategy="auto", random_state=None, replacement=False
+ ):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.replacement = replacement
+
+ def _check_X_y(self, X, y):
+ y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
+ X = _check_X(X)
+ self._check_n_features(X, reset=True)
+ self._check_feature_names(X, reset=True)
+ return X, y, binarize_y
+
+ def _fit_resample(self, X, y):
+ random_state = check_random_state(self.random_state)
+
+ idx_under = np.empty((0,), dtype=int)
+
+ for target_class in np.unique(y):
+ if target_class in self.sampling_strategy_.keys():
+ n_samples = self.sampling_strategy_[target_class]
+ index_target_class = random_state.choice(
+ range(np.count_nonzero(y == target_class)),
+ size=n_samples,
+ replace=self.replacement,
+ )
+ else:
+ index_target_class = slice(None)
+
+ idx_under = np.concatenate(
+ (
+ idx_under,
+ np.flatnonzero(y == target_class)[index_target_class],
+ ),
+ axis=0,
+ )
+
+ self.sample_indices_ = idx_under
+
+ return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
+
+ def _more_tags(self):
+ return {
+ "X_types": ["2darray", "string", "sparse", "dataframe"],
+ "sample_indices": True,
+ "allow_nan": True,
+ "_xfail_checks": {
+ "check_complex_data": "Robust to this type of data.",
+ },
+ }
diff --git a/imblearn/under_sampling/_prototype_selection/_tomek_links.py b/imblearn/under_sampling/_prototype_selection/_tomek_links.py
index a64dc2a..b0f9549 100644
--- a/imblearn/under_sampling/_prototype_selection/_tomek_links.py
+++ b/imblearn/under_sampling/_prototype_selection/_tomek_links.py
@@ -1,15 +1,25 @@
"""Class to perform under-sampling by removing Tomek's links."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Fernando Nogueira
+# Christos Aridas
+# License: MIT
+
import numbers
+
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import _safe_indexing
+
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring
from ..base import BaseCleaningSampler
-@Substitution(sampling_strategy=BaseCleaningSampler.
- _sampling_strategy_docstring, n_jobs=_n_jobs_docstring)
+@Substitution(
+ sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
+ n_jobs=_n_jobs_docstring,
+)
class TomekLinks(BaseCleaningSampler):
"""Under-sampling by removing Tomek's links.
@@ -79,10 +89,13 @@ class TomekLinks(BaseCleaningSampler):
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 897, 0: 100}})
"""
- _parameter_constraints: dict = {**BaseCleaningSampler.
- _parameter_constraints, 'n_jobs': [numbers.Integral, None]}
- def __init__(self, *, sampling_strategy='auto', n_jobs=None):
+ _parameter_constraints: dict = {
+ **BaseCleaningSampler._parameter_constraints,
+ "n_jobs": [numbers.Integral, None],
+ }
+
+ def __init__(self, *, sampling_strategy="auto", n_jobs=None):
super().__init__(sampling_strategy=sampling_strategy)
self.n_jobs = n_jobs
@@ -112,4 +125,36 @@ class TomekLinks(BaseCleaningSampler):
Boolean vector on len( # samples ), with True for majority samples
that are Tomek links.
"""
- pass
+ links = np.zeros(len(y), dtype=bool)
+
+ # find which class to not consider
+ class_excluded = [c for c in np.unique(y) if c not in class_type]
+
+ # there is a Tomek link between two samples if they are both nearest
+ # neighbors of each others.
+ for index_sample, target_sample in enumerate(y):
+ if target_sample in class_excluded:
+ continue
+
+ if y[nn_index[index_sample]] != target_sample:
+ if nn_index[nn_index[index_sample]] == index_sample:
+ links[index_sample] = True
+
+ return links
+
+ def _fit_resample(self, X, y):
+ # Find the nearest neighbour of every point
+ nn = NearestNeighbors(n_neighbors=2, n_jobs=self.n_jobs)
+ nn.fit(X)
+ nns = nn.kneighbors(X, return_distance=False)[:, 1]
+
+ links = self.is_tomek(y, nns, self.sampling_strategy_)
+ self.sample_indices_ = np.flatnonzero(np.logical_not(links))
+
+ return (
+ _safe_indexing(X, self.sample_indices_),
+ _safe_indexing(y, self.sample_indices_),
+ )
+
+ def _more_tags(self):
+ return {"sample_indices": True}
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_allknn.py b/imblearn/under_sampling/_prototype_selection/tests/test_allknn.py
index c2868a7..131f959 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_allknn.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_allknn.py
@@ -1,27 +1,357 @@
"""Test the module repeated edited nearest neighbour."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_allclose, assert_array_equal
+
from imblearn.under_sampling import AllKNN
-X = np.array([[-0.12840393, 0.66446571], [1.32319756, -0.13181616], [
- 0.04296502, -0.37981873], [0.83631853, 0.18569783], [1.02956816,
- 0.36061601], [1.12202806, 0.33811558], [-0.53171468, -0.53735182], [
- 1.3381556, 0.35956356], [-0.35946678, 0.72510189], [1.32326943,
- 0.28393874], [2.94290565, -0.13986434], [0.28294738, -1.00125525], [
- 0.34218094, -0.58781961], [-0.88864036, -0.33782387], [-1.10146139,
- 0.91782682], [-0.7969716, -0.50493969], [0.73489726, 0.43915195], [
- 0.2096964, -0.61814058], [-0.28479268, 0.70459548], [1.84864913,
- 0.14729596], [1.59068979, -0.96622933], [0.73418199, -0.02222847], [
- 0.50307437, 0.498805], [0.84929742, 0.41042894], [0.62649535,
- 0.46600596], [0.79270821, -0.41386668], [1.16606871, -0.25641059], [
- 1.57356906, 0.30390519], [1.0304995, -0.16955962], [1.67314371,
- 0.19231498], [0.98382284, 0.37184502], [0.48921682, -1.38504507], [-
- 0.46226554, -0.50481004], [-0.03918551, -0.68540745], [0.24991051, -
- 1.00864997], [0.80541964, -0.34465185], [0.1732627, -1.61323172], [
- 0.69804044, 0.44810796], [-0.5506368, -0.42072426], [-0.34474418,
- 0.21969797]])
-Y = np.array([1, 2, 2, 2, 1, 1, 0, 2, 1, 1, 1, 2, 2, 0, 1, 2, 1, 2, 1, 1, 2,
- 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 0, 2, 2, 2, 2, 1, 2, 0])
-R_TOL = 0.0001
+
+X = np.array(
+ [
+ [-0.12840393, 0.66446571],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.83631853, 0.18569783],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.53171468, -0.53735182],
+ [1.3381556, 0.35956356],
+ [-0.35946678, 0.72510189],
+ [1.32326943, 0.28393874],
+ [2.94290565, -0.13986434],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [-0.88864036, -0.33782387],
+ [-1.10146139, 0.91782682],
+ [-0.7969716, -0.50493969],
+ [0.73489726, 0.43915195],
+ [0.2096964, -0.61814058],
+ [-0.28479268, 0.70459548],
+ [1.84864913, 0.14729596],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.57356906, 0.30390519],
+ [1.0304995, -0.16955962],
+ [1.67314371, 0.19231498],
+ [0.98382284, 0.37184502],
+ [0.48921682, -1.38504507],
+ [-0.46226554, -0.50481004],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ [0.69804044, 0.44810796],
+ [-0.5506368, -0.42072426],
+ [-0.34474418, 0.21969797],
+ ]
+)
+Y = np.array(
+ [
+ 1,
+ 2,
+ 2,
+ 2,
+ 1,
+ 1,
+ 0,
+ 2,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 0,
+ 1,
+ 2,
+ 1,
+ 2,
+ 1,
+ 1,
+ 2,
+ 2,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 1,
+ 1,
+ 2,
+ 0,
+ 2,
+ 2,
+ 2,
+ 2,
+ 1,
+ 2,
+ 0,
+ ]
+)
+R_TOL = 1e-4
+
+
+def test_allknn_fit_resample():
+ allknn = AllKNN()
+ X_resampled, y_resampled = allknn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-1.10146139, 0.91782682],
+ [0.73489726, 0.43915195],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ ]
+ )
+ assert_allclose(X_resampled, X_gt, rtol=R_TOL)
+ assert_allclose(y_resampled, y_gt, rtol=R_TOL)
+
+
+def test_all_knn_allow_minority():
+ X, y = make_classification(
+ n_samples=10000,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_classes=3,
+ n_clusters_per_class=1,
+ weights=[0.2, 0.3, 0.5],
+ class_sep=0.4,
+ random_state=0,
+ )
+
+ allknn = AllKNN(allow_minority=True)
+ X_res_1, y_res_1 = allknn.fit_resample(X, y)
+ allknn = AllKNN()
+ X_res_2, y_res_2 = allknn.fit_resample(X, y)
+ assert len(y_res_1) < len(y_res_2)
+
+
+def test_allknn_fit_resample_mode():
+ allknn = AllKNN(kind_sel="mode")
+ X_resampled, y_resampled = allknn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [-0.12840393, 0.66446571],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.35946678, 0.72510189],
+ [-1.10146139, 0.91782682],
+ [0.73489726, 0.43915195],
+ [-0.28479268, 0.70459548],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ ]
+ )
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_allknn_fit_resample_with_nn_object():
+ nn = NearestNeighbors(n_neighbors=4)
+ allknn = AllKNN(n_neighbors=nn, kind_sel="mode")
+ X_resampled, y_resampled = allknn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [-0.12840393, 0.66446571],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.35946678, 0.72510189],
+ [-1.10146139, 0.91782682],
+ [0.73489726, 0.43915195],
+ [-0.28479268, 0.70459548],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ ]
+ )
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_alknn_not_good_object():
+ nn = "rnd"
+ allknn = AllKNN(n_neighbors=nn, kind_sel="mode")
+ with pytest.raises(ValueError):
+ allknn.fit_resample(X, Y)
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py
index b14a1ef..5cc8f41 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py
@@ -1,28 +1,129 @@
"""Test the module condensed nearest neighbour."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import CondensedNearestNeighbour
+
RND_SEED = 0
-X = np.array([[2.59928271, 0.93323465], [0.25738379, 0.95564169], [
- 1.42772181, 0.526027], [1.92365863, 0.82718767], [-0.10903849, -
- 0.12085181], [-0.284881, -0.62730973], [0.57062627, 1.19528323], [
- 0.03394306, 0.03986753], [0.78318102, 2.59153329], [0.35831463,
- 1.33483198], [-0.14313184, -1.0412815], [0.01936241, 0.17799828], [-
- 1.25020462, -0.40402054], [-0.09816301, -0.74662486], [-0.01252787,
- 0.34102657], [0.52726792, -0.38735648], [0.2821046, -0.07862747], [
- 0.05230552, 0.09043907], [0.15198585, 0.12512646], [0.70524765,
- 0.39816382]])
+X = np.array(
+ [
+ [2.59928271, 0.93323465],
+ [0.25738379, 0.95564169],
+ [1.42772181, 0.526027],
+ [1.92365863, 0.82718767],
+ [-0.10903849, -0.12085181],
+ [-0.284881, -0.62730973],
+ [0.57062627, 1.19528323],
+ [0.03394306, 0.03986753],
+ [0.78318102, 2.59153329],
+ [0.35831463, 1.33483198],
+ [-0.14313184, -1.0412815],
+ [0.01936241, 0.17799828],
+ [-1.25020462, -0.40402054],
+ [-0.09816301, -0.74662486],
+ [-0.01252787, 0.34102657],
+ [0.52726792, -0.38735648],
+ [0.2821046, -0.07862747],
+ [0.05230552, 0.09043907],
+ [0.15198585, 0.12512646],
+ [0.70524765, 0.39816382],
+ ]
+)
Y = np.array([1, 2, 1, 1, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 2, 1])
+def test_cnn_init():
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED)
+
+ assert cnn.n_seeds_S == 1
+ assert cnn.n_jobs is None
+
+
+def test_cnn_fit_resample():
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED)
+ X_resampled, y_resampled = cnn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.10903849, -0.12085181],
+ [0.01936241, 0.17799828],
+ [0.05230552, 0.09043907],
+ [-1.25020462, -0.40402054],
+ [0.70524765, 0.39816382],
+ [0.35831463, 1.33483198],
+ [-0.284881, -0.62730973],
+ [0.03394306, 0.03986753],
+ [-0.01252787, 0.34102657],
+ [0.15198585, 0.12512646],
+ ]
+ )
+ y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+@pytest.mark.parametrize("n_neighbors", [1, KNeighborsClassifier(n_neighbors=1)])
+def test_cnn_fit_resample_with_object(n_neighbors):
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=n_neighbors)
+ X_resampled, y_resampled = cnn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.10903849, -0.12085181],
+ [0.01936241, 0.17799828],
+ [0.05230552, 0.09043907],
+ [-1.25020462, -0.40402054],
+ [0.70524765, 0.39816382],
+ [0.35831463, 1.33483198],
+ [-0.284881, -0.62730973],
+ [0.03394306, 0.03986753],
+ [-0.01252787, 0.34102657],
+ [0.15198585, 0.12512646],
+ ]
+ )
+ y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=1)
+ X_resampled, y_resampled = cnn.fit_resample(X, Y)
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
def test_condensed_nearest_neighbour_multiclass():
"""Check the validity of the fitted attributes `estimators_`."""
- pass
+ X, y = make_classification(
+ n_samples=1_000,
+ n_classes=4,
+ weights=[0.1, 0.2, 0.2, 0.5],
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED)
+ cnn.fit_resample(X, y)
+
+ assert len(cnn.estimators_) == len(cnn.sampling_strategy_)
+ other_classes = []
+ for est in cnn.estimators_:
+ assert est.classes_[0] == 0 # minority class
+ assert est.classes_[1] in {1, 2, 3} # other classes
+ other_classes.append(est.classes_[1])
+ assert len(set(other_classes)) == len(other_classes)
+# TODO: remove in 0.14
def test_condensed_nearest_neighbors_deprecation():
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
- pass
+ cnn = CondensedNearestNeighbour(random_state=RND_SEED)
+ cnn.fit_resample(X, Y)
+ warn_msg = "`estimator_` attribute has been deprecated"
+ with pytest.warns(FutureWarning, match=warn_msg):
+ cnn.estimator_
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py
index 9333224..00a0ce5 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py
@@ -1,22 +1,140 @@
"""Test the module edited nearest neighbour."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.datasets import make_classification
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import EditedNearestNeighbours
-X = np.array([[2.59928271, 0.93323465], [0.25738379, 0.95564169], [
- 1.42772181, 0.526027], [1.92365863, 0.82718767], [-0.10903849, -
- 0.12085181], [-0.284881, -0.62730973], [0.57062627, 1.19528323], [
- 0.03394306, 0.03986753], [0.78318102, 2.59153329], [0.35831463,
- 1.33483198], [-0.14313184, -1.0412815], [0.01936241, 0.17799828], [-
- 1.25020462, -0.40402054], [-0.09816301, -0.74662486], [-0.01252787,
- 0.34102657], [0.52726792, -0.38735648], [0.2821046, -0.07862747], [
- 0.05230552, 0.09043907], [0.15198585, 0.12512646], [0.70524765,
- 0.39816382]])
+
+X = np.array(
+ [
+ [2.59928271, 0.93323465],
+ [0.25738379, 0.95564169],
+ [1.42772181, 0.526027],
+ [1.92365863, 0.82718767],
+ [-0.10903849, -0.12085181],
+ [-0.284881, -0.62730973],
+ [0.57062627, 1.19528323],
+ [0.03394306, 0.03986753],
+ [0.78318102, 2.59153329],
+ [0.35831463, 1.33483198],
+ [-0.14313184, -1.0412815],
+ [0.01936241, 0.17799828],
+ [-1.25020462, -0.40402054],
+ [-0.09816301, -0.74662486],
+ [-0.01252787, 0.34102657],
+ [0.52726792, -0.38735648],
+ [0.2821046, -0.07862747],
+ [0.05230552, 0.09043907],
+ [0.15198585, 0.12512646],
+ [0.70524765, 0.39816382],
+ ]
+)
Y = np.array([1, 2, 1, 1, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 2, 1])
+def test_enn_init():
+ enn = EditedNearestNeighbours()
+
+ assert enn.n_neighbors == 3
+ assert enn.kind_sel == "all"
+ assert enn.n_jobs is None
+
+
+def test_enn_fit_resample():
+ enn = EditedNearestNeighbours()
+ X_resampled, y_resampled = enn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.10903849, -0.12085181],
+ [0.01936241, 0.17799828],
+ [2.59928271, 0.93323465],
+ [1.92365863, 0.82718767],
+ [0.25738379, 0.95564169],
+ [0.78318102, 2.59153329],
+ [0.52726792, -0.38735648],
+ ]
+ )
+ y_gt = np.array([0, 0, 1, 1, 2, 2, 2])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_enn_fit_resample_mode():
+ enn = EditedNearestNeighbours(kind_sel="mode")
+ X_resampled, y_resampled = enn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.10903849, -0.12085181],
+ [0.01936241, 0.17799828],
+ [2.59928271, 0.93323465],
+ [1.42772181, 0.526027],
+ [1.92365863, 0.82718767],
+ [0.25738379, 0.95564169],
+ [-0.284881, -0.62730973],
+ [0.57062627, 1.19528323],
+ [0.78318102, 2.59153329],
+ [0.35831463, 1.33483198],
+ [-0.14313184, -1.0412815],
+ [-0.09816301, -0.74662486],
+ [0.52726792, -0.38735648],
+ [0.2821046, -0.07862747],
+ ]
+ )
+ y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_enn_fit_resample_with_nn_object():
+ nn = NearestNeighbors(n_neighbors=4)
+ enn = EditedNearestNeighbours(n_neighbors=nn, kind_sel="mode")
+ X_resampled, y_resampled = enn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.10903849, -0.12085181],
+ [0.01936241, 0.17799828],
+ [2.59928271, 0.93323465],
+ [1.42772181, 0.526027],
+ [1.92365863, 0.82718767],
+ [0.25738379, 0.95564169],
+ [-0.284881, -0.62730973],
+ [0.57062627, 1.19528323],
+ [0.78318102, 2.59153329],
+ [0.35831463, 1.33483198],
+ [-0.14313184, -1.0412815],
+ [-0.09816301, -0.74662486],
+ [0.52726792, -0.38735648],
+ [0.2821046, -0.07862747],
+ ]
+ )
+ y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
def test_enn_check_kind_selection():
"""Check that `check_sel="all"` is more conservative than
`check_sel="mode"`."""
- pass
+
+ X, y = make_classification(
+ n_samples=1000,
+ n_classes=2,
+ weights=[0.3, 0.7],
+ random_state=0,
+ )
+
+ enn_all = EditedNearestNeighbours(kind_sel="all")
+ enn_mode = EditedNearestNeighbours(kind_sel="mode")
+
+ enn_all.fit_resample(X, y)
+ enn_mode.fit_resample(X, y)
+
+ assert enn_all.sample_indices_.size < enn_mode.sample_indices_.size
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py
index bdb3a01..a63bb45 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py
@@ -1,22 +1,101 @@
"""Test the module ."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.naive_bayes import GaussianNB as NB
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import InstanceHardnessThreshold
+
RND_SEED = 0
-X = np.array([[-0.3879569, 0.6894251], [-0.09322739, 1.28177189], [-
- 0.77740357, 0.74097941], [0.91542919, -0.65453327], [-0.03852113,
- 0.40910479], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-
- 0.18430329, 0.52328473], [-0.30126957, -0.66268378], [-0.65571327,
- 0.42412021], [-0.28305528, 0.30284991], [0.20246714, -0.34727125], [
- 1.06446472, -1.09279772], [0.30543283, -0.02589502], [-0.00717161,
- 0.00318087]])
+X = np.array(
+ [
+ [-0.3879569, 0.6894251],
+ [-0.09322739, 1.28177189],
+ [-0.77740357, 0.74097941],
+ [0.91542919, -0.65453327],
+ [-0.03852113, 0.40910479],
+ [-0.43877303, 1.07366684],
+ [-0.85795321, 0.82980738],
+ [-0.18430329, 0.52328473],
+ [-0.30126957, -0.66268378],
+ [-0.65571327, 0.42412021],
+ [-0.28305528, 0.30284991],
+ [0.20246714, -0.34727125],
+ [1.06446472, -1.09279772],
+ [0.30543283, -0.02589502],
+ [-0.00717161, 0.00318087],
+ ]
+)
Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
ESTIMATOR = GradientBoostingClassifier(random_state=RND_SEED)
+def test_iht_init():
+ sampling_strategy = "auto"
+ iht = InstanceHardnessThreshold(
+ estimator=ESTIMATOR,
+ sampling_strategy=sampling_strategy,
+ random_state=RND_SEED,
+ )
+
+ assert iht.sampling_strategy == sampling_strategy
+ assert iht.random_state == RND_SEED
+
+
+def test_iht_fit_resample():
+ iht = InstanceHardnessThreshold(estimator=ESTIMATOR, random_state=RND_SEED)
+ X_resampled, y_resampled = iht.fit_resample(X, Y)
+ assert X_resampled.shape == (12, 2)
+ assert y_resampled.shape == (12,)
+
+
+def test_iht_fit_resample_half():
+ sampling_strategy = {0: 3, 1: 3}
+ iht = InstanceHardnessThreshold(
+ estimator=NB(),
+ sampling_strategy=sampling_strategy,
+ random_state=RND_SEED,
+ )
+ X_resampled, y_resampled = iht.fit_resample(X, Y)
+ assert X_resampled.shape == (6, 2)
+ assert y_resampled.shape == (6,)
+
+
+def test_iht_fit_resample_class_obj():
+ est = GradientBoostingClassifier(random_state=RND_SEED)
+ iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
+ X_resampled, y_resampled = iht.fit_resample(X, Y)
+ assert X_resampled.shape == (12, 2)
+ assert y_resampled.shape == (12,)
+
+
+def test_iht_reproducibility():
+ from sklearn.datasets import load_digits
+
+ X_digits, y_digits = load_digits(return_X_y=True)
+ idx_sampled = []
+ for seed in range(5):
+ est = RandomForestClassifier(n_estimators=10, random_state=seed)
+ iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
+ iht.fit_resample(X_digits, y_digits)
+ idx_sampled.append(iht.sample_indices_.copy())
+ for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]):
+ assert_array_equal(idx_1, idx_2)
+
+
+def test_iht_fit_resample_default_estimator():
+ iht = InstanceHardnessThreshold(estimator=None, random_state=RND_SEED)
+ X_resampled, y_resampled = iht.fit_resample(X, Y)
+ assert isinstance(iht.estimator_, RandomForestClassifier)
+ assert X_resampled.shape == (12, 2)
+ assert y_resampled.shape == (12,)
+
+
def test_iht_estimator_pipeline():
"""Check that we can pass a pipeline containing a classifier.
@@ -26,4 +105,8 @@ def test_iht_estimator_pipeline():
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1049
"""
- pass
+ model = make_pipeline(GradientBoostingClassifier(random_state=RND_SEED))
+ iht = InstanceHardnessThreshold(estimator=model, random_state=RND_SEED)
+ X_resampled, y_resampled = iht.fit_resample(X, Y)
+ assert X_resampled.shape == (12, 2)
+ assert y_resampled.shape == (12,)
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py
index a80ec95..9ab0da4 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py
@@ -1,14 +1,210 @@
"""Test the module nearmiss."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import NearMiss
-X = np.array([[1.17737838, -0.2002118], [0.4960075, 0.86130762], [-
- 0.05903827, 0.10947647], [0.91464286, 1.61369212], [-0.54619583,
- 1.73009918], [-0.60413357, 0.24628718], [0.45713638, 1.31069295], [-
- 0.04032409, 3.01186964], [0.03142011, 0.12323596], [0.50701028, -
- 0.17636928], [-0.80809175, -1.09917302], [-0.20497017, -0.26630228], [
- 0.99272351, -0.11631728], [-1.95581933, 0.69609604], [1.15157493, -
- 1.2981518]])
+
+X = np.array(
+ [
+ [1.17737838, -0.2002118],
+ [0.4960075, 0.86130762],
+ [-0.05903827, 0.10947647],
+ [0.91464286, 1.61369212],
+ [-0.54619583, 1.73009918],
+ [-0.60413357, 0.24628718],
+ [0.45713638, 1.31069295],
+ [-0.04032409, 3.01186964],
+ [0.03142011, 0.12323596],
+ [0.50701028, -0.17636928],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [0.99272351, -0.11631728],
+ [-1.95581933, 0.69609604],
+ [1.15157493, -1.2981518],
+ ]
+)
Y = np.array([1, 2, 1, 0, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2])
-VERSION_NEARMISS = 1, 2, 3
+
+VERSION_NEARMISS = (1, 2, 3)
+
+
+def test_nm_fit_resample_auto():
+ sampling_strategy = "auto"
+ X_gt = [
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ ]
+ ),
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ ]
+ ),
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [1.17737838, -0.2002118],
+ [-0.60413357, 0.24628718],
+ [0.03142011, 0.12323596],
+ [1.15157493, -1.2981518],
+ [-0.54619583, 1.73009918],
+ [0.99272351, -0.11631728],
+ ]
+ ),
+ ]
+ y_gt = [
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ ]
+ for version_idx, version in enumerate(VERSION_NEARMISS):
+ nm = NearMiss(sampling_strategy=sampling_strategy, version=version)
+ X_resampled, y_resampled = nm.fit_resample(X, Y)
+ assert_array_equal(X_resampled, X_gt[version_idx])
+ assert_array_equal(y_resampled, y_gt[version_idx])
+
+
+def test_nm_fit_resample_float_sampling_strategy():
+ sampling_strategy = {0: 3, 1: 4, 2: 4}
+ X_gt = [
+ np.array(
+ [
+ [-0.20497017, -0.26630228],
+ [-0.80809175, -1.09917302],
+ [0.91464286, 1.61369212],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [1.17737838, -0.2002118],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ [0.99272351, -0.11631728],
+ ]
+ ),
+ np.array(
+ [
+ [-0.20497017, -0.26630228],
+ [-0.80809175, -1.09917302],
+ [0.91464286, 1.61369212],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [1.17737838, -0.2002118],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ [0.99272351, -0.11631728],
+ ]
+ ),
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [1.17737838, -0.2002118],
+ [-0.60413357, 0.24628718],
+ [0.03142011, 0.12323596],
+ [-0.05903827, 0.10947647],
+ [1.15157493, -1.2981518],
+ [-0.54619583, 1.73009918],
+ [0.99272351, -0.11631728],
+ [0.45713638, 1.31069295],
+ ]
+ ),
+ ]
+ y_gt = [
+ np.array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
+ ]
+
+ for version_idx, version in enumerate(VERSION_NEARMISS):
+ nm = NearMiss(sampling_strategy=sampling_strategy, version=version)
+ X_resampled, y_resampled = nm.fit_resample(X, Y)
+ assert_array_equal(X_resampled, X_gt[version_idx])
+ assert_array_equal(y_resampled, y_gt[version_idx])
+
+
+def test_nm_fit_resample_nn_obj():
+ sampling_strategy = "auto"
+ nn = NearestNeighbors(n_neighbors=3)
+ X_gt = [
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ ]
+ ),
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [-0.05903827, 0.10947647],
+ [0.03142011, 0.12323596],
+ [-0.60413357, 0.24628718],
+ [0.50701028, -0.17636928],
+ [0.4960075, 0.86130762],
+ [0.45713638, 1.31069295],
+ ]
+ ),
+ np.array(
+ [
+ [0.91464286, 1.61369212],
+ [-0.80809175, -1.09917302],
+ [-0.20497017, -0.26630228],
+ [1.17737838, -0.2002118],
+ [-0.60413357, 0.24628718],
+ [0.03142011, 0.12323596],
+ [1.15157493, -1.2981518],
+ [-0.54619583, 1.73009918],
+ [0.99272351, -0.11631728],
+ ]
+ ),
+ ]
+ y_gt = [
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
+ ]
+ for version_idx, version in enumerate(VERSION_NEARMISS):
+ nm = NearMiss(
+ sampling_strategy=sampling_strategy,
+ version=version,
+ n_neighbors=nn,
+ )
+ X_resampled, y_resampled = nm.fit_resample(X, Y)
+ assert_array_equal(X_resampled, X_gt[version_idx])
+ assert_array_equal(y_resampled, y_gt[version_idx])
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py
index 97c8fd5..97a1d02 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py
@@ -1,17 +1,84 @@
"""Test the module neighbourhood cleaning rule."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import EditedNearestNeighbours, NeighbourhoodCleaningRule
+@pytest.fixture(scope="module")
+def data():
+ return make_classification(
+ n_samples=200,
+ n_features=2,
+ n_informative=2,
+ n_redundant=0,
+ n_repeated=0,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.3, 0.6],
+ random_state=0,
+ )
+
+
def test_ncr_threshold_cleaning(data):
"""Test the effect of the `threshold_cleaning` parameter."""
- pass
+ X, y = data
+ # with a large `threshold_cleaning`, the algorithm is equivalent to ENN
+ enn = EditedNearestNeighbours()
+ ncr = NeighbourhoodCleaningRule(
+ edited_nearest_neighbours=enn, n_neighbors=10, threshold_cleaning=10
+ )
+
+ enn.fit_resample(X, y)
+ ncr.fit_resample(X, y)
+
+ assert_array_equal(np.sort(enn.sample_indices_), np.sort(ncr.sample_indices_))
+ assert ncr.classes_to_clean_ == []
+
+ # set a threshold that we should consider only the class #2
+ counter = Counter(y)
+ threshold = counter[1] / counter[0]
+ ncr.set_params(threshold_cleaning=threshold)
+ ncr.fit_resample(X, y)
+
+ assert set(ncr.classes_to_clean_) == {2}
+
+ # making the threshold slightly smaller to take into account class #1
+ ncr.set_params(threshold_cleaning=threshold - np.finfo(np.float32).eps)
+ ncr.fit_resample(X, y)
+
+ assert set(ncr.classes_to_clean_) == {1, 2}
def test_ncr_n_neighbors(data):
"""Check the effect of the NN on the cleaning of the second phase."""
- pass
+ X, y = data
+
+ enn = EditedNearestNeighbours()
+ ncr = NeighbourhoodCleaningRule(edited_nearest_neighbours=enn, n_neighbors=3)
+
+ ncr.fit_resample(X, y)
+ sample_indices_3_nn = ncr.sample_indices_
+
+ ncr.set_params(n_neighbors=10).fit_resample(X, y)
+ sample_indices_10_nn = ncr.sample_indices_
+
+ # we should have a more aggressive cleaning with n_neighbors is larger
+ assert len(sample_indices_3_nn) > len(sample_indices_10_nn)
+
+
+# TODO: remove in 0.14
+@pytest.mark.parametrize("kind_sel", ["all", "mode"])
+def test_ncr_deprecate_kind_sel(data, kind_sel):
+ X, y = data
+
+ with pytest.warns(FutureWarning, match="`kind_sel` is deprecated"):
+ NeighbourhoodCleaningRule(kind_sel=kind_sel).fit_resample(X, y)
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py
index e861896..3fb5458 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py
@@ -1,26 +1,129 @@
"""Test the module one-sided selection."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import OneSidedSelection
+
RND_SEED = 0
-X = np.array([[-0.3879569, 0.6894251], [-0.09322739, 1.28177189], [-
- 0.77740357, 0.74097941], [0.91542919, -0.65453327], [-0.03852113,
- 0.40910479], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-
- 0.18430329, 0.52328473], [-0.30126957, -0.66268378], [-0.65571327,
- 0.42412021], [-0.28305528, 0.30284991], [0.20246714, -0.34727125], [
- 1.06446472, -1.09279772], [0.30543283, -0.02589502], [-0.00717161,
- 0.00318087]])
+X = np.array(
+ [
+ [-0.3879569, 0.6894251],
+ [-0.09322739, 1.28177189],
+ [-0.77740357, 0.74097941],
+ [0.91542919, -0.65453327],
+ [-0.03852113, 0.40910479],
+ [-0.43877303, 1.07366684],
+ [-0.85795321, 0.82980738],
+ [-0.18430329, 0.52328473],
+ [-0.30126957, -0.66268378],
+ [-0.65571327, 0.42412021],
+ [-0.28305528, 0.30284991],
+ [0.20246714, -0.34727125],
+ [1.06446472, -1.09279772],
+ [0.30543283, -0.02589502],
+ [-0.00717161, 0.00318087],
+ ]
+)
Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
+def test_oss_init():
+ oss = OneSidedSelection(random_state=RND_SEED)
+
+ assert oss.n_seeds_S == 1
+ assert oss.n_jobs is None
+ assert oss.random_state == RND_SEED
+
+
+def test_oss_fit_resample():
+ oss = OneSidedSelection(random_state=RND_SEED)
+ X_resampled, y_resampled = oss.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.3879569, 0.6894251],
+ [0.91542919, -0.65453327],
+ [-0.65571327, 0.42412021],
+ [1.06446472, -1.09279772],
+ [0.30543283, -0.02589502],
+ [-0.00717161, 0.00318087],
+ [-0.09322739, 1.28177189],
+ [-0.77740357, 0.74097941],
+ [-0.43877303, 1.07366684],
+ [-0.85795321, 0.82980738],
+ [-0.30126957, -0.66268378],
+ [0.20246714, -0.34727125],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+@pytest.mark.parametrize("n_neighbors", [1, KNeighborsClassifier(n_neighbors=1)])
+def test_oss_with_object(n_neighbors):
+ oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=n_neighbors)
+ X_resampled, y_resampled = oss.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.3879569, 0.6894251],
+ [0.91542919, -0.65453327],
+ [-0.65571327, 0.42412021],
+ [1.06446472, -1.09279772],
+ [0.30543283, -0.02589502],
+ [-0.00717161, 0.00318087],
+ [-0.09322739, 1.28177189],
+ [-0.77740357, 0.74097941],
+ [-0.43877303, 1.07366684],
+ [-0.85795321, 0.82980738],
+ [-0.30126957, -0.66268378],
+ [0.20246714, -0.34727125],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+ knn = 1
+ oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn)
+ X_resampled, y_resampled = oss.fit_resample(X, Y)
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
def test_one_sided_selection_multiclass():
"""Check the validity of the fitted attributes `estimators_`."""
- pass
+ X, y = make_classification(
+ n_samples=1_000,
+ n_classes=4,
+ weights=[0.1, 0.2, 0.2, 0.5],
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ oss = OneSidedSelection(random_state=RND_SEED)
+ oss.fit_resample(X, y)
+
+ assert len(oss.estimators_) == len(oss.sampling_strategy_)
+ other_classes = []
+ for est in oss.estimators_:
+ assert est.classes_[0] == 0 # minority class
+ assert est.classes_[1] in {1, 2, 3} # other classes
+ other_classes.append(est.classes_[1])
+ assert len(set(other_classes)) == len(other_classes)
+# TODO: remove in 0.14
def test_one_sided_selection_deprecation():
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
- pass
+ oss = OneSidedSelection(random_state=RND_SEED)
+ oss.fit_resample(X, Y)
+ warn_msg = "`estimator_` attribute has been deprecated"
+ with pytest.warns(FutureWarning, match=warn_msg):
+ oss.estimator_
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py
index 96745c6..f4e9279 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py
@@ -1,30 +1,167 @@
"""Test the module random under sampler."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter
from datetime import datetime
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import RandomUnderSampler
+
RND_SEED = 0
-X = np.array([[0.04352327, -0.20515826], [0.92923648, 0.76103773], [
- 0.20792588, 1.49407907], [0.47104475, 0.44386323], [0.22950086,
- 0.33367433], [0.15490546, 0.3130677], [0.09125309, -0.85409574], [
- 0.12372842, 0.6536186], [0.13347175, 0.12167502], [0.094035, -2.55298982]])
+X = np.array(
+ [
+ [0.04352327, -0.20515826],
+ [0.92923648, 0.76103773],
+ [0.20792588, 1.49407907],
+ [0.47104475, 0.44386323],
+ [0.22950086, 0.33367433],
+ [0.15490546, 0.3130677],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.13347175, 0.12167502],
+ [0.094035, -2.55298982],
+ ]
+)
Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1])
-@pytest.mark.parametrize('sampling_strategy', ['auto', 'majority',
- 'not minority', 'not majority', 'all'])
+@pytest.mark.parametrize("as_frame", [True, False], ids=["dataframe", "array"])
+def test_rus_fit_resample(as_frame):
+ if as_frame:
+ pd = pytest.importorskip("pandas")
+ X_ = pd.DataFrame(X)
+ else:
+ X_ = X
+ rus = RandomUnderSampler(random_state=RND_SEED, replacement=True)
+ X_resampled, y_resampled = rus.fit_resample(X_, Y)
+
+ X_gt = np.array(
+ [
+ [0.92923648, 0.76103773],
+ [0.47104475, 0.44386323],
+ [0.13347175, 0.12167502],
+ [0.09125309, -0.85409574],
+ [0.12372842, 0.6536186],
+ [0.04352327, -0.20515826],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 1, 1, 1])
+
+ if as_frame:
+ assert hasattr(X_resampled, "loc")
+ # FIXME: we should use to_numpy with pandas >= 0.25
+ X_resampled = X_resampled.values
+
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_rus_fit_resample_half():
+ sampling_strategy = {0: 3, 1: 6}
+ rus = RandomUnderSampler(
+ sampling_strategy=sampling_strategy,
+ random_state=RND_SEED,
+ replacement=True,
+ )
+ X_resampled, y_resampled = rus.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [0.92923648, 0.76103773],
+ [0.47104475, 0.44386323],
+ [0.92923648, 0.76103773],
+ [0.15490546, 0.3130677],
+ [0.15490546, 0.3130677],
+ [0.15490546, 0.3130677],
+ [0.20792588, 1.49407907],
+ [0.15490546, 0.3130677],
+ [0.12372842, 0.6536186],
+ ]
+ )
+ y_gt = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+def test_multiclass_fit_resample():
+ y = Y.copy()
+ y[5] = 2
+ y[6] = 2
+ rus = RandomUnderSampler(random_state=RND_SEED)
+ X_resampled, y_resampled = rus.fit_resample(X, y)
+ count_y_res = Counter(y_resampled)
+ assert count_y_res[0] == 2
+ assert count_y_res[1] == 2
+ assert count_y_res[2] == 2
+
+
+def test_random_under_sampling_heterogeneous_data():
+ X_hetero = np.array(
+ [["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
+ )
+ y = np.array([0, 0, 1])
+ rus = RandomUnderSampler(random_state=RND_SEED)
+ X_res, y_res = rus.fit_resample(X_hetero, y)
+
+ assert X_res.shape[0] == 2
+ assert y_res.shape[0] == 2
+ assert X_res.dtype == object
+
+
+def test_random_under_sampling_nan_inf():
+ # check that we can undersample even with missing or infinite data
+ # regression tests for #605
+ rng = np.random.RandomState(42)
+ n_not_finite = X.shape[0] // 3
+ row_indices = rng.choice(np.arange(X.shape[0]), size=n_not_finite)
+ col_indices = rng.randint(0, X.shape[1], size=n_not_finite)
+ not_finite_values = rng.choice([np.nan, np.inf], size=n_not_finite)
+
+ X_ = X.copy()
+ X_[row_indices, col_indices] = not_finite_values
+
+ rus = RandomUnderSampler(random_state=0)
+ X_res, y_res = rus.fit_resample(X_, Y)
+
+ assert y_res.shape == (6,)
+ assert X_res.shape == (6, 2)
+ assert np.any(~np.isfinite(X_res))
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy", ["auto", "majority", "not minority", "not majority", "all"]
+)
def test_random_under_sampler_strings(sampling_strategy):
"""Check that we support all supposed strings as `sampling_strategy` in
a sampler inheriting from `BaseUnderSampler`."""
- pass
+
+ X, y = make_classification(
+ n_samples=100,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.3, 0.6],
+ random_state=0,
+ )
+ RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)
def test_random_under_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
- pass
+ pd = pytest.importorskip("pandas")
+ X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
+ y = X["label"]
+ rus = RandomUnderSampler(random_state=0)
+ X_res, y_res = rus.fit_resample(X, y)
+
+ pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
+ pd.testing.assert_index_equal(X_res.index, y_res.index)
+ assert_array_equal(y_res.to_numpy(), np.array([0, 1]))
def test_random_under_sampler_full_nat():
@@ -33,4 +170,18 @@ def test_random_under_sampler_full_nat():
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
"""
- pass
+ pd = pytest.importorskip("pandas")
+
+ X = pd.DataFrame(
+ {
+ "col_str": ["abc", "def", "xyz"],
+ "col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]),
+ }
+ )
+ y = np.array([0, 0, 1])
+
+ X_res, y_res = RandomUnderSampler().fit_resample(X, y)
+ assert X_res.shape == (2, 2)
+ assert y_res.shape == (2,)
+
+ assert X_res["col_timedelta"].dtype == "timedelta64[ns]"
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py
index 92df66a..edd3a91 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py
@@ -1,25 +1,338 @@
"""Test the module repeated edited nearest neighbour."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.neighbors import NearestNeighbors
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import RepeatedEditedNearestNeighbours
-X = np.array([[-0.12840393, 0.66446571], [1.32319756, -0.13181616], [
- 0.04296502, -0.37981873], [0.83631853, 0.18569783], [1.02956816,
- 0.36061601], [1.12202806, 0.33811558], [-0.53171468, -0.53735182], [
- 1.3381556, 0.35956356], [-0.35946678, 0.72510189], [1.32326943,
- 0.28393874], [2.94290565, -0.13986434], [0.28294738, -1.00125525], [
- 0.34218094, -0.58781961], [-0.88864036, -0.33782387], [-1.10146139,
- 0.91782682], [-0.7969716, -0.50493969], [0.73489726, 0.43915195], [
- 0.2096964, -0.61814058], [-0.28479268, 0.70459548], [1.84864913,
- 0.14729596], [1.59068979, -0.96622933], [0.73418199, -0.02222847], [
- 0.50307437, 0.498805], [0.84929742, 0.41042894], [0.62649535,
- 0.46600596], [0.79270821, -0.41386668], [1.16606871, -0.25641059], [
- 1.57356906, 0.30390519], [1.0304995, -0.16955962], [1.67314371,
- 0.19231498], [0.98382284, 0.37184502], [0.48921682, -1.38504507], [-
- 0.46226554, -0.50481004], [-0.03918551, -0.68540745], [0.24991051, -
- 1.00864997], [0.80541964, -0.34465185], [0.1732627, -1.61323172], [
- 0.69804044, 0.44810796], [-0.5506368, -0.42072426], [-0.34474418,
- 0.21969797]])
-Y = np.array([1, 2, 2, 2, 1, 1, 0, 2, 1, 1, 1, 2, 2, 0, 1, 2, 1, 2, 1, 1, 2,
- 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 0, 2, 2, 2, 2, 1, 2, 0])
+
+X = np.array(
+ [
+ [-0.12840393, 0.66446571],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.83631853, 0.18569783],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.53171468, -0.53735182],
+ [1.3381556, 0.35956356],
+ [-0.35946678, 0.72510189],
+ [1.32326943, 0.28393874],
+ [2.94290565, -0.13986434],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [-0.88864036, -0.33782387],
+ [-1.10146139, 0.91782682],
+ [-0.7969716, -0.50493969],
+ [0.73489726, 0.43915195],
+ [0.2096964, -0.61814058],
+ [-0.28479268, 0.70459548],
+ [1.84864913, 0.14729596],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.57356906, 0.30390519],
+ [1.0304995, -0.16955962],
+ [1.67314371, 0.19231498],
+ [0.98382284, 0.37184502],
+ [0.48921682, -1.38504507],
+ [-0.46226554, -0.50481004],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ [0.69804044, 0.44810796],
+ [-0.5506368, -0.42072426],
+ [-0.34474418, 0.21969797],
+ ]
+)
+Y = np.array(
+ [
+ 1,
+ 2,
+ 2,
+ 2,
+ 1,
+ 1,
+ 0,
+ 2,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 0,
+ 1,
+ 2,
+ 1,
+ 2,
+ 1,
+ 1,
+ 2,
+ 2,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 1,
+ 1,
+ 2,
+ 0,
+ 2,
+ 2,
+ 2,
+ 2,
+ 1,
+ 2,
+ 0,
+ ]
+)
+
+
+def test_renn_init():
+ renn = RepeatedEditedNearestNeighbours()
+
+ assert renn.n_neighbors == 3
+ assert renn.kind_sel == "all"
+ assert renn.n_jobs is None
+
+
+def test_renn_iter_wrong():
+ max_iter = -1
+ renn = RepeatedEditedNearestNeighbours(max_iter=max_iter)
+ with pytest.raises(ValueError):
+ renn.fit_resample(X, Y)
+
+
+def test_renn_fit_resample():
+ renn = RepeatedEditedNearestNeighbours()
+ X_resampled, y_resampled = renn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [0.73489726, 0.43915195],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+ )
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+ assert 0 < renn.n_iter_ <= renn.max_iter
+
+
+def test_renn_fit_resample_mode_object():
+ renn = RepeatedEditedNearestNeighbours(kind_sel="mode")
+ X_resampled, y_resampled = renn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [-0.12840393, 0.66446571],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.35946678, 0.72510189],
+ [2.94290565, -0.13986434],
+ [-1.10146139, 0.91782682],
+ [0.73489726, 0.43915195],
+ [-0.28479268, 0.70459548],
+ [1.84864913, 0.14729596],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [1.67314371, 0.19231498],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ ]
+ )
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+ assert 0 < renn.n_iter_ <= renn.max_iter
+
+
+def test_renn_fit_resample_mode():
+ nn = NearestNeighbors(n_neighbors=4)
+ renn = RepeatedEditedNearestNeighbours(n_neighbors=nn, kind_sel="mode")
+ X_resampled, y_resampled = renn.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [-0.53171468, -0.53735182],
+ [-0.88864036, -0.33782387],
+ [-0.46226554, -0.50481004],
+ [-0.34474418, 0.21969797],
+ [-0.12840393, 0.66446571],
+ [1.02956816, 0.36061601],
+ [1.12202806, 0.33811558],
+ [-0.35946678, 0.72510189],
+ [2.94290565, -0.13986434],
+ [-1.10146139, 0.91782682],
+ [0.73489726, 0.43915195],
+ [-0.28479268, 0.70459548],
+ [1.84864913, 0.14729596],
+ [0.50307437, 0.498805],
+ [0.84929742, 0.41042894],
+ [0.62649535, 0.46600596],
+ [1.67314371, 0.19231498],
+ [0.98382284, 0.37184502],
+ [0.69804044, 0.44810796],
+ [1.32319756, -0.13181616],
+ [0.04296502, -0.37981873],
+ [0.28294738, -1.00125525],
+ [0.34218094, -0.58781961],
+ [0.2096964, -0.61814058],
+ [1.59068979, -0.96622933],
+ [0.73418199, -0.02222847],
+ [0.79270821, -0.41386668],
+ [1.16606871, -0.25641059],
+ [1.0304995, -0.16955962],
+ [0.48921682, -1.38504507],
+ [-0.03918551, -0.68540745],
+ [0.24991051, -1.00864997],
+ [0.80541964, -0.34465185],
+ [0.1732627, -1.61323172],
+ ]
+ )
+ y_gt = np.array(
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ ]
+ )
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+ assert 0 < renn.n_iter_ <= renn.max_iter
+
+
+@pytest.mark.parametrize(
+ "max_iter, n_iter",
+ [(2, 2), (5, 3)],
+)
+def test_renn_iter_attribute(max_iter, n_iter):
+ renn = RepeatedEditedNearestNeighbours(max_iter=max_iter)
+ renn.fit_resample(X, Y)
+ assert renn.n_iter_ == n_iter
diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_tomek_links.py b/imblearn/under_sampling/_prototype_selection/tests/test_tomek_links.py
index c1fd8e5..5fd8378 100644
--- a/imblearn/under_sampling/_prototype_selection/tests/test_tomek_links.py
+++ b/imblearn/under_sampling/_prototype_selection/tests/test_tomek_links.py
@@ -1,24 +1,89 @@
"""Test the module Tomek's links."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.utils._testing import assert_array_equal
+
from imblearn.under_sampling import TomekLinks
-X = np.array([[0.31230513, 0.1216318], [0.68481731, 0.51935141], [
- 1.34192108, -0.13367336], [0.62366841, -0.21312976], [1.61091956, -
- 0.40283504], [-0.37162401, -2.19400981], [0.74680821, 1.63827342], [
- 0.2184254, 0.24299982], [0.61472253, -0.82309052], [0.19893132, -
- 0.47761769], [1.06514042, -0.0770537], [0.97407872, 0.44454207], [
- 1.40301027, -0.83648734], [-1.20515198, -1.02689695], [-0.27410027, -
- 0.54194484], [0.8381014, 0.44085498], [-0.23374509, 0.18370049], [-
- 0.32635887, -0.29299653], [-0.00288378, 0.84259929], [1.79580611, -
- 0.02219234]])
+
+X = np.array(
+ [
+ [0.31230513, 0.1216318],
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.2184254, 0.24299982],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [1.06514042, -0.0770537],
+ [0.97407872, 0.44454207],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.27410027, -0.54194484],
+ [0.8381014, 0.44085498],
+ [-0.23374509, 0.18370049],
+ [-0.32635887, -0.29299653],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ ]
+)
Y = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
-@pytest.mark.parametrize('sampling_strategy', ['auto', 'majority',
- 'not minority', 'not majority', 'all'])
+def test_tl_init():
+ tl = TomekLinks()
+ assert tl.n_jobs is None
+
+
+def test_tl_fit_resample():
+ tl = TomekLinks()
+ X_resampled, y_resampled = tl.fit_resample(X, Y)
+
+ X_gt = np.array(
+ [
+ [0.31230513, 0.1216318],
+ [0.68481731, 0.51935141],
+ [1.34192108, -0.13367336],
+ [0.62366841, -0.21312976],
+ [1.61091956, -0.40283504],
+ [-0.37162401, -2.19400981],
+ [0.74680821, 1.63827342],
+ [0.2184254, 0.24299982],
+ [0.61472253, -0.82309052],
+ [0.19893132, -0.47761769],
+ [0.97407872, 0.44454207],
+ [1.40301027, -0.83648734],
+ [-1.20515198, -1.02689695],
+ [-0.23374509, 0.18370049],
+ [-0.32635887, -0.29299653],
+ [-0.00288378, 0.84259929],
+ [1.79580611, -0.02219234],
+ ]
+ )
+ y_gt = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0])
+ assert_array_equal(X_resampled, X_gt)
+ assert_array_equal(y_resampled, y_gt)
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy", ["auto", "majority", "not minority", "not majority", "all"]
+)
def test_tomek_links_strings(sampling_strategy):
"""Check that we support all supposed strings as `sampling_strategy` in
a sampler inheriting from `BaseCleaningSampler`."""
- pass
+
+ X, y = make_classification(
+ n_samples=100,
+ n_clusters_per_class=1,
+ n_classes=3,
+ weights=[0.1, 0.3, 0.6],
+ random_state=0,
+ )
+ TomekLinks(sampling_strategy=sampling_strategy).fit_resample(X, y)
diff --git a/imblearn/under_sampling/base.py b/imblearn/under_sampling/base.py
index f502ba5..92da457 100644
--- a/imblearn/under_sampling/base.py
+++ b/imblearn/under_sampling/base.py
@@ -1,8 +1,12 @@
"""
Base class for the under-sampling method.
"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import numbers
from collections.abc import Mapping
+
from ..base import BaseSampler
from ..utils._param_validation import Interval, StrOptions
@@ -13,9 +17,10 @@ class BaseUnderSampler(BaseSampler):
Warning: This class should not be used directly. Use the derive classes
instead.
"""
- _sampling_type = 'under-sampling'
- _sampling_strategy_docstring = (
- """sampling_strategy : float, str, dict, callable, default='auto'
+
+ _sampling_type = "under-sampling"
+
+ _sampling_strategy_docstring = """sampling_strategy : float, str, dict, callable, default='auto'
Sampling information to sample the data set.
- When ``float``, it corresponds to the desired ratio of the number of
@@ -51,11 +56,16 @@ class BaseUnderSampler(BaseSampler):
- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.
- """
- .rstrip())
- _parameter_constraints: dict = {'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), Mapping, callable]}
+ """.rstrip() # noqa: E501
+
+ _parameter_constraints: dict = {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ Mapping,
+ callable,
+ ],
+ }
class BaseCleaningSampler(BaseSampler):
@@ -64,9 +74,10 @@ class BaseCleaningSampler(BaseSampler):
Warning: This class should not be used directly. Use the derive classes
instead.
"""
- _sampling_type = 'clean-sampling'
- _sampling_strategy_docstring = (
- """sampling_strategy : str, list or callable
+
+ _sampling_type = "clean-sampling"
+
+ _sampling_strategy_docstring = """sampling_strategy : str, list or callable
Sampling information to sample the data set.
- When ``str``, specify the class targeted by the resampling. Note the
@@ -89,8 +100,13 @@ class BaseCleaningSampler(BaseSampler):
- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.
- """
- .rstrip())
- _parameter_constraints: dict = {'sampling_strategy': [Interval(numbers.
- Real, 0, 1, closed='right'), StrOptions({'auto', 'majority',
- 'not minority', 'not majority', 'all'}), list, callable]}
+ """.rstrip()
+
+ _parameter_constraints: dict = {
+ "sampling_strategy": [
+ Interval(numbers.Real, 0, 1, closed="right"),
+ StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
+ list,
+ callable,
+ ],
+ }
diff --git a/imblearn/utils/_available_if.py b/imblearn/utils/_available_if.py
index 51d5fc6..bca75e7 100644
--- a/imblearn/utils/_available_if.py
+++ b/imblearn/utils/_available_if.py
@@ -1,13 +1,17 @@
"""This is a copy of sklearn/utils/_available_if.py. It can be removed when
we support scikit-learn >= 1.1.
"""
+# mypy: ignore-errors
+
from functools import update_wrapper, wraps
from types import MethodType
+
import sklearn
from sklearn.utils.fixes import parse_version
+
sklearn_version = parse_version(sklearn.__version__)
-if sklearn_version < parse_version('1.1'):
+if sklearn_version < parse_version("1.1"):
class _AvailableIfDescriptor:
"""Implements a conditional property using the descriptor protocol.
@@ -24,23 +28,30 @@ if sklearn_version < parse_version('1.1'):
self.fn = fn
self.check = check
self.attribute_name = attribute_name
+
+ # update the docstring of the descriptor
update_wrapper(self, fn)
def __get__(self, obj, owner=None):
attr_err = AttributeError(
- f'This {owner.__name__!r} has no attribute {self.attribute_name!r}'
- )
+ f"This {owner.__name__!r} has no attribute {self.attribute_name!r}"
+ )
if obj is not None:
+ # delegate only on instances, not the classes.
+ # this is to allow access to the docstrings.
if not self.check(obj):
raise attr_err
out = MethodType(self.fn, obj)
- else:
+ else:
+ # This makes it possible to use the decorated method as an
+ # unbound method, for instance when monkeypatching.
@wraps(self.fn)
def out(*args, **kwargs):
if not self.check(args[0]):
raise attr_err
return self.fn(*args, **kwargs)
+
return out
def available_if(check):
@@ -82,6 +93,7 @@ if sklearn_version < parse_version('1.1'):
>>> obj.say_hello()
Hello
"""
- pass
+ return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__)
+
else:
- from sklearn.utils.metaestimators import available_if
+ from sklearn.utils.metaestimators import available_if # noqa
diff --git a/imblearn/utils/_docstring.py b/imblearn/utils/_docstring.py
index b678c3d..61921c3 100644
--- a/imblearn/utils/_docstring.py
+++ b/imblearn/utils/_docstring.py
@@ -1,5 +1,8 @@
"""Utilities for docstring in imbalanced-learn."""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
class Substitution:
"""Decorate a function's or a class' docstring to perform string
@@ -11,7 +14,8 @@ class Substitution:
def __init__(self, *args, **kwargs):
if args and kwargs:
- raise AssertionError('Only positional or keyword args are allowed')
+ raise AssertionError("Only positional or keyword args are allowed")
+
self.params = args or kwargs
def __call__(self, obj):
@@ -20,8 +24,7 @@ class Substitution:
return obj
-_random_state_docstring = (
- """random_state : int, RandomState instance, default=None
+_random_state_docstring = """random_state : int, RandomState instance, default=None
Control the randomization of the algorithm.
- If int, ``random_state`` is the seed used by the random number
@@ -30,14 +33,12 @@ _random_state_docstring = (
generator;
- If ``None``, the random number generator is the ``RandomState``
instance used by ``np.random``.
- """
- .rstrip())
-_n_jobs_docstring = (
- """n_jobs : int, default=None
+ """.rstrip()
+
+_n_jobs_docstring = """n_jobs : int, default=None
Number of CPU cores used during the cross-validation loop.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See
`Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_
for more details.
- """
- .rstrip())
+ """.rstrip()
diff --git a/imblearn/utils/_metadata_requests.py b/imblearn/utils/_metadata_requests.py
index aa6e024..c81aa4f 100644
--- a/imblearn/utils/_metadata_requests.py
+++ b/imblearn/utils/_metadata_requests.py
@@ -76,21 +76,49 @@ of the ``RequestMethod`` descriptor to classes, which is done in the
This mixin also implements the ``get_metadata_routing``, which meta-estimators
need to override, but it works for simple consumers as is.
"""
+
+# Author: Adrin Jalali <adrin.jalali@gmail.com>
+# License: BSD 3 clause
+
import inspect
from collections import namedtuple
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, Union
from warnings import warn
+
from sklearn import __version__, get_config
from sklearn.utils import Bunch
from sklearn.utils.fixes import parse_version
+
sklearn_version = parse_version(__version__)
-if parse_version(sklearn_version.base_version) < parse_version('1.4'):
- SIMPLE_METHODS = ['fit', 'partial_fit', 'predict', 'predict_proba',
- 'predict_log_proba', 'decision_function', 'score', 'split',
- 'transform', 'inverse_transform']
- COMPOSITE_METHODS = {'fit_transform': ['fit', 'transform'],
- 'fit_predict': ['fit', 'predict']}
+
+if parse_version(sklearn_version.base_version) < parse_version("1.4"):
+ # Only the following methods are supported in the routing mechanism. Adding new
+ # methods at the moment involves monkeypatching this list.
+ # Note that if this list is changed or monkeypatched, the corresponding method
+ # needs to be added under a TYPE_CHECKING condition like the one done here in
+ # _MetadataRequester
+ SIMPLE_METHODS = [
+ "fit",
+ "partial_fit",
+ "predict",
+ "predict_proba",
+ "predict_log_proba",
+ "decision_function",
+ "score",
+ "split",
+ "transform",
+ "inverse_transform",
+ ]
+
+ # These methods are a composite of other methods and one cannot set their
+ # requests directly. Instead they should be set by setting the requests of the
+ # simple methods which make the composite ones.
+ COMPOSITE_METHODS = {
+ "fit_transform": ["fit", "transform"],
+ "fit_predict": ["fit", "predict"],
+ }
+
METHODS = SIMPLE_METHODS + list(COMPOSITE_METHODS.keys())
def _routing_enabled():
@@ -104,7 +132,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
Whether metadata routing is enabled. If the config is not set, it
defaults to False.
"""
- pass
+ return get_config().get("enable_metadata_routing", False)
def _raise_for_params(params, owner, method):
"""Raise an error if metadata routing is not enabled and params are passed.
@@ -127,7 +155,19 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
ValueError
If metadata routing is not enabled and params are passed.
"""
- pass
+ caller = (
+ f"{owner.__class__.__name__}.{method}"
+ if method
+ else owner.__class__.__name__
+ )
+ if not _routing_enabled() and params:
+ raise ValueError(
+ f"Passing extra keyword arguments to {caller} is only supported if"
+ " enable_metadata_routing=True, which you can set using"
+ " `sklearn.set_config`. See the User Guide"
+ " <https://scikit-learn.org/stable/metadata_routing.html> for more"
+ f" details. Extra parameters passed are: {set(params)}"
+ )
def _raise_for_unsupported_routing(obj, method, **kwargs):
"""Raise when metadata routing is enabled and metadata is passed.
@@ -149,8 +189,14 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
**kwargs : dict
The metadata passed to the method.
"""
- pass
-
+ kwargs = {key: value for key, value in kwargs.items() if value is not None}
+ if _routing_enabled() and kwargs:
+ cls_name = obj.__class__.__name__
+ raise NotImplementedError(
+ f"{cls_name}.{method} cannot accept given metadata "
+ f"({set(kwargs.keys())}) since metadata routing is not yet implemented "
+ f"for {cls_name}."
+ )
class _RoutingNotSupportedMixin:
"""A mixin to be used to remove the default `get_metadata_routing`.
@@ -166,10 +212,29 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
"""Raise `NotImplementedError`.
This estimator does not support metadata routing yet."""
- pass
- UNUSED = '$UNUSED$'
- WARN = '$WARN$'
- UNCHANGED = '$UNCHANGED$'
+ raise NotImplementedError(
+ f"{self.__class__.__name__} has not implemented metadata routing yet."
+ )
+
+ # Request values
+ # ==============
+ # Each request value needs to be one of the following values, or an alias.
+
+ # this is used in `__metadata_request__*` attributes to indicate that a
+ # metadata is not present even though it may be present in the
+ # corresponding method's signature.
+ UNUSED = "$UNUSED$"
+
+ # this is used whenever a default value is changed, and therefore the user
+ # should explicitly set the value, otherwise a warning is shown. An example
+ # is when a meta-estimator is only a router, but then becomes also a
+ # consumer in a new release.
+ WARN = "$WARN$"
+
+ # this is the default used in `set_{method}_request` methods to indicate no
+ # change requested by the user.
+ UNCHANGED = "$UNCHANGED$"
+
VALID_REQUEST_VALUES = [False, True, None, UNUSED, WARN]
def request_is_alias(item):
@@ -188,7 +253,11 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
result : bool
Whether the given item is a valid alias.
"""
- pass
+ if item in VALID_REQUEST_VALUES:
+ return False
+
+ # item is only an alias if it's a valid identifier
+ return isinstance(item, str) and item.isidentifier()
def request_is_valid(item):
"""Check if an item is a valid request value (and not an alias).
@@ -203,8 +272,12 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
result : bool
Whether the given item is valid.
"""
- pass
+ return item in VALID_REQUEST_VALUES
+ # Metadata Request for Simple Consumers
+ # =====================================
+ # This section includes MethodMetadataRequest and MetadataRequest which are
+ # used in simple consumers.
class MethodMetadataRequest:
"""A prescription of how metadata is to be passed to a single method.
@@ -233,9 +306,14 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
@property
def requests(self):
"""Dictionary of the form: ``{key: alias}``."""
- pass
-
- def add_request(self, *, param, alias):
+ return self._requests
+
+ def add_request(
+ self,
+ *,
+ param,
+ alias,
+ ):
"""Add request info for a metadata.
Parameters
@@ -255,7 +333,28 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
- None: error if passed
"""
- pass
+ if not request_is_alias(alias) and not request_is_valid(alias):
+ raise ValueError(
+ f"The alias you're setting for `{param}` should be either a "
+ "valid identifier or one of {None, True, False}, but given "
+ f"value is: `{alias}`"
+ )
+
+ if alias == param:
+ alias = True
+
+ if alias == UNUSED:
+ if param in self._requests:
+ del self._requests[param]
+ else:
+ raise ValueError(
+ f"Trying to remove parameter {param} with UNUSED which doesn't"
+ " exist."
+ )
+ else:
+ self._requests[param] = alias
+
+ return self
def _get_param_names(self, return_alias):
"""Get names of all metadata that can be consumed or routed by this method.
@@ -274,7 +373,11 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
names : set of str
A set of strings with the names of all parameters.
"""
- pass
+ return set(
+ alias if return_alias and not request_is_valid(alias) else prop
+ for prop, alias in self._requests.items()
+ if not request_is_valid(alias) or alias is not False
+ )
def _check_warnings(self, *, params):
"""Check whether metadata is passed which is marked as WARN.
@@ -286,7 +389,19 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
params : dict
The metadata passed to a method.
"""
- pass
+ params = {} if params is None else params
+ warn_params = {
+ prop
+ for prop, alias in self._requests.items()
+ if alias == WARN and prop in params
+ }
+ for param in warn_params:
+ warn(
+ f"Support for {param} has recently been added to this class. "
+ "To maintain backward compatibility, it is ignored now. "
+ "You can set the request value to False to silence this "
+ "warning, or to True to consume and use the metadata."
+ )
def _route_params(self, params):
"""Prepare the given parameters to be passed to the method.
@@ -305,7 +420,30 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to
the corresponding method.
"""
- pass
+ self._check_warnings(params=params)
+ unrequested = dict()
+ args = {arg: value for arg, value in params.items() if value is not None}
+ res = Bunch()
+ for prop, alias in self._requests.items():
+ if alias is False or alias == WARN:
+ continue
+ elif alias is True and prop in args:
+ res[prop] = args[prop]
+ elif alias is None and prop in args:
+ unrequested[prop] = args[prop]
+ elif alias in args:
+ res[prop] = args[alias]
+ if unrequested:
+ raise UnsetMetadataPassedError(
+ message=(
+ f"[{', '.join([key for key in unrequested])}] are passed but "
+ "are not explicitly set as requested or not for"
+ f" {self.owner}.{self.method}"
+ ),
+ unrequested_params=unrequested,
+ routed_params=res,
+ )
+ return res
def _consumes(self, params):
"""Check whether the given parameters are consumed by this method.
@@ -320,7 +458,14 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
consumed : set of str
A set of parameters which are consumed by this method.
"""
- pass
+ params = set(params)
+ res = set()
+ for prop, alias in self._requests.items():
+ if alias is True and prop in params:
+ res.add(prop)
+ elif isinstance(alias, str) and alias in params:
+ res.add(alias)
+ return res
def _serialize(self):
"""Serialize the object.
@@ -330,7 +475,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
obj : dict
A serialized version of the instance in the form of a dictionary.
"""
- pass
+ return self._requests
def __repr__(self):
return str(self._serialize())
@@ -338,7 +483,6 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
def __str__(self):
return str(repr(self))
-
class MetadataRequest:
"""Contains the metadata request info of a consumer.
@@ -355,13 +499,20 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
owner : str
The name of the object to which these requests belong.
"""
- _type = 'metadata_request'
+
+ # this is here for us to use this attribute's value instead of doing
+ # `isinstance` in our checks, so that we avoid issues when people vendor
+ # this file instead of using it directly from scikit-learn.
+ _type = "metadata_request"
def __init__(self, owner):
self.owner = owner
for method in SIMPLE_METHODS:
- setattr(self, method, MethodMetadataRequest(owner=owner,
- method=method))
+ setattr(
+ self,
+ method,
+ MethodMetadataRequest(owner=owner, method=method),
+ )
def consumes(self, method, params):
"""Check whether the given parameters are consumed by the given method.
@@ -381,32 +532,45 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
consumed : set of str
A set of parameters which are consumed by the given method.
"""
- pass
+ return getattr(self, method)._consumes(params=params)
def __getattr__(self, name):
+ # Called when the default attribute access fails with an AttributeError
+ # (either __getattribute__() raises an AttributeError because name is
+ # not an instance attribute or an attribute in the class tree for self;
+ # or __get__() of a name property raises AttributeError). This method
+ # should either return the (computed) attribute value or raise an
+ # AttributeError exception.
+ # https://docs.python.org/3/reference/datamodel.html#object.__getattr__
if name not in COMPOSITE_METHODS:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
- )
+ )
+
requests = {}
for method in COMPOSITE_METHODS[name]:
mmr = getattr(self, method)
existing = set(requests.keys())
upcoming = set(mmr.requests.keys())
common = existing & upcoming
- conflicts = [key for key in common if requests[key] != mmr.
- _requests[key]]
+ conflicts = [
+ key for key in common if requests[key] != mmr._requests[key]
+ ]
if conflicts:
raise ValueError(
- f"Conflicting metadata requests for {', '.join(conflicts)} while composing the requests for {name}. Metadata with the same name for methods {', '.join(COMPOSITE_METHODS[name])} should have the same request value."
- )
+ f"Conflicting metadata requests for {', '.join(conflicts)} "
+ f"while composing the requests for {name}. Metadata with the "
+ f"same name for methods {', '.join(COMPOSITE_METHODS[name])} "
+ "should have the same request value."
+ )
requests.update(mmr._requests)
- return MethodMetadataRequest(owner=self.owner, method=name,
- requests=requests)
+ return MethodMetadataRequest(
+ owner=self.owner, method=name, requests=requests
+ )
- def _get_param_names(self, method, return_alias,
- ignore_self_request=None):
- """Get names of all metadata that can be consumed or routed by specified method.
+ def _get_param_names(self, method, return_alias, ignore_self_request=None):
+ """Get names of all metadata that can be consumed or routed by specified \
+ method.
This method returns the names of all metadata, even the ``False``
ones.
@@ -428,7 +592,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
names : set of str
A set of strings with the names of all parameters.
"""
- pass
+ return getattr(self, method)._get_param_names(return_alias=return_alias)
def _route_params(self, *, method, params):
"""Prepare the given parameters to be passed to the method.
@@ -451,7 +615,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to
the corresponding method.
"""
- pass
+ return getattr(self, method)._route_params(params=params)
def _check_warnings(self, *, method, params):
"""Check whether metadata is passed which is marked as WARN.
@@ -466,7 +630,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
params : dict
The metadata passed to a method.
"""
- pass
+ getattr(self, method)._check_warnings(params=params)
def _serialize(self):
"""Serialize the object.
@@ -476,16 +640,32 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
obj : dict
A serialized version of the instance in the form of a dictionary.
"""
- pass
+ output = dict()
+ for method in SIMPLE_METHODS:
+ mmr = getattr(self, method)
+ if len(mmr.requests):
+ output[method] = mmr._serialize()
+ return output
def __repr__(self):
return str(self._serialize())
def __str__(self):
return str(repr(self))
- RouterMappingPair = namedtuple('RouterMappingPair', ['mapping', 'router'])
- MethodPair = namedtuple('MethodPair', ['callee', 'caller'])
+ # Metadata Request for Routers
+ # ============================
+ # This section includes all objects required for MetadataRouter which is used
+ # in routers, returned by their ``get_metadata_routing``.
+
+ # This namedtuple is used to store a (mapping, routing) pair. Mapping is a
+ # MethodMapping object, and routing is the output of `get_metadata_routing`.
+ # MetadataRouter stores a collection of these namedtuples.
+ RouterMappingPair = namedtuple("RouterMappingPair", ["mapping", "router"])
+
+ # A namedtuple storing a single method route. A collection of these namedtuples
+ # is stored in a MetadataRouter.
+ MethodPair = namedtuple("MethodPair", ["callee", "caller"])
class MethodMapping:
"""Stores the mapping between callee and caller methods for a router.
@@ -523,7 +703,18 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
self : MethodMapping
Returns self.
"""
- pass
+ if callee not in METHODS:
+ raise ValueError(
+ f"Given callee:{callee} is not a valid method. Valid methods are:"
+ f" {METHODS}"
+ )
+ if caller not in METHODS:
+ raise ValueError(
+ f"Given caller:{caller} is not a valid method. Valid methods are:"
+ f" {METHODS}"
+ )
+ self._routes.append(MethodPair(callee=callee, caller=caller))
+ return self
def _serialize(self):
"""Serialize the object.
@@ -533,7 +724,10 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
obj : list
A serialized version of the instance in the form of a list.
"""
- pass
+ result = list()
+ for route in self._routes:
+ result.append({"callee": route.callee, "caller": route.caller})
+ return result
@classmethod
def from_str(cls, route):
@@ -554,7 +748,15 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A :class:`~sklearn.utils.metadata_routing.MethodMapping` instance
constructed from the given string.
"""
- pass
+ routing = cls()
+ if route == "one-to-one":
+ for method in METHODS:
+ routing.add(callee=method, caller=method)
+ elif route in METHODS:
+ routing.add(callee=route, caller=route)
+ else:
+ raise ValueError("route should be 'one-to-one' or a single method!")
+ return routing
def __repr__(self):
return str(self._serialize())
@@ -562,7 +764,6 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
def __str__(self):
return str(repr(self))
-
class MetadataRouter:
"""Stores and handles metadata routing for a router object.
@@ -581,10 +782,18 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
owner : str
The name of the object to which these requests belong.
"""
- _type = 'metadata_router'
+
+ # this is here for us to use this attribute's value instead of doing
+ # `isinstance`` in our checks, so that we avoid issues when people vendor
+ # this file instead of using it directly from scikit-learn.
+ _type = "metadata_router"
def __init__(self, owner):
self._route_mappings = dict()
+ # `_self_request` is used if the router is also a consumer.
+ # _self_request, (added using `add_self_request()`) is treated
+ # differently from the other objects which are stored in
+ # _route_mappings.
self._self_request = None
self.owner = owner
@@ -612,7 +821,17 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
self : MetadataRouter
Returns `self`.
"""
- pass
+ if getattr(obj, "_type", None) == "metadata_request":
+ self._self_request = deepcopy(obj)
+ elif hasattr(obj, "_get_metadata_request"):
+ self._self_request = deepcopy(obj._get_metadata_request())
+ else:
+ raise ValueError(
+ "Given `obj` is neither a `MetadataRequest` nor does it implement "
+ "the required API. Inheriting from `BaseEstimator` implements the "
+ "required API."
+ )
+ return self
def add(self, *, method_mapping, **objs):
"""Add named objects with their corresponding method mapping.
@@ -633,7 +852,16 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
self : MetadataRouter
Returns `self`.
"""
- pass
+ if isinstance(method_mapping, str):
+ method_mapping = MethodMapping.from_str(method_mapping)
+ else:
+ method_mapping = deepcopy(method_mapping)
+
+ for name, obj in objs.items():
+ self._route_mappings[name] = RouterMappingPair(
+ mapping=method_mapping, router=get_routing_for_object(obj)
+ )
+ return self
def consumes(self, method, params):
"""Check whether the given parameters are consumed by the given method.
@@ -653,11 +881,22 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
consumed : set of str
A set of parameters which are consumed by the given method.
"""
- pass
+ res = set()
+ if self._self_request:
+ res = res | self._self_request.consumes(method=method, params=params)
+
+ for _, route_mapping in self._route_mappings.items():
+ for callee, caller in route_mapping.mapping:
+ if caller == method:
+ res = res | route_mapping.router.consumes(
+ method=callee, params=params
+ )
- def _get_param_names(self, *, method, return_alias, ignore_self_request
- ):
- """Get names of all metadata that can be consumed or routed by specified method.
+ return res
+
+ def _get_param_names(self, *, method, return_alias, ignore_self_request):
+ """Get names of all metadata that can be consumed or routed by specified \
+ method.
This method returns the names of all metadata, even the ``False``
ones.
@@ -681,7 +920,25 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
names : set of str
A set of strings with the names of all parameters.
"""
- pass
+ res = set()
+ if self._self_request and not ignore_self_request:
+ res = res.union(
+ self._self_request._get_param_names(
+ method=method, return_alias=return_alias
+ )
+ )
+
+ for name, route_mapping in self._route_mappings.items():
+ for callee, caller in route_mapping.mapping:
+ if caller == method:
+ res = res.union(
+ route_mapping.router._get_param_names(
+ method=callee,
+ return_alias=True,
+ ignore_self_request=False,
+ )
+ )
+ return res
def _route_params(self, *, params, method):
"""Prepare the given parameters to be passed to the method.
@@ -708,7 +965,31 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to
the corresponding method.
"""
- pass
+ res = Bunch()
+ if self._self_request:
+ res.update(
+ self._self_request._route_params(params=params, method=method)
+ )
+
+ param_names = self._get_param_names(
+ method=method, return_alias=True, ignore_self_request=True
+ )
+ child_params = {
+ key: value for key, value in params.items() if key in param_names
+ }
+ for key in set(res.keys()).intersection(child_params.keys()):
+ # conflicts are okay if the passed objects are the same, but it's
+ # an issue if they're different objects.
+ if child_params[key] is not res[key]:
+ raise ValueError(
+ f"In {self.owner}, there is a conflict on {key} between what is"
+ " requested for this estimator and what is requested by its"
+ " children. You can resolve this conflict by using an alias for"
+ " the child estimator(s) requested metadata."
+ )
+
+ res.update(child_params)
+ return res
def route_params(self, *, caller, params):
"""Return the input parameters requested by child objects.
@@ -738,7 +1019,20 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
used to pass the required metadata to corresponding methods or
corresponding child objects.
"""
- pass
+ if self._self_request:
+ self._self_request._check_warnings(params=params, method=caller)
+
+ res = Bunch()
+ for name, route_mapping in self._route_mappings.items():
+ router, mapping = route_mapping.router, route_mapping.mapping
+
+ res[name] = Bunch()
+ for _callee, _caller in mapping:
+ if _caller == caller:
+ res[name][_callee] = router._route_params(
+ params=params, method=_callee
+ )
+ return res
def validate_metadata(self, *, method, params):
"""Validate given metadata for a method.
@@ -756,7 +1050,21 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
params : dict
A dictionary of provided metadata.
"""
- pass
+ param_names = self._get_param_names(
+ method=method, return_alias=False, ignore_self_request=False
+ )
+ if self._self_request:
+ self_params = self._self_request._get_param_names(
+ method=method, return_alias=False
+ )
+ else:
+ self_params = set()
+ extra_keys = set(params.keys()) - param_names - self_params
+ if extra_keys:
+ raise TypeError(
+ f"{self.owner}.{method} got unexpected argument(s) {extra_keys}, "
+ "which are not requested metadata in any object."
+ )
def _serialize(self):
"""Serialize the object.
@@ -766,15 +1074,27 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
obj : dict
A serialized version of the instance in the form of a dictionary.
"""
- pass
+ res = dict()
+ if self._self_request:
+ res["$self_request"] = self._self_request._serialize()
+ for name, route_mapping in self._route_mappings.items():
+ res[name] = dict()
+ res[name]["mapping"] = route_mapping.mapping._serialize()
+ res[name]["router"] = route_mapping.router._serialize()
+
+ return res
def __iter__(self):
if self._self_request:
- yield '$self_request', RouterMappingPair(mapping=
- MethodMapping.from_str('one-to-one'), router=self.
- _self_request)
+ yield (
+ "$self_request",
+ RouterMappingPair(
+ mapping=MethodMapping.from_str("one-to-one"),
+ router=self._self_request,
+ ),
+ )
for name, route_mapping in self._route_mappings.items():
- yield name, route_mapping
+ yield (name, route_mapping)
def __repr__(self):
return str(self._serialize())
@@ -813,7 +1133,23 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A ``MetadataRequest`` or a ``MetadataRouting`` taken or created from
the given object.
"""
- pass
+ # doing this instead of a try/except since an AttributeError could be raised
+ # for other reasons.
+ if hasattr(obj, "get_metadata_routing"):
+ return deepcopy(obj.get_metadata_routing())
+
+ elif getattr(obj, "_type", None) in ["metadata_request", "metadata_router"]:
+ return deepcopy(obj)
+
+ return MetadataRequest(owner=None)
+
+ # Request method
+ # ==============
+ # This section includes what's needed for the request method descriptor and
+ # their dynamic generation in a meta class.
+
+ # These strings are used to dynamically generate the docstrings for
+ # set_{method}_request methods.
REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method.
Note that this method is only relevant if
@@ -823,13 +1159,18 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
The options for each parameter are:
- - ``True``: metadata is requested, and passed to ``{method}`` if provided. The request is ignored if metadata is not provided.
+ - ``True``: metadata is requested, and \
+ passed to ``{method}`` if provided. The request is ignored if \
+ metadata is not provided.
- - ``False``: metadata is not requested and the meta-estimator will not pass it to ``{method}``.
+ - ``False``: metadata is not requested and the meta-estimator \
+ will not pass it to ``{method}``.
- - ``None``: metadata is not requested, and the meta-estimator will raise an error if the user provides it.
+ - ``None``: metadata is not requested, and the meta-estimator \
+ will raise an error if the user provides it.
- - ``str``: metadata should be passed to the meta-estimator with this given alias instead of the original name.
+ - ``str``: metadata should be passed to the meta-estimator with \
+ this given alias instead of the original name.
The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the
existing request. This allows you to change the request for some
@@ -845,7 +1186,8 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
Parameters
----------
"""
- REQUESTER_DOC_PARAM = """ {metadata} : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
+ REQUESTER_DOC_PARAM = """ {metadata} : str, True, False, or None, \
+ default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for ``{metadata}`` parameter in ``{method}``.
"""
@@ -855,7 +1197,6 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
The updated object.
"""
-
class RequestMethod:
"""
A descriptor for request methods.
@@ -895,7 +1236,7 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
self.validate_keys = validate_keys
def __get__(self, instance, owner):
-
+ # we would want to have a method which accepts only the expected args
def func(*args, **kw):
"""Updates the request for provided parameters
@@ -904,46 +1245,76 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
"""
if not _routing_enabled():
raise RuntimeError(
- 'This method is only available when metadata routing is enabled. You can enable it using sklearn.set_config(enable_metadata_routing=True).'
- )
- if self.validate_keys and set(kw) - set(self.keys):
+ "This method is only available when metadata routing is "
+ "enabled. You can enable it using"
+ " sklearn.set_config(enable_metadata_routing=True)."
+ )
+
+ if self.validate_keys and (set(kw) - set(self.keys)):
raise TypeError(
- f'Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments are: {set(self.keys)}'
- )
+ f"Unexpected args: {set(kw) - set(self.keys)}. Accepted "
+ f"arguments are: {set(self.keys)}"
+ )
+
+ # This makes it possible to use the decorated method as an unbound
+ # method, for instance when monkeypatching.
+ # https://github.com/scikit-learn/scikit-learn/issues/28632
if instance is None:
_instance = args[0]
args = args[1:]
else:
_instance = instance
+
+ # Replicating python's behavior when positional args are given other
+ # than `self`, and `self` is only allowed if this method is unbound.
if args:
raise TypeError(
- f'set_{self.name}_request() takes 0 positional argument but {len(args)} were given'
- )
+ f"set_{self.name}_request() takes 0 positional argument but"
+ f" {len(args)} were given"
+ )
+
requests = _instance._get_metadata_request()
method_metadata_request = getattr(requests, self.name)
+
for prop, alias in kw.items():
if alias is not UNCHANGED:
- method_metadata_request.add_request(param=prop,
- alias=alias)
+ method_metadata_request.add_request(param=prop, alias=alias)
_instance._metadata_request = requests
+
return _instance
- func.__name__ = f'set_{self.name}_request'
- params = [inspect.Parameter(name='self', kind=inspect.Parameter
- .POSITIONAL_OR_KEYWORD, annotation=owner)]
- params.extend([inspect.Parameter(k, inspect.Parameter.
- KEYWORD_ONLY, default=UNCHANGED, annotation=Optional[Union[
- bool, None, str]]) for k in self.keys])
- func.__signature__ = inspect.Signature(params,
- return_annotation=owner)
+
+ # Now we set the relevant attributes of the function so that it seems
+ # like a normal method to the end user, with known expected arguments.
+ func.__name__ = f"set_{self.name}_request"
+ params = [
+ inspect.Parameter(
+ name="self",
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=owner,
+ )
+ ]
+ params.extend(
+ [
+ inspect.Parameter(
+ k,
+ inspect.Parameter.KEYWORD_ONLY,
+ default=UNCHANGED,
+ annotation=Optional[Union[bool, None, str]],
+ )
+ for k in self.keys
+ ]
+ )
+ func.__signature__ = inspect.Signature(
+ params,
+ return_annotation=owner,
+ )
doc = REQUESTER_DOC.format(method=self.name)
for metadata in self.keys:
- doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method
- =self.name)
+ doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name)
doc += REQUESTER_DOC_RETURN
func.__doc__ = doc
return func
-
class _MetadataRequester:
"""Mixin class for adding metadata request functionality.
@@ -951,7 +1322,27 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
.. versionadded:: 1.3
"""
- if TYPE_CHECKING:
+
+ if TYPE_CHECKING: # pragma: no cover
+ # This code is never run in runtime, but it's here for type checking.
+ # Type checkers fail to understand that the `set_{method}_request`
+ # methods are dynamically generated, and they complain that they are
+ # not defined. We define them here to make type checkers happy.
+ # During type checking analyzers assume this to be True.
+ # The following list of defined methods mirrors the list of methods
+ # in SIMPLE_METHODS.
+ # fmt: off
+ def set_fit_request(self, **kwargs): pass
+ def set_partial_fit_request(self, **kwargs): pass
+ def set_predict_request(self, **kwargs): pass
+ def set_predict_proba_request(self, **kwargs): pass
+ def set_predict_log_proba_request(self, **kwargs): pass
+ def set_decision_function_request(self, **kwargs): pass
+ def set_score_request(self, **kwargs): pass
+ def set_split_request(self, **kwargs): pass
+ def set_transform_request(self, **kwargs): pass
+ def set_inverse_transform_request(self, **kwargs): pass
+ # fmt: on
def __init_subclass__(cls, **kwargs):
"""Set the ``set_{method}_request`` methods.
@@ -973,14 +1364,22 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
try:
requests = cls._get_default_requests()
except Exception:
+ # if there are any issues in the default values, it will be raised
+ # when ``get_metadata_routing`` is called. Here we are going to
+ # ignore all the issues such as bad defaults etc.
super().__init_subclass__(**kwargs)
return
+
for method in SIMPLE_METHODS:
mmr = getattr(requests, method)
+ # set ``set_{method}_request``` methods
if not len(mmr.requests):
continue
- setattr(cls, f'set_{method}_request', RequestMethod(method,
- sorted(mmr.requests.keys())))
+ setattr(
+ cls,
+ f"set_{method}_request",
+ RequestMethod(method, sorted(mmr.requests.keys())),
+ )
super().__init_subclass__(**kwargs)
@classmethod
@@ -1003,7 +1402,25 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
method_request : MethodMetadataRequest
The prepared request using the method's signature.
"""
- pass
+ mmr = MethodMetadataRequest(owner=cls.__name__, method=method)
+ # Here we use `isfunction` instead of `ismethod` because calling `getattr`
+ # on a class instead of an instance returns an unbound function.
+ if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)):
+ return mmr
+ # ignore the first parameter of the method, which is usually "self"
+ params = list(inspect.signature(getattr(cls, method)).parameters.items())[
+ 1:
+ ]
+ for pname, param in params:
+ if pname in {"X", "y", "Y", "Xt", "yt"}:
+ continue
+ if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}:
+ continue
+ mmr.add_request(
+ param=pname,
+ alias=None,
+ )
+ return mmr
@classmethod
def _get_default_requests(cls):
@@ -1013,7 +1430,43 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
class attributes, as well as determining request keys from method
signatures.
"""
- pass
+ requests = MetadataRequest(owner=cls.__name__)
+
+ for method in SIMPLE_METHODS:
+ setattr(
+ requests,
+ method,
+ cls._build_request_for_signature(router=requests, method=method),
+ )
+
+ # Then overwrite those defaults with the ones provided in
+ # __metadata_request__* attributes. Defaults set in
+ # __metadata_request__* attributes take precedence over signature
+ # sniffing.
+
+ # need to go through the MRO since this is a class attribute and
+ # ``vars`` doesn't report the parent class attributes. We go through
+ # the reverse of the MRO so that child classes have precedence over
+ # their parents.
+ defaults = dict()
+ for base_class in reversed(inspect.getmro(cls)):
+ base_defaults = {
+ attr: value
+ for attr, value in vars(base_class).items()
+ if "__metadata_request__" in attr
+ }
+ defaults.update(base_defaults)
+ defaults = dict(sorted(defaults.items()))
+
+ for attr, value in defaults.items():
+ # we don't check for attr.startswith() since python prefixes attrs
+ # starting with __ with the `_ClassName`.
+ substr = "__metadata_request__"
+ method = attr[attr.index(substr) + len(substr) :]
+ for prop, alias in value.items():
+ getattr(requests, method).add_request(param=prop, alias=alias)
+
+ return requests
def _get_metadata_request(self):
"""Get requested data properties.
@@ -1026,7 +1479,12 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
request : MetadataRequest
A :class:`~sklearn.utils.metadata_routing.MetadataRequest` instance.
"""
- pass
+ if hasattr(self, "_metadata_request"):
+ requests = get_routing_for_object(self._metadata_request)
+ else:
+ requests = self._get_default_requests()
+
+ return requests
def get_metadata_routing(self):
"""Get metadata routing of this object.
@@ -1040,8 +1498,17 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating
routing information.
"""
- pass
+ return self._get_metadata_request()
+ # Process Routing in Routers
+ # ==========================
+ # This is almost always the only method used in routers to process and route
+ # given metadata. This is to minimize the boilerplate required in routers.
+
+ # Here the first two arguments are positional only which makes everything
+ # passed as keyword argument a metadata. The first two args also have an `_`
+ # prefix to reduce the chances of name collisions with the passed metadata, and
+ # since they're positional only, users will never type those underscores.
def process_routing(_obj, _method, /, **kwargs):
"""Validate and route input parameters.
@@ -1078,7 +1545,59 @@ if parse_version(sklearn_version.base_version) < parse_version('1.4'):
metadata to corresponding methods or corresponding child objects. The object
names are those defined in `obj.get_metadata_routing()`.
"""
- pass
+ if not kwargs:
+ # If routing is not enabled and kwargs are empty, then we don't have to
+ # try doing any routing, we can simply return a structure which returns
+ # an empty dict on routed_params.ANYTHING.ANY_METHOD.
+ class EmptyRequest:
+ def get(self, name, default=None):
+ return Bunch(**{method: dict() for method in METHODS})
+
+ def __getitem__(self, name):
+ return Bunch(**{method: dict() for method in METHODS})
+
+ def __getattr__(self, name):
+ return Bunch(**{method: dict() for method in METHODS})
+
+ return EmptyRequest()
+
+ if not (
+ hasattr(_obj, "get_metadata_routing") or isinstance(_obj, MetadataRouter)
+ ):
+ raise AttributeError(
+ f"The given object ({repr(_obj.__class__.__name__)}) needs to either"
+ " implement the routing method `get_metadata_routing` or be a"
+ " `MetadataRouter` instance."
+ )
+ if _method not in METHODS:
+ raise TypeError(
+ f"Can only route and process input on these methods: {METHODS}, "
+ f"while the passed method is: {_method}."
+ )
+
+ request_routing = get_routing_for_object(_obj)
+ request_routing.validate_metadata(params=kwargs, method=_method)
+ routed_params = request_routing.route_params(params=kwargs, caller=_method)
+
+ return routed_params
+
else:
from sklearn.exceptions import UnsetMetadataPassedError
- from sklearn.utils._metadata_requests import COMPOSITE_METHODS, METHODS, SIMPLE_METHODS, UNCHANGED, UNUSED, WARN, MetadataRequest, MetadataRouter, MethodMapping, _MetadataRequester, _raise_for_params, _raise_for_unsupported_routing, _routing_enabled, _RoutingNotSupportedMixin, get_routing_for_object, process_routing
+ from sklearn.utils._metadata_requests import ( # type: ignore[no-redef]
+ COMPOSITE_METHODS, # noqa
+ METHODS, # noqa
+ SIMPLE_METHODS, # noqa
+ UNCHANGED,
+ UNUSED,
+ WARN,
+ MetadataRequest,
+ MetadataRouter,
+ MethodMapping,
+ _MetadataRequester, # noqa
+ _raise_for_params, # noqa
+ _raise_for_unsupported_routing, # noqa
+ _routing_enabled,
+ _RoutingNotSupportedMixin, # noqa
+ get_routing_for_object,
+ process_routing, # noqa
+ )
diff --git a/imblearn/utils/_param_validation.py b/imblearn/utils/_param_validation.py
index 47542c0..3ccabf2 100644
--- a/imblearn/utils/_param_validation.py
+++ b/imblearn/utils/_param_validation.py
@@ -1,6 +1,7 @@
"""This is a copy of sklearn/utils/_param_validation.py. It can be removed when
we support scikit-learn >= 1.2.
"""
+# mypy: ignore-errors
import functools
import math
import operator
@@ -9,23 +10,27 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable
from inspect import signature
from numbers import Integral, Real
+
import numpy as np
import sklearn
from scipy.sparse import csr_matrix, issparse
from sklearn.utils.fixes import parse_version
+
from .._config import config_context, get_config
from ..utils.fixes import _is_arraylike_not_scalar
+
sklearn_version = parse_version(sklearn.__version__)
-if sklearn_version < parse_version('1.4'):
+if sklearn_version < parse_version("1.4"):
class InvalidParameterError(ValueError, TypeError):
"""Custom exception to be raised when the parameter of a class/method/function
does not have a valid type or value.
"""
- def validate_parameter_constraints(parameter_constraints, params,
- caller_name):
+ # Inherits from ValueError and TypeError to keep backward compatibility.
+
+ def validate_parameter_constraints(parameter_constraints, params, caller_name):
"""Validate types and values of given parameters.
Parameters
@@ -61,7 +66,46 @@ if sklearn_version < parse_version('1.4'):
caller_name : str
The name of the estimator or function or method that called this function.
"""
- pass
+ for param_name, param_val in params.items():
+ # We allow parameters to not have a constraint so that third party
+ # estimators can inherit from sklearn estimators without having to
+ # necessarily use the validation tools.
+ if param_name not in parameter_constraints:
+ continue
+
+ constraints = parameter_constraints[param_name]
+
+ if constraints == "no_validation":
+ continue
+
+ constraints = [make_constraint(constraint) for constraint in constraints]
+
+ for constraint in constraints:
+ if constraint.is_satisfied_by(param_val):
+ # this constraint is satisfied, no need to check further.
+ break
+ else:
+ # No constraint is satisfied, raise with an informative message.
+
+ # Ignore constraints that we don't want to expose in the error
+ # message, i.e. options that are for internal purpose or not
+ # officially supported.
+ constraints = [
+ constraint for constraint in constraints if not constraint.hidden
+ ]
+
+ if len(constraints) == 1:
+ constraints_str = f"{constraints[0]}"
+ else:
+ constraints_str = (
+ f"{', '.join([str(c) for c in constraints[:-1]])} or"
+ f" {constraints[-1]}"
+ )
+
+ raise InvalidParameterError(
+ f"The {param_name!r} parameter of {caller_name} must be"
+ f" {constraints_str}. Got {param_val!r} instead."
+ )
def make_constraint(constraint):
"""Convert the constraint into the appropriate Constraint object.
@@ -76,10 +120,37 @@ if sklearn_version < parse_version('1.4'):
constraint : instance of _Constraint
The converted constraint.
"""
- pass
-
- def validate_params(parameter_constraints, *, prefer_skip_nested_validation
+ if isinstance(constraint, str) and constraint == "array-like":
+ return _ArrayLikes()
+ if isinstance(constraint, str) and constraint == "sparse matrix":
+ return _SparseMatrices()
+ if isinstance(constraint, str) and constraint == "random_state":
+ return _RandomStates()
+ if constraint is callable:
+ return _Callables()
+ if constraint is None:
+ return _NoneConstraint()
+ if isinstance(constraint, type):
+ return _InstancesOf(constraint)
+ if isinstance(
+ constraint, (Interval, StrOptions, Options, HasMethods, MissingValues)
):
+ return constraint
+ if isinstance(constraint, str) and constraint == "boolean":
+ return _Booleans()
+ if isinstance(constraint, str) and constraint == "verbose":
+ return _VerboseHelper()
+ if isinstance(constraint, str) and constraint == "cv_object":
+ return _CVObjects()
+ if isinstance(constraint, Hidden):
+ constraint = make_constraint(constraint.constraint)
+ constraint.hidden = True
+ return constraint
+ if isinstance(constraint, str) and constraint == "nan":
+ return _NanConstraint()
+ raise ValueError(f"Unknown constraint type: {constraint}")
+
+ def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
"""Decorator to validate types and values of functions and methods.
Parameters
@@ -110,8 +181,62 @@ if sklearn_version < parse_version('1.4'):
decorated_function : function or method
The decorated function.
"""
- pass
+ def decorator(func):
+ # The dict of parameter constraints is set as an attribute of the function
+ # to make it possible to dynamically introspect the constraints for
+ # automatic testing.
+ setattr(func, "_skl_parameter_constraints", parameter_constraints)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ global_skip_validation = get_config()["skip_parameter_validation"]
+ if global_skip_validation:
+ return func(*args, **kwargs)
+
+ func_sig = signature(func)
+
+ # Map *args/**kwargs to the function signature
+ params = func_sig.bind(*args, **kwargs)
+ params.apply_defaults()
+
+ # ignore self/cls and positional/keyword markers
+ to_ignore = [
+ p.name
+ for p in func_sig.parameters.values()
+ if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
+ ]
+ to_ignore += ["self", "cls"]
+ params = {
+ k: v for k, v in params.arguments.items() if k not in to_ignore
+ }
+
+ validate_parameter_constraints(
+ parameter_constraints, params, caller_name=func.__qualname__
+ )
+
+ try:
+ with config_context(
+ skip_parameter_validation=(
+ prefer_skip_nested_validation or global_skip_validation
+ )
+ ):
+ return func(*args, **kwargs)
+ except InvalidParameterError as e:
+ # When the function is just a wrapper around an estimator, we allow
+ # the function to delegate validation to the estimator, but we
+ # replace the name of the estimator by the name of the function in
+ # the error message to avoid confusion.
+ msg = re.sub(
+ r"parameter of \w+ must be",
+ f"parameter of {func.__qualname__} must be",
+ str(e),
+ )
+ raise InvalidParameterError(msg) from e
+
+ return wrapper
+
+ return decorator
class RealNotInt(Real):
"""A type that represents reals that are not instances of int.
@@ -120,12 +245,20 @@ if sklearn_version < parse_version('1.4'):
isintance(1, RealNotInt) -> False
isinstance(1.0, RealNotInt) -> True
"""
+
RealNotInt.register(float)
def _type_name(t):
"""Convert type into human readable string."""
- pass
-
+ module = t.__module__
+ qualname = t.__qualname__
+ if module == "builtins":
+ return qualname
+ elif t == Real:
+ return "float"
+ elif t == Integral:
+ return "int"
+ return f"{module}.{qualname}"
class _Constraint(ABC):
"""Base class for the constraint objects."""
@@ -147,13 +280,11 @@ if sklearn_version < parse_version('1.4'):
is_satisfied : bool
Whether or not the constraint is satisfied by this value.
"""
- pass
@abstractmethod
def __str__(self):
"""A human readable representational string of the constraint."""
-
class _InstancesOf(_Constraint):
"""Constraint representing instances of a given type.
@@ -167,30 +298,47 @@ if sklearn_version < parse_version('1.4'):
super().__init__()
self.type = type
- def __str__(self):
- return f'an instance of {_type_name(self.type)!r}'
+ def is_satisfied_by(self, val):
+ return isinstance(val, self.type)
+ def __str__(self):
+ return f"an instance of {_type_name(self.type)!r}"
class _NoneConstraint(_Constraint):
"""Constraint representing the None singleton."""
- def __str__(self):
- return 'None'
+ def is_satisfied_by(self, val):
+ return val is None
+ def __str__(self):
+ return "None"
class _NanConstraint(_Constraint):
"""Constraint representing the indicator `np.nan`."""
- def __str__(self):
- return 'numpy.nan'
+ def is_satisfied_by(self, val):
+ return (
+ not isinstance(val, Integral)
+ and isinstance(val, Real)
+ and math.isnan(val)
+ )
+ def __str__(self):
+ return "numpy.nan"
class _PandasNAConstraint(_Constraint):
"""Constraint representing the indicator `pd.NA`."""
- def __str__(self):
- return 'pandas.NA'
+ def is_satisfied_by(self, val):
+ try:
+ import pandas as pd
+ return isinstance(val, type(pd.NA)) and pd.isna(val)
+ except ImportError:
+ return False
+
+ def __str__(self):
+ return "pandas.NA"
class Options(_Constraint):
"""Constraint representing a finite set of instances of a given type.
@@ -212,20 +360,27 @@ if sklearn_version < parse_version('1.4'):
self.type = type
self.options = options
self.deprecated = deprecated or set()
+
if self.deprecated - self.options:
raise ValueError(
- 'The deprecated options must be a subset of the options.')
+ "The deprecated options must be a subset of the options."
+ )
+
+ def is_satisfied_by(self, val):
+ return isinstance(val, self.type) and val in self.options
def _mark_if_deprecated(self, option):
"""Add a deprecated mark to an option if needed."""
- pass
+ option_str = f"{option!r}"
+ if option in self.deprecated:
+ option_str = f"{option_str} (deprecated)"
+ return option_str
def __str__(self):
options_str = (
f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}"
- )
- return f'a {_type_name(self.type)} among {{{options_str}}}'
-
+ )
+ return f"a {_type_name(self.type)} among {{{options_str}}}"
class StrOptions(Options):
"""Constraint representing a finite set of strings.
@@ -243,7 +398,6 @@ if sklearn_version < parse_version('1.4'):
def __init__(self, options, *, deprecated=None):
super().__init__(type=str, options=options, deprecated=deprecated)
-
class Interval(_Constraint):
"""Constraint representing a typed interval.
@@ -286,58 +440,118 @@ if sklearn_version < parse_version('1.4'):
self.left = left
self.right = right
self.closed = closed
+
self._check_params()
+ def _check_params(self):
+ if self.type not in (Integral, Real, RealNotInt):
+ raise ValueError(
+ "type must be either numbers.Integral, numbers.Real or RealNotInt."
+ f" Got {self.type} instead."
+ )
+
+ if self.closed not in ("left", "right", "both", "neither"):
+ raise ValueError(
+ "closed must be either 'left', 'right', 'both' or 'neither'. "
+ f"Got {self.closed} instead."
+ )
+
+ if self.type is Integral:
+ suffix = "for an interval over the integers."
+ if self.left is not None and not isinstance(self.left, Integral):
+ raise TypeError(f"Expecting left to be an int {suffix}")
+ if self.right is not None and not isinstance(self.right, Integral):
+ raise TypeError(f"Expecting right to be an int {suffix}")
+ if self.left is None and self.closed in ("left", "both"):
+ raise ValueError(
+ f"left can't be None when closed == {self.closed} {suffix}"
+ )
+ if self.right is None and self.closed in ("right", "both"):
+ raise ValueError(
+ f"right can't be None when closed == {self.closed} {suffix}"
+ )
+ else:
+ if self.left is not None and not isinstance(self.left, Real):
+ raise TypeError("Expecting left to be a real number.")
+ if self.right is not None and not isinstance(self.right, Real):
+ raise TypeError("Expecting right to be a real number.")
+
+ if (
+ self.right is not None
+ and self.left is not None
+ and self.right <= self.left
+ ):
+ raise ValueError(
+ f"right can't be less than left. Got left={self.left} and "
+ f"right={self.right}"
+ )
+
def __contains__(self, val):
if not isinstance(val, Integral) and np.isnan(val):
return False
- left_cmp = operator.lt if self.closed in ('left', 'both'
- ) else operator.le
- right_cmp = operator.gt if self.closed in ('right', 'both'
- ) else operator.ge
+
+ left_cmp = operator.lt if self.closed in ("left", "both") else operator.le
+ right_cmp = operator.gt if self.closed in ("right", "both") else operator.ge
+
left = -np.inf if self.left is None else self.left
right = np.inf if self.right is None else self.right
+
if left_cmp(val, left):
return False
if right_cmp(val, right):
return False
return True
+ def is_satisfied_by(self, val):
+ if not isinstance(val, self.type):
+ return False
+
+ return val in self
+
def __str__(self):
- type_str = 'an int' if self.type is Integral else 'a float'
- left_bracket = '[' if self.closed in ('left', 'both') else '('
- left_bound = '-inf' if self.left is None else self.left
- right_bound = 'inf' if self.right is None else self.right
- right_bracket = ']' if self.closed in ('right', 'both') else ')'
+ type_str = "an int" if self.type is Integral else "a float"
+ left_bracket = "[" if self.closed in ("left", "both") else "("
+ left_bound = "-inf" if self.left is None else self.left
+ right_bound = "inf" if self.right is None else self.right
+ right_bracket = "]" if self.closed in ("right", "both") else ")"
+
+ # better repr if the bounds were given as integers
if not self.type == Integral and isinstance(self.left, Real):
left_bound = float(left_bound)
if not self.type == Integral and isinstance(self.right, Real):
right_bound = float(right_bound)
- return (
- f'{type_str} in the range {left_bracket}{left_bound}, {right_bound}{right_bracket}'
- )
+ return (
+ f"{type_str} in the range "
+ f"{left_bracket}{left_bound}, {right_bound}{right_bracket}"
+ )
class _ArrayLikes(_Constraint):
"""Constraint representing array-likes"""
- def __str__(self):
- return 'an array-like'
+ def is_satisfied_by(self, val):
+ return _is_arraylike_not_scalar(val)
+ def __str__(self):
+ return "an array-like"
class _SparseMatrices(_Constraint):
"""Constraint representing sparse matrices."""
- def __str__(self):
- return 'a sparse matrix'
+ def is_satisfied_by(self, val):
+ return issparse(val)
+ def __str__(self):
+ return "a sparse matrix"
class _Callables(_Constraint):
"""Constraint representing callables."""
- def __str__(self):
- return 'a callable'
+ def is_satisfied_by(self, val):
+ return callable(val)
+ def __str__(self):
+ return "a callable"
class _RandomStates(_Constraint):
"""Constraint representing random states.
@@ -348,15 +562,20 @@ if sklearn_version < parse_version('1.4'):
def __init__(self):
super().__init__()
- self._constraints = [Interval(Integral, 0, 2 ** 32 - 1, closed=
- 'both'), _InstancesOf(np.random.RandomState), _NoneConstraint()
- ]
+ self._constraints = [
+ Interval(Integral, 0, 2**32 - 1, closed="both"),
+ _InstancesOf(np.random.RandomState),
+ _NoneConstraint(),
+ ]
+
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
- f"{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}"
- )
-
+ f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
+ f" {self._constraints[-1]}"
+ )
class _Booleans(_Constraint):
"""Constraint representing boolean likes.
@@ -367,13 +586,19 @@ if sklearn_version < parse_version('1.4'):
def __init__(self):
super().__init__()
- self._constraints = [_InstancesOf(bool), _InstancesOf(np.bool_)]
+ self._constraints = [
+ _InstancesOf(bool),
+ _InstancesOf(np.bool_),
+ ]
+
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
- f"{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}"
- )
-
+ f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
+ f" {self._constraints[-1]}"
+ )
class _VerboseHelper(_Constraint):
"""Helper constraint for the verbose parameter.
@@ -384,14 +609,20 @@ if sklearn_version < parse_version('1.4'):
def __init__(self):
super().__init__()
- self._constraints = [Interval(Integral, 0, None, closed='left'),
- _InstancesOf(bool), _InstancesOf(np.bool_)]
+ self._constraints = [
+ Interval(Integral, 0, None, closed="left"),
+ _InstancesOf(bool),
+ _InstancesOf(np.bool_),
+ ]
+
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
- f"{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}"
- )
-
+ f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
+ f" {self._constraints[-1]}"
+ )
class MissingValues(_Constraint):
"""Helper constraint for the `missing_values` parameters.
@@ -415,19 +646,28 @@ if sklearn_version < parse_version('1.4'):
def __init__(self, numeric_only=False):
super().__init__()
+
self.numeric_only = numeric_only
- self._constraints = [_InstancesOf(Integral), Interval(Real,
- None, None, closed='both'), _NanConstraint(),
- _PandasNAConstraint()]
+
+ self._constraints = [
+ _InstancesOf(Integral),
+ # we use an interval of Real to ignore np.nan that has its own
+ # constraint
+ Interval(Real, None, None, closed="both"),
+ _NanConstraint(),
+ _PandasNAConstraint(),
+ ]
if not self.numeric_only:
- self._constraints.extend([_InstancesOf(str), _NoneConstraint()]
- )
+ self._constraints.extend([_InstancesOf(str), _NoneConstraint()])
+
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
- f"{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}"
- )
-
+ f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
+ f" {self._constraints[-1]}"
+ )
class HasMethods(_Constraint):
"""Constraint representing objects that expose specific methods.
@@ -441,30 +681,37 @@ if sklearn_version < parse_version('1.4'):
The method(s) that the object is expected to expose.
"""
- @validate_params({'methods': [str, list]},
- prefer_skip_nested_validation=True)
+ @validate_params(
+ {"methods": [str, list]},
+ prefer_skip_nested_validation=True,
+ )
def __init__(self, methods):
super().__init__()
if isinstance(methods, str):
methods = [methods]
self.methods = methods
+ def is_satisfied_by(self, val):
+ return all(callable(getattr(val, method, None)) for method in self.methods)
+
def __str__(self):
if len(self.methods) == 1:
- methods = f'{self.methods[0]!r}'
+ methods = f"{self.methods[0]!r}"
else:
methods = (
- f"{', '.join([repr(m) for m in self.methods[:-1]])} and {self.methods[-1]!r}"
- )
- return f'an object implementing {methods}'
-
+ f"{', '.join([repr(m) for m in self.methods[:-1]])} and"
+ f" {self.methods[-1]!r}"
+ )
+ return f"an object implementing {methods}"
class _IterablesNotString(_Constraint):
"""Constraint representing iterables that are not strings."""
- def __str__(self):
- return 'an iterable'
+ def is_satisfied_by(self, val):
+ return isinstance(val, Iterable) and not isinstance(val, str)
+ def __str__(self):
+ return "an iterable"
class _CVObjects(_Constraint):
"""Constraint representing cv objects.
@@ -480,15 +727,21 @@ if sklearn_version < parse_version('1.4'):
def __init__(self):
super().__init__()
- self._constraints = [Interval(Integral, 2, None, closed='left'),
- HasMethods(['split', 'get_n_splits']), _IterablesNotString(
- ), _NoneConstraint()]
+ self._constraints = [
+ Interval(Integral, 2, None, closed="left"),
+ HasMethods(["split", "get_n_splits"]),
+ _IterablesNotString(),
+ _NoneConstraint(),
+ ]
+
+ def is_satisfied_by(self, val):
+ return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
- f"{', '.join([str(c) for c in self._constraints[:-1]])} or {self._constraints[-1]}"
- )
-
+ f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
+ f" {self._constraints[-1]}"
+ )
class Hidden:
"""Class encapsulating a constraint not meant to be exposed to the user.
@@ -520,7 +773,49 @@ if sklearn_version < parse_version('1.4'):
val : object
A value that does not satisfy the constraint.
"""
- pass
+ if isinstance(constraint, StrOptions):
+ return f"not {' or '.join(constraint.options)}"
+
+ if isinstance(constraint, MissingValues):
+ return np.array([1, 2, 3])
+
+ if isinstance(constraint, _VerboseHelper):
+ return -1
+
+ if isinstance(constraint, HasMethods):
+ return type("HasNotMethods", (), {})()
+
+ if isinstance(constraint, _IterablesNotString):
+ return "a string"
+
+ if isinstance(constraint, _CVObjects):
+ return "not a cv object"
+
+ if isinstance(constraint, Interval) and constraint.type is Integral:
+ if constraint.left is not None:
+ return constraint.left - 1
+ if constraint.right is not None:
+ return constraint.right + 1
+
+ # There's no integer outside (-inf, +inf)
+ raise NotImplementedError
+
+ if isinstance(constraint, Interval) and constraint.type in (Real, RealNotInt):
+ if constraint.left is not None:
+ return constraint.left - 1e-6
+ if constraint.right is not None:
+ return constraint.right + 1e-6
+
+ # bounds are -inf, +inf
+ if constraint.closed in ("right", "neither"):
+ return -np.inf
+ if constraint.closed in ("left", "neither"):
+ return np.inf
+
+ # interval is [-inf, +inf]
+ return np.nan
+
+ raise NotImplementedError
def generate_valid_param(constraint):
"""Return a value that does satisfy a constraint.
@@ -537,9 +832,103 @@ if sklearn_version < parse_version('1.4'):
val : object
A value that does satisfy the constraint.
"""
- pass
+ if isinstance(constraint, _ArrayLikes):
+ return np.array([1, 2, 3])
+
+ if isinstance(constraint, _SparseMatrices):
+ return csr_matrix([[0, 1], [1, 0]])
+
+ if isinstance(constraint, _RandomStates):
+ return np.random.RandomState(42)
+
+ if isinstance(constraint, _Callables):
+ return lambda x: x
+
+ if isinstance(constraint, _NoneConstraint):
+ return None
+
+ if isinstance(constraint, _InstancesOf):
+ if constraint.type is np.ndarray:
+ # special case for ndarray since it can't be instantiated without
+ # arguments
+ return np.array([1, 2, 3])
+
+ if constraint.type in (Integral, Real):
+ # special case for Integral and Real since they are abstract classes
+ return 1
+
+ return constraint.type()
+
+ if isinstance(constraint, _Booleans):
+ return True
+
+ if isinstance(constraint, _VerboseHelper):
+ return 1
+
+ if isinstance(constraint, MissingValues) and constraint.numeric_only:
+ return np.nan
+
+ if isinstance(constraint, MissingValues) and not constraint.numeric_only:
+ return "missing"
+
+ if isinstance(constraint, HasMethods):
+ return type(
+ "ValidHasMethods",
+ (),
+ {m: lambda self: None for m in constraint.methods},
+ )()
+
+ if isinstance(constraint, _IterablesNotString):
+ return [1, 2, 3]
+
+ if isinstance(constraint, _CVObjects):
+ return 5
+
+ if isinstance(constraint, Options): # includes StrOptions
+ for option in constraint.options:
+ return option
+
+ if isinstance(constraint, Interval):
+ interval = constraint
+ if interval.left is None and interval.right is None:
+ return 0
+ elif interval.left is None:
+ return interval.right - 1
+ elif interval.right is None:
+ return interval.left + 1
+ else:
+ if interval.type is Real:
+ return (interval.left + interval.right) / 2
+ else:
+ return interval.left + 1
+
+ raise ValueError(f"Unknown constraint type: {constraint}")
+
else:
- from sklearn.utils._param_validation import generate_invalid_param_val
- from sklearn.utils._param_validation import generate_valid_param
- from sklearn.utils._param_validation import validate_parameter_constraints
- from sklearn.utils._param_validation import HasMethods, Hidden, Interval, InvalidParameterError, MissingValues, Options, RealNotInt, StrOptions, _ArrayLikes, _Booleans, _Callables, _CVObjects, _InstancesOf, _IterablesNotString, _NanConstraint, _NoneConstraint, _PandasNAConstraint, _RandomStates, _SparseMatrices, _VerboseHelper, make_constraint, validate_params
+ from sklearn.utils._param_validation import generate_invalid_param_val # noqa
+ from sklearn.utils._param_validation import generate_valid_param # noqa
+ from sklearn.utils._param_validation import validate_parameter_constraints # noqa
+ from sklearn.utils._param_validation import (
+ HasMethods,
+ Hidden,
+ Interval,
+ InvalidParameterError,
+ MissingValues,
+ Options,
+ RealNotInt,
+ StrOptions,
+ _ArrayLikes,
+ _Booleans,
+ _Callables,
+ _CVObjects,
+ _InstancesOf,
+ _IterablesNotString,
+ _NanConstraint,
+ _NoneConstraint,
+ _PandasNAConstraint,
+ _RandomStates,
+ _SparseMatrices,
+ _VerboseHelper,
+ make_constraint,
+ validate_params,
+ )
diff --git a/imblearn/utils/_show_versions.py b/imblearn/utils/_show_versions.py
index e6bd42b..4912e3a 100644
--- a/imblearn/utils/_show_versions.py
+++ b/imblearn/utils/_show_versions.py
@@ -4,6 +4,10 @@ and filing issues on GitHub.
Adapted from :func:`sklearn.show_versions`,
which was adapted from :func:`pandas.show_versions`
"""
+
+# Author: Alexander L. Hayes <hayesall@iu.edu>
+# License: MIT
+
from .. import __version__
@@ -14,7 +18,32 @@ def _get_deps_info():
deps_info: dict
version information on relevant Python libraries
"""
- pass
+ deps = [
+ "imbalanced-learn",
+ "pip",
+ "setuptools",
+ "numpy",
+ "scipy",
+ "scikit-learn",
+ "Cython",
+ "pandas",
+ "keras",
+ "tensorflow",
+ "joblib",
+ ]
+
+ deps_info = {
+ "imbalanced-learn": __version__,
+ }
+
+ from importlib.metadata import PackageNotFoundError, version
+
+ for modname in deps:
+ try:
+ deps_info[modname] = version(modname)
+ except PackageNotFoundError:
+ deps_info[modname] = None
+ return deps_info
def show_versions(github=False):
@@ -27,4 +56,37 @@ def show_versions(github=False):
github : bool,
If true, wrap system info with GitHub markup.
"""
- pass
+
+ from sklearn.utils._show_versions import _get_sys_info
+
+ _sys_info = _get_sys_info()
+ _deps_info = _get_deps_info()
+ _github_markup = (
+ "<details>"
+ "<summary>System, Dependency Information</summary>\n\n"
+ "**System Information**\n\n"
+ "{0}\n"
+ "**Python Dependencies**\n\n"
+ "{1}\n"
+ "</details>"
+ )
+
+ if github:
+ _sys_markup = ""
+ _deps_markup = ""
+
+ for k, stat in _sys_info.items():
+ _sys_markup += f"* {k:<10}: `{stat}`\n"
+ for k, stat in _deps_info.items():
+ _deps_markup += f"* {k:<10}: `{stat}`\n"
+
+ print(_github_markup.format(_sys_markup, _deps_markup))
+
+ else:
+ print("\nSystem:")
+ for k, stat in _sys_info.items():
+ print(f"{k:>11}: {stat}")
+
+ print("\nPython dependencies:")
+ for k, stat in _deps_info.items():
+ print(f"{k:>11}: {stat}")
diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py
index 38a7408..b21c157 100644
--- a/imblearn/utils/_validation.py
+++ b/imblearn/utils/_validation.py
@@ -1,9 +1,14 @@
"""Utilities for input validation"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import warnings
from collections import OrderedDict
from functools import wraps
from inspect import Parameter, signature
from numbers import Integral, Real
+
import numpy as np
from scipy.sparse import issparse
from sklearn.base import clone
@@ -11,10 +16,17 @@ from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples
+
from .fixes import _is_pandas_df
-SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling',
- 'ensemble', 'bypass')
-TARGET_KIND = 'binary', 'multiclass', 'multilabel-indicator'
+
+SAMPLING_KIND = (
+ "over-sampling",
+ "under-sampling",
+ "clean-sampling",
+ "ensemble",
+ "bypass",
+)
+TARGET_KIND = ("binary", "multiclass", "multilabel-indicator")
class ArraysTransformer:
@@ -24,6 +36,62 @@ class ArraysTransformer:
self.x_props = self._gets_props(X)
self.y_props = self._gets_props(y)
+ def transform(self, X, y):
+ X = self._transfrom_one(X, self.x_props)
+ y = self._transfrom_one(y, self.y_props)
+ if self.x_props["type"].lower() == "dataframe" and self.y_props[
+ "type"
+ ].lower() in {"series", "dataframe"}:
+ # We lost the y.index during resampling. We can safely use X.index to align
+ # them.
+ y.index = X.index
+ return X, y
+
+ def _gets_props(self, array):
+ props = {}
+ props["type"] = array.__class__.__name__
+ props["columns"] = getattr(array, "columns", None)
+ props["name"] = getattr(array, "name", None)
+ props["dtypes"] = getattr(array, "dtypes", None)
+ return props
+
+ def _transfrom_one(self, array, props):
+ type_ = props["type"].lower()
+ if type_ == "list":
+ ret = array.tolist()
+ elif type_ == "dataframe":
+ import pandas as pd
+
+ if issparse(array):
+ ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"])
+ else:
+ ret = pd.DataFrame(array, columns=props["columns"])
+
+ try:
+ ret = ret.astype(props["dtypes"])
+ except TypeError:
+ # We special case the following error:
+ # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
+ # There is no easy way to have a generic workaround. Here, we detect
+ # that we have a column with only null values that is datetime64
+ # (resulting from the np.vstack of the resampling).
+ for col in ret.columns:
+ if (
+ ret[col].isnull().all()
+ and ret[col].dtype == "datetime64[ns]"
+ and props["dtypes"][col] == "timedelta64[ns]"
+ ):
+ ret[col] = pd.to_timedelta(["NaT"] * len(ret[col]))
+ # try again
+ ret = ret.astype(props["dtypes"])
+ elif type_ == "series":
+ import pandas as pd
+
+ ret = pd.Series(array, dtype=props["dtypes"], name=props["name"])
+ else:
+ ret = array
+ return ret
+
def _is_neighbors_object(estimator):
"""Check that the estimator exposes a KNeighborsMixin-like API.
@@ -41,7 +109,8 @@ def _is_neighbors_object(estimator):
is_neighbors_object : bool
True if the estimator exposes a KNeighborsMixin-like API.
"""
- pass
+ neighbors_attributes = ["kneighbors", "kneighbors_graph"]
+ return all(hasattr(estimator, attr) for attr in neighbors_attributes)
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
@@ -68,7 +137,15 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
nn_object : KNeighborsMixin
The k-NN object.
"""
- pass
+ if isinstance(nn_object, Integral):
+ return NearestNeighbors(n_neighbors=nn_object + additional_neighbor)
+ # _is_neighbors_object(nn_object)
+ return clone(nn_object)
+
+
+def _count_class_sample(y):
+ unique, counts = np.unique(y, return_counts=True)
+ return dict(zip(unique, counts))
def check_target_type(y, indicate_one_vs_all=False):
@@ -94,58 +171,272 @@ def check_target_type(y, indicate_one_vs_all=False):
Indicate if the target was originally encoded in a one-vs-all fashion.
Only returned if ``indicate_multilabel=True``.
"""
- pass
+ type_y = type_of_target(y)
+ if type_y == "multilabel-indicator":
+ if np.any(y.sum(axis=1) > 1):
+ raise ValueError(
+ "Imbalanced-learn currently supports binary, multiclass and "
+ "binarized encoded multiclasss targets. Multilabel and "
+ "multioutput targets are not supported."
+ )
+ y = y.argmax(axis=1)
+ else:
+ y = column_or_1d(y)
+
+ return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y
def _sampling_strategy_all(y, sampling_type):
"""Returns sampling target by targeting all classes."""
- pass
+ target_stats = _count_class_sample(y)
+ if sampling_type == "over-sampling":
+ n_sample_majority = max(target_stats.values())
+ sampling_strategy = {
+ key: n_sample_majority - value for (key, value) in target_stats.items()
+ }
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ n_sample_minority = min(target_stats.values())
+ sampling_strategy = {key: n_sample_minority for key in target_stats.keys()}
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy
def _sampling_strategy_majority(y, sampling_type):
"""Returns sampling target by targeting the majority class only."""
- pass
+ if sampling_type == "over-sampling":
+ raise ValueError(
+ "'sampling_strategy'='majority' cannot be used with over-sampler."
+ )
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ target_stats = _count_class_sample(y)
+ class_majority = max(target_stats, key=target_stats.get)
+ n_sample_minority = min(target_stats.values())
+ sampling_strategy = {
+ key: n_sample_minority
+ for key in target_stats.keys()
+ if key == class_majority
+ }
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy
def _sampling_strategy_not_majority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
majority."""
- pass
+ target_stats = _count_class_sample(y)
+ if sampling_type == "over-sampling":
+ n_sample_majority = max(target_stats.values())
+ class_majority = max(target_stats, key=target_stats.get)
+ sampling_strategy = {
+ key: n_sample_majority - value
+ for (key, value) in target_stats.items()
+ if key != class_majority
+ }
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ n_sample_minority = min(target_stats.values())
+ class_majority = max(target_stats, key=target_stats.get)
+ sampling_strategy = {
+ key: n_sample_minority
+ for key in target_stats.keys()
+ if key != class_majority
+ }
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy
def _sampling_strategy_not_minority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
minority."""
- pass
+ target_stats = _count_class_sample(y)
+ if sampling_type == "over-sampling":
+ n_sample_majority = max(target_stats.values())
+ class_minority = min(target_stats, key=target_stats.get)
+ sampling_strategy = {
+ key: n_sample_majority - value
+ for (key, value) in target_stats.items()
+ if key != class_minority
+ }
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ n_sample_minority = min(target_stats.values())
+ class_minority = min(target_stats, key=target_stats.get)
+ sampling_strategy = {
+ key: n_sample_minority
+ for key in target_stats.keys()
+ if key != class_minority
+ }
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy
def _sampling_strategy_minority(y, sampling_type):
"""Returns sampling target by targeting the minority class only."""
- pass
+ target_stats = _count_class_sample(y)
+ if sampling_type == "over-sampling":
+ n_sample_majority = max(target_stats.values())
+ class_minority = min(target_stats, key=target_stats.get)
+ sampling_strategy = {
+ key: n_sample_majority - value
+ for (key, value) in target_stats.items()
+ if key == class_minority
+ }
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ raise ValueError(
+ "'sampling_strategy'='minority' cannot be used with"
+ " under-sampler and clean-sampler."
+ )
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy
def _sampling_strategy_auto(y, sampling_type):
"""Returns sampling target auto for over-sampling and not-minority for
under-sampling."""
- pass
+ if sampling_type == "over-sampling":
+ return _sampling_strategy_not_majority(y, sampling_type)
+ elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
+ return _sampling_strategy_not_minority(y, sampling_type)
def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
"""Returns sampling target by converting the dictionary depending of the
sampling."""
- pass
+ target_stats = _count_class_sample(y)
+ # check that all keys in sampling_strategy are also in y
+ set_diff_sampling_strategy_target = set(sampling_strategy.keys()) - set(
+ target_stats.keys()
+ )
+ if len(set_diff_sampling_strategy_target) > 0:
+ raise ValueError(
+ f"The {set_diff_sampling_strategy_target} target class is/are not "
+ f"present in the data."
+ )
+ # check that there is no negative number
+ if any(n_samples < 0 for n_samples in sampling_strategy.values()):
+ raise ValueError(
+ f"The number of samples in a class cannot be negative."
+ f"'sampling_strategy' contains some negative value: {sampling_strategy}"
+ )
+ sampling_strategy_ = {}
+ if sampling_type == "over-sampling":
+ max(target_stats.values())
+ max(target_stats, key=target_stats.get)
+ for class_sample, n_samples in sampling_strategy.items():
+ if n_samples < target_stats[class_sample]:
+ raise ValueError(
+ f"With over-sampling methods, the number"
+ f" of samples in a class should be greater"
+ f" or equal to the original number of samples."
+ f" Originally, there is {target_stats[class_sample]} "
+ f"samples and {n_samples} samples are asked."
+ )
+ sampling_strategy_[class_sample] = n_samples - target_stats[class_sample]
+ elif sampling_type == "under-sampling":
+ for class_sample, n_samples in sampling_strategy.items():
+ if n_samples > target_stats[class_sample]:
+ raise ValueError(
+ f"With under-sampling methods, the number of"
+ f" samples in a class should be less or equal"
+ f" to the original number of samples."
+ f" Originally, there is {target_stats[class_sample]} "
+ f"samples and {n_samples} samples are asked."
+ )
+ sampling_strategy_[class_sample] = n_samples
+ elif sampling_type == "clean-sampling":
+ raise ValueError(
+ "'sampling_strategy' as a dict for cleaning methods is "
+ "not supported. Please give a list of the classes to be "
+ "targeted by the sampling."
+ )
+ else:
+ raise NotImplementedError
+
+ return sampling_strategy_
def _sampling_strategy_list(sampling_strategy, y, sampling_type):
"""With cleaning methods, sampling_strategy can be a list to target the
class of interest."""
- pass
+ if sampling_type != "clean-sampling":
+ raise ValueError(
+ "'sampling_strategy' cannot be a list for samplers "
+ "which are not cleaning methods."
+ )
+
+ target_stats = _count_class_sample(y)
+ # check that all keys in sampling_strategy are also in y
+ set_diff_sampling_strategy_target = set(sampling_strategy) - set(
+ target_stats.keys()
+ )
+ if len(set_diff_sampling_strategy_target) > 0:
+ raise ValueError(
+ f"The {set_diff_sampling_strategy_target} target class is/are not "
+ f"present in the data."
+ )
+
+ return {
+ class_sample: min(target_stats.values()) for class_sample in sampling_strategy
+ }
def _sampling_strategy_float(sampling_strategy, y, sampling_type):
"""Take a proportion of the majority (over-sampling) or minority
(under-sampling) class in binary classification."""
- pass
+ type_y = type_of_target(y)
+ if type_y != "binary":
+ raise ValueError(
+ '"sampling_strategy" can be a float only when the type '
+ "of target is binary. For multi-class, use a dict."
+ )
+ target_stats = _count_class_sample(y)
+ if sampling_type == "over-sampling":
+ n_sample_majority = max(target_stats.values())
+ class_majority = max(target_stats, key=target_stats.get)
+ sampling_strategy_ = {
+ key: int(n_sample_majority * sampling_strategy - value)
+ for (key, value) in target_stats.items()
+ if key != class_majority
+ }
+ if any([n_samples <= 0 for n_samples in sampling_strategy_.values()]):
+ raise ValueError(
+ "The specified ratio required to remove samples "
+ "from the minority class while trying to "
+ "generate new samples. Please increase the "
+ "ratio."
+ )
+ elif sampling_type == "under-sampling":
+ n_sample_minority = min(target_stats.values())
+ class_minority = min(target_stats, key=target_stats.get)
+ sampling_strategy_ = {
+ key: int(n_sample_minority / sampling_strategy)
+ for (key, value) in target_stats.items()
+ if key != class_minority
+ }
+ if any(
+ [
+ n_samples > target_stats[target]
+ for target, n_samples in sampling_strategy_.items()
+ ]
+ ):
+ raise ValueError(
+ "The specified ratio required to generate new "
+ "sample in the majority class while trying to "
+ "remove samples. Please increase the ratio."
+ )
+ else:
+ raise ValueError(
+ "'clean-sampling' methods do let the user specify the sampling ratio."
+ )
+ return sampling_strategy_
def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
@@ -236,14 +527,67 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
the key being the class target and the value being the desired
number of samples.
"""
- pass
-
-
-SAMPLING_TARGET_KIND = {'minority': _sampling_strategy_minority, 'majority':
- _sampling_strategy_majority, 'not minority':
- _sampling_strategy_not_minority, 'not majority':
- _sampling_strategy_not_majority, 'all': _sampling_strategy_all, 'auto':
- _sampling_strategy_auto}
+ if sampling_type not in SAMPLING_KIND:
+ raise ValueError(
+ f"'sampling_type' should be one of {SAMPLING_KIND}. "
+ f"Got '{sampling_type} instead."
+ )
+
+ if np.unique(y).size <= 1:
+ raise ValueError(
+ f"The target 'y' needs to have more than 1 class. "
+ f"Got {np.unique(y).size} class instead"
+ )
+
+ if sampling_type in ("ensemble", "bypass"):
+ return sampling_strategy
+
+ if isinstance(sampling_strategy, str):
+ if sampling_strategy not in SAMPLING_TARGET_KIND.keys():
+ raise ValueError(
+ f"When 'sampling_strategy' is a string, it needs"
+ f" to be one of {SAMPLING_TARGET_KIND}. Got '{sampling_strategy}' "
+ f"instead."
+ )
+ return OrderedDict(
+ sorted(SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items())
+ )
+ elif isinstance(sampling_strategy, dict):
+ return OrderedDict(
+ sorted(_sampling_strategy_dict(sampling_strategy, y, sampling_type).items())
+ )
+ elif isinstance(sampling_strategy, list):
+ return OrderedDict(
+ sorted(_sampling_strategy_list(sampling_strategy, y, sampling_type).items())
+ )
+ elif isinstance(sampling_strategy, Real):
+ if sampling_strategy <= 0 or sampling_strategy > 1:
+ raise ValueError(
+ f"When 'sampling_strategy' is a float, it should be "
+ f"in the range (0, 1]. Got {sampling_strategy} instead."
+ )
+ return OrderedDict(
+ sorted(
+ _sampling_strategy_float(sampling_strategy, y, sampling_type).items()
+ )
+ )
+ elif callable(sampling_strategy):
+ sampling_strategy_ = sampling_strategy(y, **kwargs)
+ return OrderedDict(
+ sorted(
+ _sampling_strategy_dict(sampling_strategy_, y, sampling_type).items()
+ )
+ )
+
+
+SAMPLING_TARGET_KIND = {
+ "minority": _sampling_strategy_minority,
+ "majority": _sampling_strategy_majority,
+ "not minority": _sampling_strategy_not_minority,
+ "not majority": _sampling_strategy_not_majority,
+ "all": _sampling_strategy_all,
+ "auto": _sampling_strategy_auto,
+}
def _deprecate_positional_args(f):
@@ -257,9 +601,47 @@ def _deprecate_positional_args(f):
f : function
function to check arguments on.
"""
- pass
+ sig = signature(f)
+ kwonly_args = []
+ all_args = []
+
+ for name, param in sig.parameters.items():
+ if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
+ all_args.append(name)
+ elif param.kind == Parameter.KEYWORD_ONLY:
+ kwonly_args.append(name)
+
+ @wraps(f)
+ def inner_f(*args, **kwargs):
+ extra_args = len(args) - len(all_args)
+ if extra_args > 0:
+ # ignore first 'self' argument for instance methods
+ args_msg = [
+ f"{name}={arg}"
+ for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
+ ]
+ warnings.warn(
+ f"Pass {', '.join(args_msg)} as keyword args. From version 0.9 "
+ f"passing these as positional arguments will "
+ f"result in an error",
+ FutureWarning,
+ )
+ kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
+ return f(**kwargs)
+
+ return inner_f
def _check_X(X):
"""Check X and do not check it if a dataframe."""
- pass
+ n_samples = _num_samples(X)
+ if n_samples < 1:
+ raise ValueError(
+ f"Found array with {n_samples} sample(s) while a minimum of 1 is "
+ "required."
+ )
+ if _is_pandas_df(X):
+ return X
+ return check_array(
+ X, dtype=None, accept_sparse=["csr", "csc"], force_all_finite=False
+ )
diff --git a/imblearn/utils/deprecation.py b/imblearn/utils/deprecation.py
index a630c60..6d459b8 100644
--- a/imblearn/utils/deprecation.py
+++ b/imblearn/utils/deprecation.py
@@ -1,9 +1,12 @@
"""Utilities for deprecation"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import warnings
-def deprecate_parameter(sampler, version_deprecation, param_deprecated,
- new_param=None):
+def deprecate_parameter(sampler, version_deprecation, param_deprecated, new_param=None):
"""Helper to deprecate a parameter by another one.
Parameters
@@ -22,4 +25,22 @@ def deprecate_parameter(sampler, version_deprecation, param_deprecated,
The parameter used instead of the deprecated parameter. By default, no
parameter is expected.
"""
- pass
+ x, y = version_deprecation.split(".")
+ version_removed = x + "." + str(int(y) + 2)
+ if new_param is None:
+ if getattr(sampler, param_deprecated) is not None:
+ warnings.warn(
+ f"'{param_deprecated}' is deprecated from {version_deprecation} and "
+ f" will be removed in {version_removed} for the estimator "
+ f"{sampler.__class__}.",
+ category=FutureWarning,
+ )
+ else:
+ if getattr(sampler, param_deprecated) is not None:
+ warnings.warn(
+ f"'{param_deprecated}' is deprecated from {version_deprecation} and "
+ f"will be removed in {version_removed} for the estimator "
+ f"{sampler.__class__}. Use '{new_param}' instead.",
+ category=FutureWarning,
+ )
+ setattr(sampler, new_param, getattr(sampler, param_deprecated))
diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py
index d3aea67..5704277 100644
--- a/imblearn/utils/estimator_checks.py
+++ b/imblearn/utils/estimator_checks.py
@@ -1,35 +1,153 @@
"""Utils to check the samplers and compatibility with scikit-learn"""
+
+# Adapated from scikit-learn
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import re
import sys
import traceback
import warnings
from collections import Counter
from functools import partial
+
import numpy as np
import pytest
import sklearn
from scipy import sparse
from sklearn.base import clone, is_classifier, is_regressor
from sklearn.cluster import KMeans
-from sklearn.datasets import load_iris, make_blobs, make_classification, make_multilabel_classification
+from sklearn.datasets import ( # noqa
+ load_iris,
+ make_blobs,
+ make_classification,
+ make_multilabel_classification,
+)
from sklearn.exceptions import SkipTestWarning
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.utils._tags import _safe_tags
-from sklearn.utils._testing import SkipTest, assert_allclose, assert_array_equal, assert_raises_regex, raises, set_random_state
-from sklearn.utils.estimator_checks import _enforce_estimator_tags_y, _get_check_estimator_ids, _maybe_mark_xfail
+from sklearn.utils._testing import (
+ SkipTest,
+ assert_allclose,
+ assert_array_equal,
+ assert_raises_regex,
+ raises,
+ set_random_state,
+)
+from sklearn.utils.estimator_checks import (
+ _enforce_estimator_tags_y,
+ _get_check_estimator_ids,
+ _maybe_mark_xfail,
+)
+
try:
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
except ImportError:
- from sklearn.utils.estimator_checks import _enforce_estimator_tags_X as _enforce_estimator_tags_x
+ # scikit-learn >= 1.2
+ from sklearn.utils.estimator_checks import (
+ _enforce_estimator_tags_X as _enforce_estimator_tags_x,
+ )
+
from sklearn.utils.fixes import parse_version
from sklearn.utils.multiclass import type_of_target
+
from imblearn.datasets import make_imbalance
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
from imblearn.utils._param_validation import generate_invalid_param_val, make_constraint
+
sklearn_version = parse_version(sklearn.__version__)
+def sample_dataset_generator():
+ X, y = make_classification(
+ n_samples=1000,
+ n_classes=3,
+ n_informative=4,
+ weights=[0.2, 0.3, 0.5],
+ random_state=0,
+ )
+ return X, y
+
+
+@pytest.fixture(name="sample_dataset_generator")
+def sample_dataset_generator_fixture():
+ return sample_dataset_generator()
+
+
+def _set_checking_parameters(estimator):
+ params = estimator.get_params()
+ name = estimator.__class__.__name__
+ if "n_estimators" in params:
+ estimator.set_params(n_estimators=min(5, estimator.n_estimators))
+ if name == "ClusterCentroids":
+ if sklearn_version < parse_version("1.1"):
+ algorithm = "full"
+ else:
+ algorithm = "lloyd"
+ estimator.set_params(
+ voting="soft",
+ estimator=KMeans(random_state=0, algorithm=algorithm, n_init=1),
+ )
+ if name == "KMeansSMOTE":
+ estimator.set_params(kmeans_estimator=12)
+ if name == "BalancedRandomForestClassifier":
+ # TODO: remove in 0.13
+ # future default in 0.13
+ estimator.set_params(replacement=True, sampling_strategy="all", bootstrap=False)
+
+
+def _yield_sampler_checks(sampler):
+ tags = sampler._get_tags()
+ yield check_target_type
+ yield check_samplers_one_label
+ yield check_samplers_fit
+ yield check_samplers_fit_resample
+ yield check_samplers_sampling_strategy_fit_resample
+ if "sparse" in tags["X_types"]:
+ yield check_samplers_sparse
+ if "dataframe" in tags["X_types"]:
+ yield check_samplers_pandas
+ yield check_samplers_pandas_sparse
+ if "string" in tags["X_types"]:
+ yield check_samplers_string
+ if tags["allow_nan"]:
+ yield check_samplers_nan
+ yield check_samplers_list
+ yield check_samplers_multiclass_ova
+ yield check_samplers_preserve_dtype
+ # we don't filter samplers based on their tag here because we want to make
+ # sure that the fitted attribute does not exist if the tag is not
+ # stipulated
+ yield check_samplers_sample_indices
+ yield check_samplers_2d_target
+ yield check_sampler_get_feature_names_out
+ yield check_sampler_get_feature_names_out_pandas
+
+
+def _yield_classifier_checks(classifier):
+ yield check_classifier_on_multilabel_or_multioutput_targets
+ yield check_classifiers_with_encoded_labels
+
+
+def _yield_all_checks(estimator):
+ name = estimator.__class__.__name__
+ tags = estimator._get_tags()
+ if tags["_skip_test"]:
+ warnings.warn(
+ f"Explicit SKIP via _skip_test tag for estimator {name}.",
+ SkipTestWarning,
+ )
+ return
+ # trigger our checks if this is a SamplerMixin
+ if hasattr(estimator, "fit_resample"):
+ for check in _yield_sampler_checks(estimator):
+ yield check
+ if hasattr(estimator, "predict"):
+ for check in _yield_classifier_checks(estimator):
+ yield check
+
+
def parametrize_with_checks(estimators):
"""Pytest specific decorator for parametrizing estimator checks.
@@ -59,4 +177,652 @@ def parametrize_with_checks(estimators):
... def test_sklearn_compatible_estimator(estimator, check):
... check(estimator)
"""
- pass
+
+ def checks_generator():
+ for estimator in estimators:
+ name = type(estimator).__name__
+ for check in _yield_all_checks(estimator):
+ check = partial(check, name)
+ yield _maybe_mark_xfail(estimator, check, pytest)
+
+ return pytest.mark.parametrize(
+ "estimator, check", checks_generator(), ids=_get_check_estimator_ids
+ )
+
+
+def check_target_type(name, estimator_orig):
+ estimator = clone(estimator_orig)
+ # should raise warning if the target is continuous (we cannot raise error)
+ X = np.random.random((20, 2))
+ y = np.linspace(0, 1, 20)
+ msg = "Unknown label type:"
+ assert_raises_regex(
+ ValueError,
+ msg,
+ estimator.fit_resample,
+ X,
+ y,
+ )
+ # if the target is multilabel then we should raise an error
+ rng = np.random.RandomState(42)
+ y = rng.randint(2, size=(20, 3))
+ msg = "Multilabel and multioutput targets are not supported."
+ assert_raises_regex(
+ ValueError,
+ msg,
+ estimator.fit_resample,
+ X,
+ y,
+ )
+
+
+def check_samplers_one_label(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ error_string_fit = "Sampler can't balance when only one class is present."
+ X = np.random.random((20, 2))
+ y = np.zeros(20)
+ try:
+ sampler.fit_resample(X, y)
+ except ValueError as e:
+ if "class" not in repr(e):
+ print(error_string_fit, sampler.__class__.__name__, e)
+ traceback.print_exc(file=sys.stdout)
+ raise e
+ else:
+ return
+ except Exception as exc:
+ print(error_string_fit, traceback, exc)
+ traceback.print_exc(file=sys.stdout)
+ raise exc
+ raise AssertionError(error_string_fit)
+
+
+def check_samplers_fit(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ np.random.seed(42) # Make this test reproducible
+ X = np.random.random((30, 2))
+ y = np.array([1] * 20 + [0] * 10)
+ sampler.fit_resample(X, y)
+ assert hasattr(
+ sampler, "sampling_strategy_"
+ ), "No fitted attribute sampling_strategy_"
+
+
+def check_samplers_fit_resample(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ X, y = sample_dataset_generator()
+ target_stats = Counter(y)
+ X_res, y_res = sampler.fit_resample(X, y)
+ if isinstance(sampler, BaseOverSampler):
+ target_stats_res = Counter(y_res)
+ n_samples = max(target_stats.values())
+ assert all(value >= n_samples for value in Counter(y_res).values())
+ elif isinstance(sampler, BaseUnderSampler):
+ n_samples = min(target_stats.values())
+ if name == "InstanceHardnessThreshold":
+ # IHT does not enforce the number of samples but provide a number
+ # of samples the closest to the desired target.
+ assert all(
+ Counter(y_res)[k] <= target_stats[k] for k in target_stats.keys()
+ )
+ else:
+ assert all(value == n_samples for value in Counter(y_res).values())
+ elif isinstance(sampler, BaseCleaningSampler):
+ target_stats_res = Counter(y_res)
+ class_minority = min(target_stats, key=target_stats.get)
+ assert all(
+ target_stats[class_sample] > target_stats_res[class_sample]
+ for class_sample in target_stats.keys()
+ if class_sample != class_minority
+ )
+
+
+def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ # in this test we will force all samplers to not change the class 1
+ X, y = sample_dataset_generator()
+ expected_stat = Counter(y)[1]
+ if isinstance(sampler, BaseOverSampler):
+ sampling_strategy = {2: 498, 0: 498}
+ sampler.set_params(sampling_strategy=sampling_strategy)
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert Counter(y_res)[1] == expected_stat
+ elif isinstance(sampler, BaseUnderSampler):
+ sampling_strategy = {2: 201, 0: 201}
+ sampler.set_params(sampling_strategy=sampling_strategy)
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert Counter(y_res)[1] == expected_stat
+ elif isinstance(sampler, BaseCleaningSampler):
+ sampling_strategy = [2, 0]
+ sampler.set_params(sampling_strategy=sampling_strategy)
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert Counter(y_res)[1] == expected_stat
+
+
+def check_samplers_sparse(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ # check that sparse matrices can be passed through the sampler leading to
+ # the same results than dense
+ X, y = sample_dataset_generator()
+ X_sparse = sparse.csr_matrix(X)
+ X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
+ sampler = clone(sampler)
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert sparse.issparse(X_res_sparse)
+ assert_allclose(X_res_sparse.A, X_res, rtol=1e-5)
+ assert_allclose(y_res_sparse, y_res)
+
+
+def check_samplers_pandas_sparse(name, sampler_orig):
+ pd = pytest.importorskip("pandas")
+ sampler = clone(sampler_orig)
+ # Check that the samplers handle pandas dataframe and pandas series
+ X, y = sample_dataset_generator()
+ X_df = pd.DataFrame(
+ X, columns=[str(i) for i in range(X.shape[1])], dtype=pd.SparseDtype(float, 0)
+ )
+ y_s = pd.Series(y, name="class")
+
+ X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
+ X_res, y_res = sampler.fit_resample(X, y)
+
+ # check that we return the same type for dataframes or series types
+ assert isinstance(X_res_df, pd.DataFrame)
+ assert isinstance(y_res_s, pd.Series)
+
+ for column_dtype in X_res_df.dtypes:
+ assert isinstance(column_dtype, pd.SparseDtype)
+
+ assert X_df.columns.tolist() == X_res_df.columns.tolist()
+ assert y_s.name == y_res_s.name
+
+ # FIXME: we should use to_numpy with pandas >= 0.25
+ assert_allclose(X_res_df.values, X_res)
+ assert_allclose(y_res_s.values, y_res)
+
+
+def check_samplers_pandas(name, sampler_orig):
+ pd = pytest.importorskip("pandas")
+ sampler = clone(sampler_orig)
+ # Check that the samplers handle pandas dataframe and pandas series
+ X, y = sample_dataset_generator()
+ X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
+ y_df = pd.DataFrame(y)
+ y_s = pd.Series(y, name="class")
+
+ X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
+ X_res_df, y_res_df = sampler.fit_resample(X_df, y_df)
+ X_res, y_res = sampler.fit_resample(X, y)
+
+ # check that we return the same type for dataframes or series types
+ assert isinstance(X_res_df, pd.DataFrame)
+ assert isinstance(y_res_df, pd.DataFrame)
+ assert isinstance(y_res_s, pd.Series)
+
+ assert X_df.columns.tolist() == X_res_df.columns.tolist()
+ assert y_df.columns.tolist() == y_res_df.columns.tolist()
+ assert y_s.name == y_res_s.name
+
+ # FIXME: we should use to_numpy with pandas >= 0.25
+ assert_allclose(X_res_df.values, X_res)
+ assert_allclose(y_res_df.values.ravel(), y_res)
+ assert_allclose(y_res_s.values, y_res)
+
+
+def check_samplers_list(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ # Check that the can samplers handle simple lists
+ X, y = sample_dataset_generator()
+ X_list = X.tolist()
+ y_list = y.tolist()
+
+ X_res, y_res = sampler.fit_resample(X, y)
+ X_res_list, y_res_list = sampler.fit_resample(X_list, y_list)
+
+ assert isinstance(X_res_list, list)
+ assert isinstance(y_res_list, list)
+
+ assert_allclose(X_res, X_res_list)
+ assert_allclose(y_res, y_res_list)
+
+
+def check_samplers_multiclass_ova(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ # Check that multiclass target lead to the same results than OVA encoding
+ X, y = sample_dataset_generator()
+ y_ova = label_binarize(y, classes=np.unique(y))
+ X_res, y_res = sampler.fit_resample(X, y)
+ X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
+ assert_allclose(X_res, X_res_ova)
+ assert type_of_target(y_res_ova) == type_of_target(y_ova)
+ assert_allclose(y_res, y_res_ova.argmax(axis=1))
+
+
+def check_samplers_2d_target(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ X, y = sample_dataset_generator()
+
+ y = y.reshape(-1, 1) # Make the target 2d
+ sampler.fit_resample(X, y)
+
+
+def check_samplers_preserve_dtype(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ X, y = sample_dataset_generator()
+ # Cast X and y to not default dtype
+ X = X.astype(np.float32)
+ y = y.astype(np.int32)
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert X.dtype == X_res.dtype, "X dtype is not preserved"
+ assert y.dtype == y_res.dtype, "y dtype is not preserved"
+
+
+def check_samplers_sample_indices(name, sampler_orig):
+ sampler = clone(sampler_orig)
+ X, y = sample_dataset_generator()
+ sampler.fit_resample(X, y)
+ sample_indices = sampler._get_tags().get("sample_indices", None)
+ if sample_indices:
+ assert hasattr(sampler, "sample_indices_") is sample_indices
+ else:
+ assert not hasattr(sampler, "sample_indices_")
+
+
+def check_samplers_string(name, sampler_orig):
+ rng = np.random.RandomState(0)
+ sampler = clone(sampler_orig)
+ categories = np.array(["A", "B", "C"], dtype=object)
+ n_samples = 30
+ X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1)
+ X = categories[X]
+ y = rng.permutation([0] * 10 + [1] * 20)
+
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert X_res.dtype == object
+ assert X_res.shape[0] == y_res.shape[0]
+ assert_array_equal(np.unique(X_res.ravel()), categories)
+
+
+def check_samplers_nan(name, sampler_orig):
+ rng = np.random.RandomState(0)
+ sampler = clone(sampler_orig)
+ categories = np.array([0, 1, np.nan], dtype=np.float64)
+ n_samples = 100
+ X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1)
+ X = categories[X]
+ y = rng.permutation([0] * 40 + [1] * 60)
+
+ X_res, y_res = sampler.fit_resample(X, y)
+ assert X_res.dtype == np.float64
+ assert X_res.shape[0] == y_res.shape[0]
+ assert np.any(np.isnan(X_res.ravel()))
+
+
+def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
+ estimator = clone(estimator_orig)
+ X, y = make_multilabel_classification(n_samples=30)
+ msg = "Multilabel and multioutput targets are not supported."
+ with pytest.raises(ValueError, match=msg):
+ estimator.fit(X, y)
+
+
+def check_classifiers_with_encoded_labels(name, classifier_orig):
+ # Non-regression test for #709
+ # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
+ pd = pytest.importorskip("pandas")
+ classifier = clone(classifier_orig)
+ iris = load_iris(as_frame=True)
+ df, y = iris.data, iris.target
+ y = pd.Series(iris.target_names[iris.target], dtype="category")
+ df, y = make_imbalance(
+ df,
+ y,
+ sampling_strategy={
+ "setosa": 30,
+ "versicolor": 20,
+ "virginica": 50,
+ },
+ )
+ classifier.set_params(sampling_strategy={"setosa": 20, "virginica": 20})
+ classifier.fit(df, y)
+ assert set(classifier.classes_) == set(y.cat.categories.tolist())
+ y_pred = classifier.predict(df)
+ assert set(y_pred) == set(y.cat.categories.tolist())
+
+
+def check_param_validation(name, estimator_orig):
+ # Check that an informative error is raised when the value of a constructor
+ # parameter does not have an appropriate type or value.
+ rng = np.random.RandomState(0)
+ X = rng.uniform(size=(20, 5))
+ y = rng.randint(0, 2, size=20)
+ y = _enforce_estimator_tags_y(estimator_orig, y)
+
+ estimator_params = estimator_orig.get_params(deep=False).keys()
+
+ # check that there is a constraint for each parameter
+ if estimator_params:
+ validation_params = estimator_orig._parameter_constraints.keys()
+ unexpected_params = set(validation_params) - set(estimator_params)
+ missing_params = set(estimator_params) - set(validation_params)
+ err_msg = (
+ f"Mismatch between _parameter_constraints and the parameters of {name}."
+ f"\nConsider the unexpected parameters {unexpected_params} and expected but"
+ f" missing parameters {missing_params}"
+ )
+ assert validation_params == estimator_params, err_msg
+
+ # this object does not have a valid type for sure for all params
+ param_with_bad_type = type("BadType", (), {})()
+
+ fit_methods = ["fit", "partial_fit", "fit_transform", "fit_predict", "fit_resample"]
+
+ for param_name in estimator_params:
+ constraints = estimator_orig._parameter_constraints[param_name]
+
+ if constraints == "no_validation":
+ # This parameter is not validated
+ continue # pragma: no cover
+
+ match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead."
+ err_msg = (
+ f"{name} does not raise an informative error message when the "
+ f"parameter {param_name} does not have a valid type or value."
+ )
+
+ estimator = clone(estimator_orig)
+
+ # First, check that the error is raised if param doesn't match any valid type.
+ estimator.set_params(**{param_name: param_with_bad_type})
+
+ for method in fit_methods:
+ if not hasattr(estimator, method):
+ # the method is not accessible with the current set of parameters
+ continue
+
+ with raises(ValueError, match=match, err_msg=err_msg):
+ if any(
+ isinstance(X_type, str) and X_type.endswith("labels")
+ for X_type in _safe_tags(estimator, key="X_types")
+ ):
+ # The estimator is a label transformer and take only `y`
+ getattr(estimator, method)(y) # pragma: no cover
+ else:
+ getattr(estimator, method)(X, y)
+
+ # Then, for constraints that are more than a type constraint, check that the
+ # error is raised if param does match a valid type but does not match any valid
+ # value for this type.
+ constraints = [make_constraint(constraint) for constraint in constraints]
+
+ for constraint in constraints:
+ try:
+ bad_value = generate_invalid_param_val(constraint)
+ except NotImplementedError:
+ continue
+
+ estimator.set_params(**{param_name: bad_value})
+
+ for method in fit_methods:
+ if not hasattr(estimator, method):
+ # the method is not accessible with the current set of parameters
+ continue
+
+ with raises(ValueError, match=match, err_msg=err_msg):
+ if any(
+ X_type.endswith("labels")
+ for X_type in _safe_tags(estimator, key="X_types")
+ ):
+ # The estimator is a label transformer and take only `y`
+ getattr(estimator, method)(y) # pragma: no cover
+ else:
+ getattr(estimator, method)(X, y)
+
+
+def check_dataframe_column_names_consistency(name, estimator_orig):
+ try:
+ import pandas as pd
+ except ImportError:
+ raise SkipTest(
+ "pandas is not installed: not checking column name consistency for pandas"
+ )
+
+ tags = _safe_tags(estimator_orig)
+ is_supported_X_types = (
+ "2darray" in tags["X_types"] or "categorical" in tags["X_types"]
+ )
+
+ if not is_supported_X_types or tags["no_validation"]:
+ return
+
+ rng = np.random.RandomState(0)
+
+ estimator = clone(estimator_orig)
+ set_random_state(estimator)
+
+ X_orig = rng.normal(size=(150, 8))
+
+ X_orig = _enforce_estimator_tags_x(estimator, X_orig)
+ n_samples, n_features = X_orig.shape
+
+ names = np.array([f"col_{i}" for i in range(n_features)])
+ X = pd.DataFrame(X_orig, columns=names)
+
+ if is_regressor(estimator):
+ y = rng.normal(size=n_samples)
+ else:
+ y = rng.randint(low=0, high=2, size=n_samples)
+ y = _enforce_estimator_tags_y(estimator, y)
+
+ # Check that calling `fit` does not raise any warnings about feature names.
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "error",
+ message="X does not have valid feature names",
+ category=UserWarning,
+ module="imblearn",
+ )
+ estimator.fit(X, y)
+
+ if not hasattr(estimator, "feature_names_in_"):
+ raise ValueError(
+ "Estimator does not have a feature_names_in_ "
+ "attribute after fitting with a dataframe"
+ )
+ assert isinstance(estimator.feature_names_in_, np.ndarray)
+ assert estimator.feature_names_in_.dtype == object
+ assert_array_equal(estimator.feature_names_in_, names)
+
+ # Only check imblearn estimators for feature_names_in_ in docstring
+ module_name = estimator_orig.__module__
+ if (
+ module_name.startswith("imblearn.")
+ and not ("test_" in module_name or module_name.endswith("_testing"))
+ and ("feature_names_in_" not in (estimator_orig.__doc__))
+ ):
+ raise ValueError(
+ f"Estimator {name} does not document its feature_names_in_ attribute"
+ )
+
+ check_methods = []
+ for method in (
+ "predict",
+ "transform",
+ "decision_function",
+ "predict_proba",
+ "score",
+ "score_samples",
+ "predict_log_proba",
+ ):
+ if not hasattr(estimator, method):
+ continue
+
+ callable_method = getattr(estimator, method)
+ if method == "score":
+ callable_method = partial(callable_method, y=y)
+ check_methods.append((method, callable_method))
+
+ for _, method in check_methods:
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "error",
+ message="X does not have valid feature names",
+ category=UserWarning,
+ module="sklearn",
+ )
+ method(X) # works without UserWarning for valid features
+
+ invalid_names = [
+ (names[::-1], "Feature names must be in the same order as they were in fit."),
+ (
+ [f"another_prefix_{i}" for i in range(n_features)],
+ "Feature names unseen at fit time:\n- another_prefix_0\n-"
+ " another_prefix_1\n",
+ ),
+ (
+ names[:3],
+ f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n",
+ ),
+ ]
+ params = {
+ key: value
+ for key, value in estimator.get_params().items()
+ if "early_stopping" in key
+ }
+ early_stopping_enabled = any(value is True for value in params.values())
+
+ for invalid_name, additional_message in invalid_names:
+ X_bad = pd.DataFrame(X, columns=invalid_name)
+
+ for name, method in check_methods:
+ if sklearn_version >= parse_version("1.2"):
+ expected_msg = re.escape(
+ "The feature names should match those that were passed during fit."
+ f"\n{additional_message}"
+ )
+ with raises(
+ ValueError, match=expected_msg, err_msg=f"{name} did not raise"
+ ):
+ method(X_bad)
+ else:
+ expected_msg = re.escape(
+ "The feature names should match those that were passed "
+ "during fit. Starting version 1.2, an error will be raised.\n"
+ f"{additional_message}"
+ )
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "error",
+ category=FutureWarning,
+ module="sklearn",
+ )
+ with raises(
+ FutureWarning,
+ match=expected_msg,
+ err_msg=f"{name} did not raise",
+ ):
+ method(X_bad)
+
+ # partial_fit checks on second call
+ # Do not call partial fit if early_stopping is on
+ if not hasattr(estimator, "partial_fit") or early_stopping_enabled:
+ continue
+
+ estimator = clone(estimator_orig)
+ if is_classifier(estimator):
+ classes = np.unique(y)
+ estimator.partial_fit(X, y, classes=classes)
+ else:
+ estimator.partial_fit(X, y)
+
+ with raises(ValueError, match=expected_msg):
+ estimator.partial_fit(X_bad, y)
+
+
+def check_sampler_get_feature_names_out(name, sampler_orig):
+ tags = sampler_orig._get_tags()
+ if "2darray" not in tags["X_types"] or tags["no_validation"]:
+ return
+
+ X, y = make_blobs(
+ n_samples=30,
+ centers=[[0, 0, 0], [1, 1, 1]],
+ random_state=0,
+ n_features=2,
+ cluster_std=0.1,
+ )
+ X = StandardScaler().fit_transform(X)
+
+ sampler = clone(sampler_orig)
+ X = _enforce_estimator_tags_x(sampler, X)
+
+ n_features = X.shape[1]
+ set_random_state(sampler)
+
+ y_ = y
+ X_res, y_res = sampler.fit_resample(X, y=y_)
+ input_features = [f"feature{i}" for i in range(n_features)]
+
+ # input_features names is not the same length as n_features_in_
+ with raises(ValueError, match="input_features should have length equal"):
+ sampler.get_feature_names_out(input_features[::2])
+
+ feature_names_out = sampler.get_feature_names_out(input_features)
+ assert feature_names_out is not None
+ assert isinstance(feature_names_out, np.ndarray)
+ assert feature_names_out.dtype == object
+ assert all(isinstance(name, str) for name in feature_names_out)
+
+ n_features_out = X_res.shape[1]
+
+ assert (
+ len(feature_names_out) == n_features_out
+ ), f"Expected {n_features_out} feature names, got {len(feature_names_out)}"
+
+
+def check_sampler_get_feature_names_out_pandas(name, sampler_orig):
+ try:
+ import pandas as pd
+ except ImportError:
+ raise SkipTest(
+ "pandas is not installed: not checking column name consistency for pandas"
+ )
+
+ tags = sampler_orig._get_tags()
+ if "2darray" not in tags["X_types"] or tags["no_validation"]:
+ return
+
+ X, y = make_blobs(
+ n_samples=30,
+ centers=[[0, 0, 0], [1, 1, 1]],
+ random_state=0,
+ n_features=2,
+ cluster_std=0.1,
+ )
+ X = StandardScaler().fit_transform(X)
+
+ sampler = clone(sampler_orig)
+ X = _enforce_estimator_tags_x(sampler, X)
+
+ n_features = X.shape[1]
+ set_random_state(sampler)
+
+ y_ = y
+ feature_names_in = [f"col{i}" for i in range(n_features)]
+ df = pd.DataFrame(X, columns=feature_names_in)
+ X_res, y_res = sampler.fit_resample(df, y=y_)
+
+ # error is raised when `input_features` do not match feature_names_in
+ invalid_feature_names = [f"bad{i}" for i in range(n_features)]
+ with raises(ValueError, match="input_features is not equal to feature_names_in_"):
+ sampler.get_feature_names_out(invalid_feature_names)
+
+ feature_names_out_default = sampler.get_feature_names_out()
+ feature_names_in_explicit_names = sampler.get_feature_names_out(feature_names_in)
+ assert_array_equal(feature_names_out_default, feature_names_in_explicit_names)
+
+ n_features_out = X_res.shape[1]
+
+ assert (
+ len(feature_names_out_default) == n_features_out
+ ), f"Expected {n_features_out} feature names, got {len(feature_names_out_default)}"
diff --git a/imblearn/utils/fixes.py b/imblearn/utils/fixes.py
index 801067f..023d8a1 100644
--- a/imblearn/utils/fixes.py
+++ b/imblearn/utils/fixes.py
@@ -6,23 +6,39 @@ which the fix is no longer needed.
"""
import functools
import sys
+
import numpy as np
import scipy
import scipy.stats
import sklearn
from sklearn.utils.fixes import parse_version
+
from .._config import config_context, get_config
+
sp_version = parse_version(scipy.__version__)
sklearn_version = parse_version(sklearn.__version__)
-if sklearn_version >= parse_version('1.1'):
+
+
+# TODO: Remove when SciPy 1.9 is the minimum supported version
+def _mode(a, axis=0):
+ if sp_version >= parse_version("1.9.0"):
+ return scipy.stats.mode(a, axis=axis, keepdims=True)
+ return scipy.stats.mode(a, axis=axis)
+
+
+# TODO: Remove when scikit-learn 1.1 is the minimum supported version
+if sklearn_version >= parse_version("1.1"):
from sklearn.utils.validation import _is_arraylike_not_scalar
else:
from sklearn.utils.validation import _is_arraylike
def _is_arraylike_not_scalar(array):
"""Return True if array is array-like and not a scalar"""
- pass
-if sklearn_version < parse_version('1.3'):
+ return _is_arraylike(array) and not np.isscalar(array)
+
+
+# TODO: remove when scikit-learn minimum version is 1.3
+if sklearn_version < parse_version("1.3"):
def _fit_context(*, prefer_skip_nested_validation):
"""Decorator to run the fit methods of estimators within context managers.
@@ -47,10 +63,36 @@ if sklearn_version < parse_version('1.3'):
decorated_fit : method
The decorated fit method.
"""
- pass
+
+ def decorator(fit_method):
+ @functools.wraps(fit_method)
+ def wrapper(estimator, *args, **kwargs):
+ global_skip_validation = get_config()["skip_parameter_validation"]
+
+ # we don't want to validate again for each call to partial_fit
+ partial_fit_and_fitted = (
+ fit_method.__name__ == "partial_fit" and _is_fitted(estimator)
+ )
+
+ if not global_skip_validation and not partial_fit_and_fitted:
+ estimator._validate_params()
+
+ with config_context(
+ skip_parameter_validation=(
+ prefer_skip_nested_validation or global_skip_validation
+ )
+ ):
+ return fit_method(estimator, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
else:
- from sklearn.base import _fit_context
-if sklearn_version < parse_version('1.3'):
+ from sklearn.base import _fit_context # type: ignore[no-redef] # noqa
+
+# TODO: remove when scikit-learn minimum version is 1.3
+if sklearn_version < parse_version("1.3"):
def _is_fitted(estimator, attributes=None, all_or_any=all):
"""Determine if an estimator is fitted
@@ -76,13 +118,33 @@ if sklearn_version < parse_version('1.3'):
fitted : bool
Whether the estimator is fitted.
"""
- pass
+ if attributes is not None:
+ if not isinstance(attributes, (list, tuple)):
+ attributes = [attributes]
+ return all_or_any([hasattr(estimator, attr) for attr in attributes])
+
+ if hasattr(estimator, "__sklearn_is_fitted__"):
+ return estimator.__sklearn_is_fitted__()
+
+ fitted_attrs = [
+ v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
+ ]
+ return len(fitted_attrs) > 0
+
else:
- from sklearn.utils.validation import _is_fitted
+ from sklearn.utils.validation import _is_fitted # type: ignore[no-redef]
+
try:
from sklearn.utils.validation import _is_pandas_df
except ImportError:
def _is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
- pass
+ if hasattr(X, "columns") and hasattr(X, "iloc"):
+ # Likely a pandas DataFrame, we explicitly check the type to confirm.
+ try:
+ pd = sys.modules["pandas"]
+ except KeyError:
+ return False
+ return isinstance(X, pd.DataFrame)
+ return False
diff --git a/imblearn/utils/testing.py b/imblearn/utils/testing.py
index aa344b1..8c19d61 100644
--- a/imblearn/utils/testing.py
+++ b/imblearn/utils/testing.py
@@ -1,9 +1,15 @@
"""Test utilities."""
+
+# Adapted from scikit-learn
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import inspect
import pkgutil
from importlib import import_module
from operator import itemgetter
from pathlib import Path
+
import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator
@@ -11,7 +17,9 @@ from sklearn.neighbors import KDTree
from sklearn.utils._testing import ignore_warnings
-def all_estimators(type_filter=None):
+def all_estimators(
+ type_filter=None,
+):
"""Get a list of all estimators from imblearn.
This function crawls the module and gets all classes that inherit
@@ -35,7 +43,73 @@ def all_estimators(type_filter=None):
List of (name, class), where ``name`` is the class name as string
and ``class`` is the actual type of the class.
"""
- pass
+ from ..base import SamplerMixin
+
+ def is_abstract(c):
+ if not (hasattr(c, "__abstractmethods__")):
+ return False
+ if not len(c.__abstractmethods__):
+ return False
+ return True
+
+ all_classes = []
+ modules_to_ignore = {"tests"}
+ root = str(Path(__file__).parent.parent)
+ # Ignore deprecation warnings triggered at import time and from walking
+ # packages
+ with ignore_warnings(category=FutureWarning):
+ for importer, modname, ispkg in pkgutil.walk_packages(
+ path=[root], prefix="imblearn."
+ ):
+ mod_parts = modname.split(".")
+ if any(part in modules_to_ignore for part in mod_parts) or "._" in modname:
+ continue
+ module = import_module(modname)
+ classes = inspect.getmembers(module, inspect.isclass)
+ classes = [
+ (name, est_cls) for name, est_cls in classes if not name.startswith("_")
+ ]
+
+ all_classes.extend(classes)
+
+ all_classes = set(all_classes)
+
+ estimators = [
+ c
+ for c in all_classes
+ if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
+ ]
+ # get rid of abstract base classes
+ estimators = [c for c in estimators if not is_abstract(c[1])]
+
+ # get rid of sklearn estimators which have been imported in some classes
+ estimators = [c for c in estimators if "sklearn" not in c[1].__module__]
+
+ if type_filter is not None:
+ if not isinstance(type_filter, list):
+ type_filter = [type_filter]
+ else:
+ type_filter = list(type_filter) # copy
+ filtered_estimators = []
+ filters = {"sampler": SamplerMixin}
+ for name, mixin in filters.items():
+ if name in type_filter:
+ type_filter.remove(name)
+ filtered_estimators.extend(
+ [est for est in estimators if issubclass(est[1], mixin)]
+ )
+ estimators = filtered_estimators
+ if type_filter:
+ raise ValueError(
+ "Parameter type_filter must be 'sampler' or "
+ "None, got"
+ " %s." % repr(type_filter)
+ )
+
+ # drop duplicates, sort for reproducibility
+ # itemgetter is used to ensure the sort does not extend to the 2nd item of
+ # the tuple
+ return sorted(set(estimators), key=itemgetter(0))
class _CustomNearestNeighbors(BaseEstimator):
@@ -44,11 +118,24 @@ class _CustomNearestNeighbors(BaseEstimator):
`kneighbors_graph` is ignored and `metric` does not have any impact.
"""
- def __init__(self, n_neighbors=1, metric='euclidean'):
+ def __init__(self, n_neighbors=1, metric="euclidean"):
self.n_neighbors = n_neighbors
self.metric = metric
- def kneighbors_graph(X=None, n_neighbors=None, mode='connectivity'):
+ def fit(self, X, y=None):
+ X = X.toarray() if sparse.issparse(X) else X
+ self._kd_tree = KDTree(X)
+ return self
+
+ def kneighbors(self, X, n_neighbors=None, return_distance=True):
+ n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
+ X = X.toarray() if sparse.issparse(X) else X
+ distances, indices = self._kd_tree.query(X, k=n_neighbors)
+ if return_distance:
+ return distances, indices
+ return indices
+
+ def kneighbors_graph(X=None, n_neighbors=None, mode="connectivity"):
"""This method is not used within imblearn but it is required for
duck-typing."""
pass
@@ -60,3 +147,11 @@ class _CustomClusterer(BaseEstimator):
def __init__(self, n_clusters=1, expose_cluster_centers=True):
self.n_clusters = n_clusters
self.expose_cluster_centers = expose_cluster_centers
+
+ def fit(self, X, y=None):
+ if self.expose_cluster_centers:
+ self.cluster_centers_ = np.random.randn(self.n_clusters, X.shape[1])
+ return self
+
+ def predict(self, X):
+ return np.zeros(len(X), dtype=int)
diff --git a/imblearn/utils/tests/test_deprecation.py b/imblearn/utils/tests/test_deprecation.py
index 2411624..f7e084d 100644
--- a/imblearn/utils/tests/test_deprecation.py
+++ b/imblearn/utils/tests/test_deprecation.py
@@ -1,10 +1,21 @@
"""Test for the deprecation helper"""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import pytest
+
from imblearn.utils.deprecation import deprecate_parameter
class Sampler:
-
def __init__(self):
- self.a = 'something'
- self.b = 'something'
+ self.a = "something"
+ self.b = "something"
+
+
+def test_deprecate_parameter():
+ with pytest.warns(FutureWarning, match="is deprecated from"):
+ deprecate_parameter(Sampler(), "0.2", "a")
+ with pytest.warns(FutureWarning, match="Use 'b' instead."):
+ deprecate_parameter(Sampler(), "0.2", "a", "b")
diff --git a/imblearn/utils/tests/test_docstring.py b/imblearn/utils/tests/test_docstring.py
index f377d75..4a07536 100644
--- a/imblearn/utils/tests/test_docstring.py
+++ b/imblearn/utils/tests/test_docstring.py
@@ -1,7 +1,13 @@
"""Test utilities for docstring."""
+
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# License: MIT
+
import sys
import textwrap
+
import pytest
+
from imblearn.utils import Substitution
from imblearn.utils._docstring import _n_jobs_docstring, _random_state_docstring
@@ -11,7 +17,7 @@ def _dedent_docstring(docstring):
xref: https://github.com/python/cpython/issues/81283
"""
- pass
+ return "\n".join([textwrap.dedent(line) for line in docstring.split("\n")])
func_docstring = """A function.
@@ -33,7 +39,7 @@ def func(param_1, param_2):
{param_2}
"""
- pass
+ return param_1, param_2
cls_docstring = """A class.
@@ -66,10 +72,28 @@ if sys.version_info >= (3, 13):
cls_docstring = _dedent_docstring(cls_docstring)
+@pytest.mark.parametrize(
+ "obj, obj_docstring", [(func, func_docstring), (cls, cls_docstring)]
+)
+def test_docstring_inject(obj, obj_docstring):
+ obj_injected_docstring = Substitution(param_1="xxx", param_2="yyy")(obj)
+ assert obj_injected_docstring.__doc__ == obj_docstring
+
+
+def test_docstring_template():
+ assert "random_state" in _random_state_docstring
+ assert "n_jobs" in _n_jobs_docstring
+
+
def test_docstring_with_python_OO():
"""Check that we don't raise a warning if the code is executed with -OO.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/945
"""
- pass
+ instance = cls(param_1="xxx", param_2="yyy")
+ instance.__doc__ = None # simulate -OO
+
+ instance = Substitution(param_1="xxx", param_2="yyy")(instance)
+
+ assert instance.__doc__ is None
diff --git a/imblearn/utils/tests/test_estimator_checks.py b/imblearn/utils/tests/test_estimator_checks.py
index e93b1c3..ca704f2 100644
--- a/imblearn/utils/tests/test_estimator_checks.py
+++ b/imblearn/utils/tests/test_estimator_checks.py
@@ -2,42 +2,121 @@ import numpy as np
import pytest
from sklearn.base import BaseEstimator
from sklearn.utils.multiclass import check_classification_targets
+
from imblearn.base import BaseSampler
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.utils import check_target_type as target_check
-from imblearn.utils.estimator_checks import check_samplers_fit, check_samplers_nan, check_samplers_one_label, check_samplers_preserve_dtype, check_samplers_sparse, check_samplers_string, check_target_type
+from imblearn.utils.estimator_checks import (
+ check_samplers_fit,
+ check_samplers_nan,
+ check_samplers_one_label,
+ check_samplers_preserve_dtype,
+ check_samplers_sparse,
+ check_samplers_string,
+ check_target_type,
+)
class BaseBadSampler(BaseEstimator):
"""Sampler without inputs checking."""
- _sampling_type = 'bypass'
+
+ _sampling_type = "bypass"
+
+ def fit(self, X, y):
+ return self
+
+ def fit_resample(self, X, y):
+ check_classification_targets(y)
+ self.fit(X, y)
+ return X, y
class SamplerSingleClass(BaseSampler):
"""Sampler that would sample even with a single class."""
- _sampling_type = 'bypass'
+
+ _sampling_type = "bypass"
+
+ def fit_resample(self, X, y):
+ return self._fit_resample(X, y)
+
+ def _fit_resample(self, X, y):
+ return X, y
class NotFittedSampler(BaseBadSampler):
"""Sampler without target checking."""
+ def fit(self, X, y):
+ X, y = self._validate_data(X, y)
+ return self
+
class NoAcceptingSparseSampler(BaseBadSampler):
"""Sampler which does not accept sparse matrix."""
+ def fit(self, X, y):
+ X, y = self._validate_data(X, y)
+ self.sampling_strategy_ = "sampling_strategy_"
+ return self
+
class NotPreservingDtypeSampler(BaseSampler):
- _sampling_type = 'bypass'
- _parameter_constraints: dict = {'sampling_strategy': 'no_validation'}
+ _sampling_type = "bypass"
+
+ _parameter_constraints: dict = {"sampling_strategy": "no_validation"}
+
+ def _fit_resample(self, X, y):
+ return X.astype(np.float64), y.astype(np.int64)
class IndicesSampler(BaseOverSampler):
- pass
+ def _check_X_y(self, X, y):
+ y, binarize_y = target_check(y, indicate_one_vs_all=True)
+ X, y = self._validate_data(
+ X,
+ y,
+ reset=True,
+ dtype=None,
+ force_all_finite=False,
+ )
+ return X, y, binarize_y
+
+ def _fit_resample(self, X, y):
+ n_max_count_class = np.bincount(y).max()
+ indices = np.random.choice(np.arange(X.shape[0]), size=n_max_count_class * 2)
+ return X[indices], y[indices]
+
+
+def test_check_samplers_string():
+ sampler = IndicesSampler()
+ check_samplers_string(sampler.__class__.__name__, sampler)
+
+
+def test_check_samplers_nan():
+ sampler = IndicesSampler()
+ check_samplers_nan(sampler.__class__.__name__, sampler)
+
+
+mapping_estimator_error = {
+ "BaseBadSampler": (AssertionError, "ValueError not raised by fit"),
+ "SamplerSingleClass": (AssertionError, "Sampler can't balance when only"),
+ "NotFittedSampler": (AssertionError, "No fitted attribute"),
+ "NoAcceptingSparseSampler": (TypeError, "dense data is required"),
+ "NotPreservingDtypeSampler": (AssertionError, "X dtype is not preserved"),
+}
+
+
+def _test_single_check(Estimator, check):
+ estimator = Estimator()
+ name = estimator.__class__.__name__
+ err_type, err_msg = mapping_estimator_error[name]
+ with pytest.raises(err_type, match=err_msg):
+ check(name, estimator)
-mapping_estimator_error = {'BaseBadSampler': (AssertionError,
- 'ValueError not raised by fit'), 'SamplerSingleClass': (AssertionError,
- "Sampler can't balance when only"), 'NotFittedSampler': (AssertionError,
- 'No fitted attribute'), 'NoAcceptingSparseSampler': (TypeError,
- 'dense data is required'), 'NotPreservingDtypeSampler': (AssertionError,
- 'X dtype is not preserved')}
+def test_all_checks():
+ _test_single_check(BaseBadSampler, check_target_type)
+ _test_single_check(SamplerSingleClass, check_samplers_one_label)
+ _test_single_check(NotFittedSampler, check_samplers_fit)
+ _test_single_check(NoAcceptingSparseSampler, check_samplers_sparse)
+ _test_single_check(NotPreservingDtypeSampler, check_samplers_preserve_dtype)
diff --git a/imblearn/utils/tests/test_min_dependencies.py b/imblearn/utils/tests/test_min_dependencies.py
index d25700b..cd53703 100644
--- a/imblearn/utils/tests/test_min_dependencies.py
+++ b/imblearn/utils/tests/test_min_dependencies.py
@@ -3,7 +3,49 @@ import os
import platform
import re
from pathlib import Path
+
import pytest
from sklearn.utils.fixes import parse_version
+
import imblearn
from imblearn._min_dependencies import dependent_packages
+
+
+@pytest.mark.skipif(
+ platform.system() == "Windows", reason="This test is enough on unix system"
+)
+def test_min_dependencies_readme():
+ # Test that the minimum dependencies in the README.rst file are
+ # consistent with the minimum dependencies defined at the file:
+ # imblearn/_min_dependencies.py
+
+ pattern = re.compile(
+ r"(\.\. \|)"
+ + r"(([A-Za-z]+\-?)+)"
+ + r"(MinVersion\| replace::)"
+ + r"( [0-9]+\.[0-9]+(\.[0-9]+)?)"
+ )
+
+ readme_path = Path(imblearn.__path__[0]).parents[0]
+ readme_file = readme_path / "README.rst"
+
+ if not os.path.exists(readme_file):
+ # Skip the test if the README.rst file is not available.
+ # For instance, when installing scikit-learn from wheels
+ pytest.skip("The README.rst file is not available.")
+
+ with readme_file.open("r") as f:
+ for line in f:
+ matched = pattern.match(line)
+
+ if not matched:
+ continue
+
+ package, version = matched.group(2), matched.group(5)
+ package = package.lower()
+
+ if package in dependent_packages:
+ version = parse_version(version)
+ min_version = parse_version(dependent_packages[package][0])
+
+ assert version == min_version, f"{package} has a mismatched version"
diff --git a/imblearn/utils/tests/test_param_validation.py b/imblearn/utils/tests/test_param_validation.py
index 8b0709d..38af664 100644
--- a/imblearn/utils/tests/test_param_validation.py
+++ b/imblearn/utils/tests/test_param_validation.py
@@ -2,61 +2,119 @@
removed when we support scikit-learn >= 1.2.
"""
from numbers import Integral, Real
+
import numpy as np
import pytest
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from sklearn.model_selection import LeaveOneOut
from sklearn.utils import deprecated
+
from imblearn._config import config_context, get_config
from imblearn.base import _ParamsValidationMixin
-from imblearn.utils._param_validation import HasMethods, Hidden, Interval, InvalidParameterError, MissingValues, Options, RealNotInt, StrOptions, _ArrayLikes, _Booleans, _Callables, _CVObjects, _InstancesOf, _IterablesNotString, _NanConstraint, _NoneConstraint, _PandasNAConstraint, _RandomStates, _SparseMatrices, _VerboseHelper, generate_invalid_param_val, generate_valid_param, make_constraint, validate_params
+from imblearn.utils._param_validation import (
+ HasMethods,
+ Hidden,
+ Interval,
+ InvalidParameterError,
+ MissingValues,
+ Options,
+ RealNotInt,
+ StrOptions,
+ _ArrayLikes,
+ _Booleans,
+ _Callables,
+ _CVObjects,
+ _InstancesOf,
+ _IterablesNotString,
+ _NanConstraint,
+ _NoneConstraint,
+ _PandasNAConstraint,
+ _RandomStates,
+ _SparseMatrices,
+ _VerboseHelper,
+ generate_invalid_param_val,
+ generate_valid_param,
+ make_constraint,
+ validate_params,
+)
from imblearn.utils.fixes import _fit_context
-@validate_params({'a': [Real], 'b': [Real], 'c': [Real], 'd': [Real]},
- prefer_skip_nested_validation=True)
+# Some helpers for the tests
+@validate_params(
+ {"a": [Real], "b": [Real], "c": [Real], "d": [Real]},
+ prefer_skip_nested_validation=True,
+)
def _func(a, b=0, *args, c, d=0, **kwargs):
"""A function to test the validation of functions."""
- pass
class _Class:
"""A class to test the _InstancesOf constraint and the validation of methods."""
- @validate_params({'a': [Real]}, prefer_skip_nested_validation=True)
+ @validate_params({"a": [Real]}, prefer_skip_nested_validation=True)
def _method(self, a):
"""A validated method"""
- pass
@deprecated()
- @validate_params({'a': [Real]}, prefer_skip_nested_validation=True)
+ @validate_params({"a": [Real]}, prefer_skip_nested_validation=True)
def _deprecated_method(self, a):
"""A deprecated validated method"""
- pass
class _Estimator(_ParamsValidationMixin, BaseEstimator):
"""An estimator to test the validation of estimator parameters."""
- _parameter_constraints: dict = {'a': [Real]}
+
+ _parameter_constraints: dict = {"a": [Real]}
def __init__(self, a):
self.a = a
+ @_fit_context(prefer_skip_nested_validation=True)
+ def fit(self, X=None, y=None):
+ pass
+
-@pytest.mark.parametrize('interval_type', [Integral, Real])
+@pytest.mark.parametrize("interval_type", [Integral, Real])
def test_interval_range(interval_type):
"""Check the range of values depending on closed."""
- pass
+ interval = Interval(interval_type, -2, 2, closed="left")
+ assert -2 in interval
+ assert 2 not in interval
+
+ interval = Interval(interval_type, -2, 2, closed="right")
+ assert -2 not in interval
+ assert 2 in interval
+
+ interval = Interval(interval_type, -2, 2, closed="both")
+ assert -2 in interval
+ assert 2 in interval
+ interval = Interval(interval_type, -2, 2, closed="neither")
+ assert -2 not in interval
+ assert 2 not in interval
-@pytest.mark.parametrize('interval_type', [Integral, Real])
+
+@pytest.mark.parametrize("interval_type", [Integral, Real])
def test_interval_large_integers(interval_type):
"""Check that Interval constraint work with large integers.
non-regression test for #26648.
"""
- pass
+ interval = Interval(interval_type, 0, 2, closed="neither")
+ assert 2**65 not in interval
+ assert 2**128 not in interval
+ assert float(2**65) not in interval
+ assert float(2**128) not in interval
+
+ interval = Interval(interval_type, 0, 2**128, closed="neither")
+ assert 2**65 in interval
+ assert 2**128 not in interval
+ assert float(2**65) in interval
+ assert float(2**128) not in interval
+
+ assert 2**1024 not in interval
def test_interval_inf_in_bounds():
@@ -64,235 +122,565 @@ def test_interval_inf_in_bounds():
Only valid for real intervals.
"""
- pass
+ interval = Interval(Real, 0, None, closed="right")
+ assert np.inf in interval
+ interval = Interval(Real, None, 0, closed="left")
+ assert -np.inf in interval
-@pytest.mark.parametrize('interval', [Interval(Real, 0, 1, closed='left'),
- Interval(Real, None, None, closed='both')])
+ interval = Interval(Real, None, None, closed="neither")
+ assert np.inf not in interval
+ assert -np.inf not in interval
+
+
+@pytest.mark.parametrize(
+ "interval",
+ [Interval(Real, 0, 1, closed="left"), Interval(Real, None, None, closed="both")],
+)
def test_nan_not_in_interval(interval):
"""Check that np.nan is not in any interval."""
- pass
-
-
-@pytest.mark.parametrize('params, error, match', [({'type': Integral,
- 'left': 1.0, 'right': 2, 'closed': 'both'}, TypeError,
- 'Expecting left to be an int for an interval over the integers'), ({
- 'type': Integral, 'left': 1, 'right': 2.0, 'closed': 'neither'},
- TypeError,
- 'Expecting right to be an int for an interval over the integers'), ({
- 'type': Integral, 'left': None, 'right': 0, 'closed': 'left'},
- ValueError, "left can't be None when closed == left"), ({'type':
- Integral, 'left': 0, 'right': None, 'closed': 'right'}, ValueError,
- "right can't be None when closed == right"), ({'type': Integral, 'left':
- 1, 'right': -1, 'closed': 'both'}, ValueError,
- "right can't be less than left")])
+ assert np.nan not in interval
+
+
+@pytest.mark.parametrize(
+ "params, error, match",
+ [
+ (
+ {"type": Integral, "left": 1.0, "right": 2, "closed": "both"},
+ TypeError,
+ r"Expecting left to be an int for an interval over the integers",
+ ),
+ (
+ {"type": Integral, "left": 1, "right": 2.0, "closed": "neither"},
+ TypeError,
+ "Expecting right to be an int for an interval over the integers",
+ ),
+ (
+ {"type": Integral, "left": None, "right": 0, "closed": "left"},
+ ValueError,
+ r"left can't be None when closed == left",
+ ),
+ (
+ {"type": Integral, "left": 0, "right": None, "closed": "right"},
+ ValueError,
+ r"right can't be None when closed == right",
+ ),
+ (
+ {"type": Integral, "left": 1, "right": -1, "closed": "both"},
+ ValueError,
+ r"right can't be less than left",
+ ),
+ ],
+)
def test_interval_errors(params, error, match):
"""Check that informative errors are raised for invalid combination of parameters"""
- pass
+ with pytest.raises(error, match=match):
+ Interval(**params)
def test_stroptions():
"""Sanity check for the StrOptions constraint"""
- pass
+ options = StrOptions({"a", "b", "c"}, deprecated={"c"})
+ assert options.is_satisfied_by("a")
+ assert options.is_satisfied_by("c")
+ assert not options.is_satisfied_by("d")
+
+ assert "'c' (deprecated)" in str(options)
def test_options():
"""Sanity check for the Options constraint"""
- pass
+ options = Options(Real, {-0.5, 0.5, np.inf}, deprecated={-0.5})
+ assert options.is_satisfied_by(-0.5)
+ assert options.is_satisfied_by(np.inf)
+ assert not options.is_satisfied_by(1.23)
+
+ assert "-0.5 (deprecated)" in str(options)
-@pytest.mark.parametrize('type, expected_type_name', [(int, 'int'), (
- Integral, 'int'), (Real, 'float'), (np.ndarray, 'numpy.ndarray')])
+@pytest.mark.parametrize(
+ "type, expected_type_name",
+ [
+ (int, "int"),
+ (Integral, "int"),
+ (Real, "float"),
+ (np.ndarray, "numpy.ndarray"),
+ ],
+)
def test_instances_of_type_human_readable(type, expected_type_name):
"""Check the string representation of the _InstancesOf constraint."""
- pass
+ constraint = _InstancesOf(type)
+ assert str(constraint) == f"an instance of '{expected_type_name}'"
def test_hasmethods():
"""Check the HasMethods constraint."""
- pass
+ constraint = HasMethods(["a", "b"])
+
+ class _Good:
+ def a(self):
+ pass # pragma: no cover
+
+ def b(self):
+ pass # pragma: no cover
+
+ class _Bad:
+ def a(self):
+ pass # pragma: no cover
+ assert constraint.is_satisfied_by(_Good())
+ assert not constraint.is_satisfied_by(_Bad())
+ assert str(constraint) == "an object implementing 'a' and 'b'"
-@pytest.mark.parametrize('constraint', [Interval(Real, None, 0, closed=
- 'left'), Interval(Real, 0, None, closed='left'), Interval(Real, None,
- None, closed='neither'), StrOptions({'a', 'b', 'c'}), MissingValues(),
- MissingValues(numeric_only=True), _VerboseHelper(), HasMethods('fit'),
- _IterablesNotString(), _CVObjects()])
+
+@pytest.mark.parametrize(
+ "constraint",
+ [
+ Interval(Real, None, 0, closed="left"),
+ Interval(Real, 0, None, closed="left"),
+ Interval(Real, None, None, closed="neither"),
+ StrOptions({"a", "b", "c"}),
+ MissingValues(),
+ MissingValues(numeric_only=True),
+ _VerboseHelper(),
+ HasMethods("fit"),
+ _IterablesNotString(),
+ _CVObjects(),
+ ],
+)
def test_generate_invalid_param_val(constraint):
"""Check that the value generated does not satisfy the constraint"""
- pass
-
-
-@pytest.mark.parametrize('integer_interval, real_interval', [(Interval(
- Integral, None, 3, closed='right'), Interval(RealNotInt, -5, 5, closed=
- 'both')), (Interval(Integral, None, 3, closed='right'), Interval(
- RealNotInt, -5, 5, closed='neither')), (Interval(Integral, None, 3,
- closed='right'), Interval(RealNotInt, 4, 5, closed='both')), (Interval(
- Integral, None, 3, closed='right'), Interval(RealNotInt, 5, None,
- closed='left')), (Interval(Integral, None, 3, closed='right'), Interval
- (RealNotInt, 4, None, closed='neither')), (Interval(Integral, 3, None,
- closed='left'), Interval(RealNotInt, -5, 5, closed='both')), (Interval(
- Integral, 3, None, closed='left'), Interval(RealNotInt, -5, 5, closed=
- 'neither')), (Interval(Integral, 3, None, closed='left'), Interval(
- RealNotInt, 1, 2, closed='both')), (Interval(Integral, 3, None, closed=
- 'left'), Interval(RealNotInt, None, -5, closed='left')), (Interval(
- Integral, 3, None, closed='left'), Interval(RealNotInt, None, -4,
- closed='neither')), (Interval(Integral, -5, 5, closed='both'), Interval
- (RealNotInt, None, 1, closed='right')), (Interval(Integral, -5, 5,
- closed='both'), Interval(RealNotInt, 1, None, closed='left')), (
- Interval(Integral, -5, 5, closed='both'), Interval(RealNotInt, -10, -4,
- closed='neither')), (Interval(Integral, -5, 5, closed='both'), Interval
- (RealNotInt, -10, -4, closed='right')), (Interval(Integral, -5, 5,
- closed='neither'), Interval(RealNotInt, 6, 10, closed='neither')), (
- Interval(Integral, -5, 5, closed='neither'), Interval(RealNotInt, 6, 10,
- closed='left')), (Interval(Integral, 2, None, closed='left'), Interval(
- RealNotInt, 0, 1, closed='both')), (Interval(Integral, 1, None, closed=
- 'left'), Interval(RealNotInt, 0, 1, closed='both'))])
-def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval
- ):
+ bad_value = generate_invalid_param_val(constraint)
+ assert not constraint.is_satisfied_by(bad_value)
+
+
+@pytest.mark.parametrize(
+ "integer_interval, real_interval",
+ [
+ (
+ Interval(Integral, None, 3, closed="right"),
+ Interval(RealNotInt, -5, 5, closed="both"),
+ ),
+ (
+ Interval(Integral, None, 3, closed="right"),
+ Interval(RealNotInt, -5, 5, closed="neither"),
+ ),
+ (
+ Interval(Integral, None, 3, closed="right"),
+ Interval(RealNotInt, 4, 5, closed="both"),
+ ),
+ (
+ Interval(Integral, None, 3, closed="right"),
+ Interval(RealNotInt, 5, None, closed="left"),
+ ),
+ (
+ Interval(Integral, None, 3, closed="right"),
+ Interval(RealNotInt, 4, None, closed="neither"),
+ ),
+ (
+ Interval(Integral, 3, None, closed="left"),
+ Interval(RealNotInt, -5, 5, closed="both"),
+ ),
+ (
+ Interval(Integral, 3, None, closed="left"),
+ Interval(RealNotInt, -5, 5, closed="neither"),
+ ),
+ (
+ Interval(Integral, 3, None, closed="left"),
+ Interval(RealNotInt, 1, 2, closed="both"),
+ ),
+ (
+ Interval(Integral, 3, None, closed="left"),
+ Interval(RealNotInt, None, -5, closed="left"),
+ ),
+ (
+ Interval(Integral, 3, None, closed="left"),
+ Interval(RealNotInt, None, -4, closed="neither"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="both"),
+ Interval(RealNotInt, None, 1, closed="right"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="both"),
+ Interval(RealNotInt, 1, None, closed="left"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="both"),
+ Interval(RealNotInt, -10, -4, closed="neither"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="both"),
+ Interval(RealNotInt, -10, -4, closed="right"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="neither"),
+ Interval(RealNotInt, 6, 10, closed="neither"),
+ ),
+ (
+ Interval(Integral, -5, 5, closed="neither"),
+ Interval(RealNotInt, 6, 10, closed="left"),
+ ),
+ (
+ Interval(Integral, 2, None, closed="left"),
+ Interval(RealNotInt, 0, 1, closed="both"),
+ ),
+ (
+ Interval(Integral, 1, None, closed="left"),
+ Interval(RealNotInt, 0, 1, closed="both"),
+ ),
+ ],
+)
+def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval):
"""Check that the value generated for an interval constraint does not satisfy any of
the interval constraints.
"""
- pass
+ bad_value = generate_invalid_param_val(constraint=real_interval)
+ assert not real_interval.is_satisfied_by(bad_value)
+ assert not integer_interval.is_satisfied_by(bad_value)
+ bad_value = generate_invalid_param_val(constraint=integer_interval)
+ assert not real_interval.is_satisfied_by(bad_value)
+ assert not integer_interval.is_satisfied_by(bad_value)
-@pytest.mark.parametrize('constraint', [_ArrayLikes(), _InstancesOf(list),
- _Callables(), _NoneConstraint(), _RandomStates(), _SparseMatrices(),
- _Booleans(), Interval(Integral, None, None, closed='neither')])
+
+@pytest.mark.parametrize(
+ "constraint",
+ [
+ _ArrayLikes(),
+ _InstancesOf(list),
+ _Callables(),
+ _NoneConstraint(),
+ _RandomStates(),
+ _SparseMatrices(),
+ _Booleans(),
+ Interval(Integral, None, None, closed="neither"),
+ ],
+)
def test_generate_invalid_param_val_all_valid(constraint):
"""Check that the function raises NotImplementedError when there's no invalid value
for the constraint.
"""
- pass
-
-
-@pytest.mark.parametrize('constraint', [_ArrayLikes(), _Callables(),
- _InstancesOf(list), _NoneConstraint(), _RandomStates(), _SparseMatrices
- (), _Booleans(), _VerboseHelper(), MissingValues(), MissingValues(
- numeric_only=True), StrOptions({'a', 'b', 'c'}), Options(Integral, {1,
- 2, 3}), Interval(Integral, None, None, closed='neither'), Interval(
- Integral, 0, 10, closed='neither'), Interval(Integral, 0, None, closed=
- 'neither'), Interval(Integral, None, 0, closed='neither'), Interval(
- Real, 0, 1, closed='neither'), Interval(Real, 0, None, closed='both'),
- Interval(Real, None, 0, closed='right'), HasMethods('fit'),
- _IterablesNotString(), _CVObjects()])
+ with pytest.raises(NotImplementedError):
+ generate_invalid_param_val(constraint)
+
+
+@pytest.mark.parametrize(
+ "constraint",
+ [
+ _ArrayLikes(),
+ _Callables(),
+ _InstancesOf(list),
+ _NoneConstraint(),
+ _RandomStates(),
+ _SparseMatrices(),
+ _Booleans(),
+ _VerboseHelper(),
+ MissingValues(),
+ MissingValues(numeric_only=True),
+ StrOptions({"a", "b", "c"}),
+ Options(Integral, {1, 2, 3}),
+ Interval(Integral, None, None, closed="neither"),
+ Interval(Integral, 0, 10, closed="neither"),
+ Interval(Integral, 0, None, closed="neither"),
+ Interval(Integral, None, 0, closed="neither"),
+ Interval(Real, 0, 1, closed="neither"),
+ Interval(Real, 0, None, closed="both"),
+ Interval(Real, None, 0, closed="right"),
+ HasMethods("fit"),
+ _IterablesNotString(),
+ _CVObjects(),
+ ],
+)
def test_generate_valid_param(constraint):
"""Check that the value generated does satisfy the constraint."""
- pass
-
-
-@pytest.mark.parametrize('constraint_declaration, value', [(Interval(Real,
- 0, 1, closed='both'), 0.42), (Interval(Integral, 0, None, closed=
- 'neither'), 42), (StrOptions({'a', 'b', 'c'}), 'b'), (Options(type, {np
- .float32, np.float64}), np.float64), (callable, lambda x: x + 1), (None,
- None), ('array-like', [[1, 2], [3, 4]]), ('array-like', np.array([[1, 2
- ], [3, 4]])), ('sparse matrix', csr_matrix([[1, 2], [3, 4]])), (
- 'random_state', 0), ('random_state', np.random.RandomState(0)), (
- 'random_state', None), (_Class, _Class()), (int, 1), (Real, 0.5), (
- 'boolean', False), ('verbose', 1), ('nan', np.nan), (MissingValues(), -
- 1), (MissingValues(), -1.0), (MissingValues(), 2 ** 1028), (
- MissingValues(), None), (MissingValues(), float('nan')), (MissingValues
- (), np.nan), (MissingValues(), 'missing'), (HasMethods('fit'),
- _Estimator(a=0)), ('cv_object', 5)])
+ value = generate_valid_param(constraint)
+ assert constraint.is_satisfied_by(value)
+
+
+@pytest.mark.parametrize(
+ "constraint_declaration, value",
+ [
+ (Interval(Real, 0, 1, closed="both"), 0.42),
+ (Interval(Integral, 0, None, closed="neither"), 42),
+ (StrOptions({"a", "b", "c"}), "b"),
+ (Options(type, {np.float32, np.float64}), np.float64),
+ (callable, lambda x: x + 1),
+ (None, None),
+ ("array-like", [[1, 2], [3, 4]]),
+ ("array-like", np.array([[1, 2], [3, 4]])),
+ ("sparse matrix", csr_matrix([[1, 2], [3, 4]])),
+ ("random_state", 0),
+ ("random_state", np.random.RandomState(0)),
+ ("random_state", None),
+ (_Class, _Class()),
+ (int, 1),
+ (Real, 0.5),
+ ("boolean", False),
+ ("verbose", 1),
+ ("nan", np.nan),
+ (MissingValues(), -1),
+ (MissingValues(), -1.0),
+ (MissingValues(), 2**1028),
+ (MissingValues(), None),
+ (MissingValues(), float("nan")),
+ (MissingValues(), np.nan),
+ (MissingValues(), "missing"),
+ (HasMethods("fit"), _Estimator(a=0)),
+ ("cv_object", 5),
+ ],
+)
def test_is_satisfied_by(constraint_declaration, value):
"""Sanity check for the is_satisfied_by method"""
- pass
-
-
-@pytest.mark.parametrize('constraint_declaration, expected_constraint_class',
- [(Interval(Real, 0, 1, closed='both'), Interval), (StrOptions({
- 'option1', 'option2'}), StrOptions), (Options(Real, {0.42, 1.23}),
- Options), ('array-like', _ArrayLikes), ('sparse matrix',
- _SparseMatrices), ('random_state', _RandomStates), (None,
- _NoneConstraint), (callable, _Callables), (int, _InstancesOf), (
- 'boolean', _Booleans), ('verbose', _VerboseHelper), (MissingValues(
- numeric_only=True), MissingValues), (HasMethods('fit'), HasMethods), (
- 'cv_object', _CVObjects), ('nan', _NanConstraint)])
+ constraint = make_constraint(constraint_declaration)
+ assert constraint.is_satisfied_by(value)
+
+
+@pytest.mark.parametrize(
+ "constraint_declaration, expected_constraint_class",
+ [
+ (Interval(Real, 0, 1, closed="both"), Interval),
+ (StrOptions({"option1", "option2"}), StrOptions),
+ (Options(Real, {0.42, 1.23}), Options),
+ ("array-like", _ArrayLikes),
+ ("sparse matrix", _SparseMatrices),
+ ("random_state", _RandomStates),
+ (None, _NoneConstraint),
+ (callable, _Callables),
+ (int, _InstancesOf),
+ ("boolean", _Booleans),
+ ("verbose", _VerboseHelper),
+ (MissingValues(numeric_only=True), MissingValues),
+ (HasMethods("fit"), HasMethods),
+ ("cv_object", _CVObjects),
+ ("nan", _NanConstraint),
+ ],
+)
def test_make_constraint(constraint_declaration, expected_constraint_class):
"""Check that make_constraint dispatches to the appropriate constraint class"""
- pass
+ constraint = make_constraint(constraint_declaration)
+ assert constraint.__class__ is expected_constraint_class
def test_make_constraint_unknown():
"""Check that an informative error is raised when an unknown constraint is passed"""
- pass
+ with pytest.raises(ValueError, match="Unknown constraint"):
+ make_constraint("not a valid constraint")
def test_validate_params():
"""Check that validate_params works no matter how the arguments are passed"""
- pass
+ with pytest.raises(
+ InvalidParameterError, match="The 'a' parameter of _func must be"
+ ):
+ _func("wrong", c=1)
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'b' parameter of _func must be"
+ ):
+ _func(*[1, "wrong"], c=1)
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'c' parameter of _func must be"
+ ):
+ _func(1, **{"c": "wrong"})
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'd' parameter of _func must be"
+ ):
+ _func(1, c=1, d="wrong")
+
+ # check in the presence of extra positional and keyword args
+ with pytest.raises(
+ InvalidParameterError, match="The 'b' parameter of _func must be"
+ ):
+ _func(0, *["wrong", 2, 3], c=4, **{"e": 5})
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'c' parameter of _func must be"
+ ):
+ _func(0, *[1, 2, 3], c="four", **{"e": 5})
def test_validate_params_missing_params():
"""Check that no error is raised when there are parameters without
constraints
"""
- pass
+
+ @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
+ def func(a, b):
+ pass
+
+ func(1, 2)
def test_decorate_validated_function():
"""Check that validate_params functions can be decorated"""
- pass
+ decorated_function = deprecated()(_func)
+
+ with pytest.warns(FutureWarning, match="Function _func is deprecated"):
+ decorated_function(1, 2, c=3)
+
+ # outer decorator does not interfere with validation
+ with pytest.warns(FutureWarning, match="Function _func is deprecated"):
+ with pytest.raises(
+ InvalidParameterError, match=r"The 'c' parameter of _func must be"
+ ):
+ decorated_function(1, 2, c="wrong")
def test_validate_params_method():
"""Check that validate_params works with methods"""
- pass
+ with pytest.raises(
+ InvalidParameterError, match="The 'a' parameter of _Class._method must be"
+ ):
+ _Class()._method("wrong")
+
+ # validated method can be decorated
+ with pytest.warns(FutureWarning, match="Function _deprecated_method is deprecated"):
+ with pytest.raises(
+ InvalidParameterError,
+ match="The 'a' parameter of _Class._deprecated_method must be",
+ ):
+ _Class()._deprecated_method("wrong")
def test_validate_params_estimator():
"""Check that validate_params works with Estimator instances"""
- pass
+ # no validation in init
+ est = _Estimator("wrong")
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'a' parameter of _Estimator must be"
+ ):
+ est.fit()
def test_stroptions_deprecated_subset():
"""Check that the deprecated parameter must be a subset of options."""
- pass
+ with pytest.raises(ValueError, match="deprecated options must be a subset"):
+ StrOptions({"a", "b", "c"}, deprecated={"a", "d"})
def test_hidden_constraint():
"""Check that internal constraints are not exposed in the error message."""
- pass
+
+ @validate_params(
+ {"param": [Hidden(list), dict]}, prefer_skip_nested_validation=True
+ )
+ def f(param):
+ pass
+
+ # list and dict are valid params
+ f({"a": 1, "b": 2, "c": 3})
+ f([1, 2, 3])
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'param' parameter"
+ ) as exc_info:
+ f(param="bad")
+
+ # the list option is not exposed in the error message
+ err_msg = str(exc_info.value)
+ assert "an instance of 'dict'" in err_msg
+ assert "an instance of 'list'" not in err_msg
def test_hidden_stroptions():
"""Check that we can have 2 StrOptions constraints, one being hidden."""
- pass
+
+ @validate_params(
+ {"param": [StrOptions({"auto"}), Hidden(StrOptions({"warn"}))]},
+ prefer_skip_nested_validation=True,
+ )
+ def f(param):
+ pass
+
+ # "auto" and "warn" are valid params
+ f("auto")
+ f("warn")
+
+ with pytest.raises(
+ InvalidParameterError, match="The 'param' parameter"
+ ) as exc_info:
+ f(param="bad")
+
+ # the "warn" option is not exposed in the error message
+ err_msg = str(exc_info.value)
+ assert "auto" in err_msg
+ assert "warn" not in err_msg
def test_validate_params_set_param_constraints_attribute():
"""Check that the validate_params decorator properly sets the parameter constraints
as attribute of the decorated function/method.
"""
- pass
+ assert hasattr(_func, "_skl_parameter_constraints")
+ assert hasattr(_Class()._method, "_skl_parameter_constraints")
def test_boolean_constraint_deprecated_int():
"""Check that validate_params raise a deprecation message but still passes
validation when using an int for a parameter accepting a boolean.
"""
- pass
+
+ @validate_params({"param": ["boolean"]}, prefer_skip_nested_validation=True)
+ def f(param):
+ pass
+
+ # True/False and np.bool_(True/False) are valid params
+ f(True)
+ f(np.bool_(False))
def test_no_validation():
"""Check that validation can be skipped for a parameter."""
- pass
+
+ @validate_params(
+ {"param1": [int, None], "param2": "no_validation"},
+ prefer_skip_nested_validation=True,
+ )
+ def f(param1=None, param2=None):
+ pass
+
+ # param1 is validated
+ with pytest.raises(InvalidParameterError, match="The 'param1' parameter"):
+ f(param1="wrong")
+
+ # param2 is not validated: any type is valid.
+ class SomeType:
+ pass
+
+ f(param2=SomeType)
+ f(param2=SomeType())
def test_pandas_na_constraint_with_pd_na():
"""Add a specific test for checking support for `pandas.NA`."""
- pass
+ pd = pytest.importorskip("pandas")
+
+ na_constraint = _PandasNAConstraint()
+ assert na_constraint.is_satisfied_by(pd.NA)
+ assert not na_constraint.is_satisfied_by(np.array([1, 2, 3]))
def test_iterable_not_string():
"""Check that a string does not satisfy the _IterableNotString constraint."""
- pass
+ constraint = _IterablesNotString()
+ assert constraint.is_satisfied_by([1, 2, 3])
+ assert constraint.is_satisfied_by(range(10))
+ assert not constraint.is_satisfied_by("some string")
def test_cv_objects():
"""Check that the _CVObjects constraint accepts all current ways
to pass cv objects."""
- pass
+ constraint = _CVObjects()
+ assert constraint.is_satisfied_by(5)
+ assert constraint.is_satisfied_by(LeaveOneOut())
+ assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])])
+ assert constraint.is_satisfied_by(None)
+ assert not constraint.is_satisfied_by("not a CV object")
def test_third_party_estimator():
@@ -300,35 +688,98 @@ def test_third_party_estimator():
party estimator does not impose a match between the dict of constraints and the
parameters of the estimator.
"""
- pass
+
+ class ThirdPartyEstimator(_Estimator):
+ def __init__(self, b):
+ self.b = b
+ super().__init__(a=0)
+
+ def fit(self, X=None, y=None):
+ super().fit(X, y)
+
+ # does not raise, even though "b" is not in the constraints dict and "a" is not
+ # a parameter of the estimator.
+ ThirdPartyEstimator(b=0).fit()
def test_interval_real_not_int():
"""Check for the type RealNotInt in the Interval constraint."""
- pass
+ constraint = Interval(RealNotInt, 0, 1, closed="both")
+ assert constraint.is_satisfied_by(1.0)
+ assert not constraint.is_satisfied_by(1)
def test_real_not_int():
"""Check for the RealNotInt type."""
- pass
+ assert isinstance(1.0, RealNotInt)
+ assert not isinstance(1, RealNotInt)
+ assert isinstance(np.float64(1), RealNotInt)
+ assert not isinstance(np.int64(1), RealNotInt)
def test_skip_param_validation():
"""Check that param validation can be skipped using config_context."""
- pass
+ @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
+ def f(a):
+ pass
+
+ with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
+ f(a="1")
+
+ # does not raise
+ with config_context(skip_parameter_validation=True):
+ f(a="1")
-@pytest.mark.parametrize('prefer_skip_nested_validation', [True, False])
+
+@pytest.mark.parametrize("prefer_skip_nested_validation", [True, False])
def test_skip_nested_validation(prefer_skip_nested_validation):
"""Check that nested validation can be skipped."""
- pass
+
+ @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
+ def f(a):
+ pass
+
+ @validate_params(
+ {"b": [int]},
+ prefer_skip_nested_validation=prefer_skip_nested_validation,
+ )
+ def g(b):
+ # calls f with a bad parameter type
+ return f(a="invalid_param_value")
+
+ # Validation for g is never skipped.
+ with pytest.raises(InvalidParameterError, match="The 'b' parameter"):
+ g(b="invalid_param_value")
+
+ if prefer_skip_nested_validation:
+ g(b=1) # does not raise because inner f is not validated
+ else:
+ with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
+ g(b=1)
@pytest.mark.parametrize(
- 'skip_parameter_validation, prefer_skip_nested_validation, expected_skipped'
- , [(True, True, True), (True, False, True), (False, True, True), (False,
- False, False)])
-def test_skip_nested_validation_and_config_context(skip_parameter_validation,
- prefer_skip_nested_validation, expected_skipped):
+ "skip_parameter_validation, prefer_skip_nested_validation, expected_skipped",
+ [
+ (True, True, True),
+ (True, False, True),
+ (False, True, True),
+ (False, False, False),
+ ],
+)
+def test_skip_nested_validation_and_config_context(
+ skip_parameter_validation, prefer_skip_nested_validation, expected_skipped
+):
"""Check interaction between global skip and local skip."""
- pass
+
+ @validate_params(
+ {"a": [int]}, prefer_skip_nested_validation=prefer_skip_nested_validation
+ )
+ def g(a):
+ return get_config()["skip_parameter_validation"]
+
+ with config_context(skip_parameter_validation=skip_parameter_validation):
+ actual_skipped = g(1)
+
+ assert actual_skipped == expected_skipped
diff --git a/imblearn/utils/tests/test_show_versions.py b/imblearn/utils/tests/test_show_versions.py
index ca6a29e..1b43053 100644
--- a/imblearn/utils/tests/test_show_versions.py
+++ b/imblearn/utils/tests/test_show_versions.py
@@ -1,2 +1,60 @@
"""Test for the show_versions helper. Based on the sklearn tests."""
+# Author: Alexander L. Hayes <hayesall@iu.edu>
+# License: MIT
+
from imblearn.utils._show_versions import _get_deps_info, show_versions
+
+
+def test_get_deps_info():
+ _deps_info = _get_deps_info()
+ assert "pip" in _deps_info
+ assert "setuptools" in _deps_info
+ assert "imbalanced-learn" in _deps_info
+ assert "scikit-learn" in _deps_info
+ assert "numpy" in _deps_info
+ assert "scipy" in _deps_info
+ assert "Cython" in _deps_info
+ assert "pandas" in _deps_info
+ assert "joblib" in _deps_info
+
+
+def test_show_versions_default(capsys):
+ show_versions()
+ out, err = capsys.readouterr()
+ assert "python" in out
+ assert "executable" in out
+ assert "machine" in out
+ assert "pip" in out
+ assert "setuptools" in out
+ assert "imbalanced-learn" in out
+ assert "scikit-learn" in out
+ assert "numpy" in out
+ assert "scipy" in out
+ assert "Cython" in out
+ assert "pandas" in out
+ assert "keras" in out
+ assert "tensorflow" in out
+ assert "joblib" in out
+
+
+def test_show_versions_github(capsys):
+ show_versions(github=True)
+ out, err = capsys.readouterr()
+ assert "<details><summary>System, Dependency Information</summary>" in out
+ assert "**System Information**" in out
+ assert "* python" in out
+ assert "* executable" in out
+ assert "* machine" in out
+ assert "**Python Dependencies**" in out
+ assert "* pip" in out
+ assert "* setuptools" in out
+ assert "* imbalanced-learn" in out
+ assert "* scikit-learn" in out
+ assert "* numpy" in out
+ assert "* scipy" in out
+ assert "* Cython" in out
+ assert "* pandas" in out
+ assert "* keras" in out
+ assert "* tensorflow" in out
+ assert "* joblib" in out
+ assert "</details>" in out
diff --git a/imblearn/utils/tests/test_testing.py b/imblearn/utils/tests/test_testing.py
index 421be2b..1b37978 100644
--- a/imblearn/utils/tests/test_testing.py
+++ b/imblearn/utils/tests/test_testing.py
@@ -1,12 +1,49 @@
"""Test for the testing module"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
import numpy as np
import pytest
from sklearn.neighbors._base import KNeighborsMixin
+
from imblearn.base import SamplerMixin
from imblearn.utils.testing import _CustomNearestNeighbors, all_estimators
+def test_all_estimators():
+ # check if the filtering is working with a list or a single string
+ type_filter = "sampler"
+ all_estimators(type_filter=type_filter)
+ type_filter = ["sampler"]
+ estimators = all_estimators(type_filter=type_filter)
+ for estimator in estimators:
+ # check that all estimators are sampler
+ assert issubclass(estimator[1], SamplerMixin)
+
+ # check that an error is raised when the type is unknown
+ type_filter = "rnd"
+ with pytest.raises(ValueError, match="Parameter type_filter must be 'sampler'"):
+ all_estimators(type_filter=type_filter)
+
+
def test_custom_nearest_neighbors():
"""Check that our custom nearest neighbors can be used for our internal
duck-typing."""
- pass
+
+ neareat_neighbors = _CustomNearestNeighbors(n_neighbors=3)
+
+ assert not isinstance(neareat_neighbors, KNeighborsMixin)
+ assert hasattr(neareat_neighbors, "kneighbors")
+ assert hasattr(neareat_neighbors, "kneighbors_graph")
+
+ rng = np.random.RandomState(42)
+ X = rng.randn(150, 3)
+ y = rng.randint(0, 2, 150)
+ neareat_neighbors.fit(X, y)
+
+ distances, indices = neareat_neighbors.kneighbors(X)
+ assert distances.shape == (150, 3)
+ assert indices.shape == (150, 3)
+ np.testing.assert_allclose(distances[:, 0], 0.0)
+ np.testing.assert_allclose(indices[:, 0], np.arange(150))
diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py
index b7e2a02..4394f04 100644
--- a/imblearn/utils/tests/test_validation.py
+++ b/imblearn/utils/tests/test_validation.py
@@ -1,13 +1,382 @@
"""Test for the validation helper"""
+# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
+# Christos Aridas
+# License: MIT
+
from collections import Counter, OrderedDict
+
import numpy as np
import pytest
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors._base import KNeighborsMixin
from sklearn.utils._testing import assert_array_equal
-from imblearn.utils import check_neighbors_object, check_sampling_strategy, check_target_type
-from imblearn.utils._validation import ArraysTransformer, _deprecate_positional_args, _is_neighbors_object
+
+from imblearn.utils import (
+ check_neighbors_object,
+ check_sampling_strategy,
+ check_target_type,
+)
+from imblearn.utils._validation import (
+ ArraysTransformer,
+ _deprecate_positional_args,
+ _is_neighbors_object,
+)
from imblearn.utils.testing import _CustomNearestNeighbors
+
multiclass_target = np.array([1] * 50 + [2] * 100 + [3] * 25)
binary_target = np.array([1] * 25 + [0] * 100)
+
+
+def test_check_neighbors_object():
+ name = "n_neighbors"
+ n_neighbors = 1
+ estimator = check_neighbors_object(name, n_neighbors)
+ assert issubclass(type(estimator), KNeighborsMixin)
+ assert estimator.n_neighbors == 1
+ estimator = check_neighbors_object(name, n_neighbors, 1)
+ assert issubclass(type(estimator), KNeighborsMixin)
+ assert estimator.n_neighbors == 2
+ estimator = NearestNeighbors(n_neighbors=n_neighbors)
+ estimator_cloned = check_neighbors_object(name, estimator)
+ assert estimator.n_neighbors == estimator_cloned.n_neighbors
+ estimator = _CustomNearestNeighbors()
+ estimator_cloned = check_neighbors_object(name, estimator)
+ assert isinstance(estimator_cloned, _CustomNearestNeighbors)
+
+
+@pytest.mark.parametrize(
+ "target, output_target",
+ [
+ (np.array([0, 1, 1]), np.array([0, 1, 1])),
+ (np.array([0, 1, 2]), np.array([0, 1, 2])),
+ (np.array([[0, 1], [1, 0]]), np.array([1, 0])),
+ ],
+)
+def test_check_target_type(target, output_target):
+ converted_target = check_target_type(target.astype(int))
+ assert_array_equal(converted_target, output_target.astype(int))
+
+
+@pytest.mark.parametrize(
+ "target, output_target, is_ova",
+ [
+ (np.array([0, 1, 1]), np.array([0, 1, 1]), False),
+ (np.array([0, 1, 2]), np.array([0, 1, 2]), False),
+ (np.array([[0, 1], [1, 0]]), np.array([1, 0]), True),
+ ],
+)
+def test_check_target_type_ova(target, output_target, is_ova):
+ converted_target, binarize_target = check_target_type(
+ target.astype(int), indicate_one_vs_all=True
+ )
+ assert_array_equal(converted_target, output_target.astype(int))
+ assert binarize_target == is_ova
+
+
+def test_check_sampling_strategy_warning():
+ msg = "dict for cleaning methods is not supported"
+ with pytest.raises(ValueError, match=msg):
+ check_sampling_strategy({1: 0, 2: 0, 3: 0}, multiclass_target, "clean-sampling")
+
+
+@pytest.mark.parametrize(
+ "ratio, y, type, err_msg",
+ [
+ (
+ 0.5,
+ binary_target,
+ "clean-sampling",
+ "'clean-sampling' methods do let the user specify the sampling ratio", # noqa
+ ),
+ (
+ 0.1,
+ np.array([0] * 10 + [1] * 20),
+ "over-sampling",
+ "remove samples from the minority class while trying to generate new", # noqa
+ ),
+ (
+ 0.1,
+ np.array([0] * 10 + [1] * 20),
+ "under-sampling",
+ "generate new sample in the majority class while trying to remove",
+ ),
+ ],
+)
+def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
+ with pytest.raises(ValueError, match=err_msg):
+ check_sampling_strategy(ratio, y, type)
+
+
+def test_check_sampling_strategy_error():
+ with pytest.raises(ValueError, match="'sampling_type' should be one of"):
+ check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd")
+
+ error_regex = "The target 'y' needs to have more than 1 class."
+ with pytest.raises(ValueError, match=error_regex):
+ check_sampling_strategy("auto", np.ones((10,)), "over-sampling")
+
+ error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
+ with pytest.raises(ValueError, match=error_regex):
+ check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling")
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, sampling_type, err_msg",
+ [
+ ("majority", "over-sampling", "over-sampler"),
+ ("minority", "under-sampling", "under-sampler"),
+ ],
+)
+def test_check_sampling_strategy_error_wrong_string(
+ sampling_strategy, sampling_type, err_msg
+):
+ with pytest.raises(
+ ValueError,
+ match=("'{}' cannot be used with {}".format(sampling_strategy, err_msg)),
+ ):
+ check_sampling_strategy(sampling_strategy, np.array([1, 2, 3]), sampling_type)
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, sampling_method",
+ [
+ ({10: 10}, "under-sampling"),
+ ({10: 10}, "over-sampling"),
+ ([10], "clean-sampling"),
+ ],
+)
+def test_sampling_strategy_class_target_unknown(sampling_strategy, sampling_method):
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ with pytest.raises(ValueError, match="are not present in the data."):
+ check_sampling_strategy(sampling_strategy, y, sampling_method)
+
+
+def test_sampling_strategy_dict_error():
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ sampling_strategy = {1: -100, 2: 50, 3: 25}
+ with pytest.raises(ValueError, match="in a class cannot be negative."):
+ check_sampling_strategy(sampling_strategy, y, "under-sampling")
+ sampling_strategy = {1: 45, 2: 100, 3: 70}
+ error_regex = (
+ "With over-sampling methods, the number of samples in a"
+ " class should be greater or equal to the original number"
+ " of samples. Originally, there is 50 samples and 45"
+ " samples are asked."
+ )
+ with pytest.raises(ValueError, match=error_regex):
+ check_sampling_strategy(sampling_strategy, y, "over-sampling")
+
+ error_regex = (
+ "With under-sampling methods, the number of samples in a"
+ " class should be less or equal to the original number of"
+ " samples. Originally, there is 25 samples and 70 samples"
+ " are asked."
+ )
+ with pytest.raises(ValueError, match=error_regex):
+ check_sampling_strategy(sampling_strategy, y, "under-sampling")
+
+
+@pytest.mark.parametrize("sampling_strategy", [-10, 10])
+def test_sampling_strategy_float_error_not_in_range(sampling_strategy):
+ y = np.array([1] * 50 + [2] * 100)
+ with pytest.raises(ValueError, match="it should be in the range"):
+ check_sampling_strategy(sampling_strategy, y, "under-sampling")
+
+
+def test_sampling_strategy_float_error_not_binary():
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ with pytest.raises(ValueError, match="the type of target is binary"):
+ sampling_strategy = 0.5
+ check_sampling_strategy(sampling_strategy, y, "under-sampling")
+
+
+@pytest.mark.parametrize("sampling_method", ["over-sampling", "under-sampling"])
+def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ with pytest.raises(ValueError, match="cannot be a list for samplers"):
+ sampling_strategy = [1, 2, 3]
+ check_sampling_strategy(sampling_strategy, y, sampling_method)
+
+
+def _sampling_strategy_func(y):
+ # this function could create an equal number of samples
+ target_stats = Counter(y)
+ n_samples = max(target_stats.values())
+ return {key: int(n_samples) for key in target_stats.keys()}
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, sampling_type, expected_sampling_strategy, target",
+ [
+ ("auto", "under-sampling", {1: 25, 2: 25}, multiclass_target),
+ ("auto", "clean-sampling", {1: 25, 2: 25}, multiclass_target),
+ ("auto", "over-sampling", {1: 50, 3: 75}, multiclass_target),
+ ("all", "over-sampling", {1: 50, 2: 0, 3: 75}, multiclass_target),
+ ("all", "under-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target),
+ ("all", "clean-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target),
+ ("majority", "under-sampling", {2: 25}, multiclass_target),
+ ("majority", "clean-sampling", {2: 25}, multiclass_target),
+ ("minority", "over-sampling", {3: 75}, multiclass_target),
+ ("not minority", "over-sampling", {1: 50, 2: 0}, multiclass_target),
+ ("not minority", "under-sampling", {1: 25, 2: 25}, multiclass_target),
+ ("not minority", "clean-sampling", {1: 25, 2: 25}, multiclass_target),
+ ("not majority", "over-sampling", {1: 50, 3: 75}, multiclass_target),
+ ("not majority", "under-sampling", {1: 25, 3: 25}, multiclass_target),
+ ("not majority", "clean-sampling", {1: 25, 3: 25}, multiclass_target),
+ (
+ {1: 70, 2: 100, 3: 70},
+ "over-sampling",
+ {1: 20, 2: 0, 3: 45},
+ multiclass_target,
+ ),
+ (
+ {1: 30, 2: 45, 3: 25},
+ "under-sampling",
+ {1: 30, 2: 45, 3: 25},
+ multiclass_target,
+ ),
+ ([1], "clean-sampling", {1: 25}, multiclass_target),
+ (
+ _sampling_strategy_func,
+ "over-sampling",
+ {1: 50, 2: 0, 3: 75},
+ multiclass_target,
+ ),
+ (0.5, "over-sampling", {1: 25}, binary_target),
+ (0.5, "under-sampling", {0: 50}, binary_target),
+ ],
+)
+def test_check_sampling_strategy(
+ sampling_strategy, sampling_type, expected_sampling_strategy, target
+):
+ sampling_strategy_ = check_sampling_strategy(
+ sampling_strategy, target, sampling_type
+ )
+ assert sampling_strategy_ == expected_sampling_strategy
+
+
+def test_sampling_strategy_callable_args():
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ multiplier = {1: 1.5, 2: 1, 3: 3}
+
+ def sampling_strategy_func(y, multiplier):
+ """samples such that each class will be affected by the multiplier."""
+ target_stats = Counter(y)
+ return {
+ key: int(values * multiplier[key]) for key, values in target_stats.items()
+ }
+
+ sampling_strategy_ = check_sampling_strategy(
+ sampling_strategy_func, y, "over-sampling", multiplier=multiplier
+ )
+ assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, sampling_type, expected_result",
+ [
+ (
+ {3: 25, 1: 25, 2: 25},
+ "under-sampling",
+ OrderedDict({1: 25, 2: 25, 3: 25}),
+ ),
+ (
+ {3: 100, 1: 100, 2: 100},
+ "over-sampling",
+ OrderedDict({1: 50, 2: 0, 3: 75}),
+ ),
+ ],
+)
+def test_sampling_strategy_check_order(
+ sampling_strategy, sampling_type, expected_result
+):
+ # We pass on purpose a non sorted dictionary and check that the resulting
+ # dictionary is sorted. Refer to issue #428.
+ y = np.array([1] * 50 + [2] * 100 + [3] * 25)
+ sampling_strategy_ = check_sampling_strategy(sampling_strategy, y, sampling_type)
+ assert sampling_strategy_ == expected_result
+
+
+def test_arrays_transformer_plain_list():
+ X = np.array([[0, 0], [1, 1]])
+ y = np.array([[0, 0], [1, 1]])
+
+ arrays_transformer = ArraysTransformer(X.tolist(), y.tolist())
+ X_res, y_res = arrays_transformer.transform(X, y)
+ assert isinstance(X_res, list)
+ assert isinstance(y_res, list)
+
+
+def test_arrays_transformer_numpy():
+ X = np.array([[0, 0], [1, 1]])
+ y = np.array([[0, 0], [1, 1]])
+
+ arrays_transformer = ArraysTransformer(X, y)
+ X_res, y_res = arrays_transformer.transform(X, y)
+ assert isinstance(X_res, np.ndarray)
+ assert isinstance(y_res, np.ndarray)
+
+
+def test_arrays_transformer_pandas():
+ pd = pytest.importorskip("pandas")
+
+ X = np.array([[0, 0], [1, 1]])
+ y = np.array([0, 1])
+
+ X_df = pd.DataFrame(X, columns=["a", "b"])
+ X_df = X_df.astype(int)
+ y_df = pd.DataFrame(y, columns=["target"])
+ y_df = y_df.astype(int)
+ y_s = pd.Series(y, name="target", dtype=int)
+
+ # DataFrame and DataFrame case
+ arrays_transformer = ArraysTransformer(X_df, y_df)
+ X_res, y_res = arrays_transformer.transform(X, y)
+ assert isinstance(X_res, pd.DataFrame)
+ assert_array_equal(X_res.columns, X_df.columns)
+ assert_array_equal(X_res.dtypes, X_df.dtypes)
+ assert isinstance(y_res, pd.DataFrame)
+ assert_array_equal(y_res.columns, y_df.columns)
+ assert_array_equal(y_res.dtypes, y_df.dtypes)
+
+ # DataFrames and Series case
+ arrays_transformer = ArraysTransformer(X_df, y_s)
+ _, y_res = arrays_transformer.transform(X, y)
+ assert isinstance(y_res, pd.Series)
+ assert_array_equal(y_res.name, y_s.name)
+ assert_array_equal(y_res.dtype, y_s.dtype)
+
+
+def test_deprecate_positional_args_warns_for_function():
+ @_deprecate_positional_args
+ def f1(a, b, *, c=1, d=1):
+ pass
+
+ with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
+ f1(1, 2, 3)
+
+ with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
+ f1(1, 2, 3, 4)
+
+ @_deprecate_positional_args
+ def f2(a=1, *, b=1, c=1, d=1):
+ pass
+
+ with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
+ f2(1, 2)
+
+ # The * is place before a keyword only argument without a default value
+ @_deprecate_positional_args
+ def f3(a, *, b, c=1, d=1):
+ pass
+
+ with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
+ f3(1, 2)
+
+
+@pytest.mark.parametrize(
+ "estimator, is_neighbor_estimator", [(NearestNeighbors(), True), (KMeans(), False)]
+)
+def test_is_neighbors_object(estimator, is_neighbor_estimator):
+ assert _is_neighbors_object(estimator) == is_neighbor_estimator