# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import pickle
import re
import warnings

import numpy as np
import pytest
import scipy.sparse as sp
from numpy.testing import assert_allclose

import sklearn
from sklearn import config_context, datasets
from sklearn.base import (
    BaseEstimator,
    OutlierMixin,
    TransformerMixin,
    clone,
    is_classifier,
    is_clusterer,
    is_outlier_detector,
    is_regressor,
)
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._set_output import _get_output_config
from sklearn.utils._testing import (
    _convert_container,
    assert_array_equal,
)
from sklearn.utils.validation import _check_n_features, validate_data


#############################################################################
# A few test classes
class MyEstimator(BaseEstimator):
    def __init__(self, l1=0, empty=None):
        self.l1 = l1
        self.empty = empty


class K(BaseEstimator):
    def __init__(self, c=None, d=None):
        self.c = c
        self.d = d


class T(BaseEstimator):
    def __init__(self, a=None, b=None):
        self.a = a
        self.b = b


class NaNTag(BaseEstimator):
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.input_tags.allow_nan = True
        return tags


class NoNaNTag(BaseEstimator):
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.input_tags.allow_nan = False
        return tags


class OverrideTag(NaNTag):
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.input_tags.allow_nan = False
        return tags


class DiamondOverwriteTag(NaNTag, NoNaNTag):
    pass


class InheritDiamondOverwriteTag(DiamondOverwriteTag):
    pass


class ModifyInitParams(BaseEstimator):
    """Deprecated behavior.
    Equal parameters but with a type cast.
    Doesn't fulfill a is a
    """

    def __init__(self, a=np.array([0])):
        self.a = a.copy()


class Buggy(BaseEstimator):
    "A buggy estimator that does not set its parameters right."

    def __init__(self, a=None):
        self.a = 1


class NoEstimator:
    def __init__(self):
        pass

    def fit(self, X=None, y=None):
        return self

    def predict(self, X=None):
        return None


class VargEstimator(BaseEstimator):
    """scikit-learn estimators shouldn't have vargs."""

    def __init__(self, *vargs):
        pass


#############################################################################
# The tests


def test_clone():
    # Tests that clone creates a correct deep copy.
    # We create an estimator, make a copy of its original state
    # (which, in this case, is the current state of the estimator),
    # and check that the obtained copy is a correct deep copy.

    from sklearn.feature_selection import SelectFpr, f_classif

    selector = SelectFpr(f_classif, alpha=0.1)
    new_selector = clone(selector)
    assert selector is not new_selector
    assert selector.get_params() == new_selector.get_params()

    selector = SelectFpr(f_classif, alpha=np.zeros((10, 2)))
    new_selector = clone(selector)
    assert selector is not new_selector


def test_clone_2():
    # Tests that clone doesn't copy everything.
    # We first create an estimator, give it an own attribute, and
    # make a copy of its original state. Then we check that the copy doesn't
    # have the specific attribute we manually added to the initial estimator.

    from sklearn.feature_selection import SelectFpr, f_classif

    selector = SelectFpr(f_classif, alpha=0.1)
    selector.own_attribute = "test"
    new_selector = clone(selector)
    assert not hasattr(new_selector, "own_attribute")


def test_clone_buggy():
    # Check that clone raises an error on buggy estimators.
    buggy = Buggy()
    buggy.a = 2
    with pytest.raises(RuntimeError):
        clone(buggy)

    no_estimator = NoEstimator()
    with pytest.raises(TypeError):
        clone(no_estimator)

    varg_est = VargEstimator()
    with pytest.raises(RuntimeError):
        clone(varg_est)

    est = ModifyInitParams()
    with pytest.raises(RuntimeError):
        clone(est)


def test_clone_empty_array():
    # Regression test for cloning estimators with empty arrays
    clf = MyEstimator(empty=np.array([]))
    clf2 = clone(clf)
    assert_array_equal(clf.empty, clf2.empty)

    clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]])))
    clf2 = clone(clf)
    assert_array_equal(clf.empty.data, clf2.empty.data)


def test_clone_nan():
    # Regression test for cloning estimators with default parameter as np.nan
    clf = MyEstimator(empty=np.nan)
    clf2 = clone(clf)

    assert clf.empty is clf2.empty


def test_clone_dict():
    # test that clone creates a clone of a dict
    orig = {"a": MyEstimator()}
    cloned = clone(orig)
    assert orig["a"] is not cloned["a"]


def test_clone_sparse_matrices():
    sparse_matrix_classes = [
        cls
        for name in dir(sp)
        if name.endswith("_matrix") and type(cls := getattr(sp, name)) is type
    ]

    for cls in sparse_matrix_classes:
        sparse_matrix = cls(np.eye(5))
        clf = MyEstimator(empty=sparse_matrix)
        clf_cloned = clone(clf)
        assert clf.empty.__class__ is clf_cloned.empty.__class__
        assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())


def test_clone_estimator_types():
    # Check that clone works for parameters that are types rather than
    # instances
    clf = MyEstimator(empty=MyEstimator)
    clf2 = clone(clf)

    assert clf.empty is clf2.empty


def test_clone_class_rather_than_instance():
    # Check that clone raises expected error message when
    # cloning class rather than instance
    msg = "You should provide an instance of scikit-learn estimator"
    with pytest.raises(TypeError, match=msg):
        clone(MyEstimator)


def test_repr():
    # Smoke test the repr of the base estimator.
    my_estimator = MyEstimator()
    repr(my_estimator)
    test = T(K(), K())
    assert repr(test) == "T(a=K(), b=K())"

    some_est = T(a=["long_params"] * 1000)
    assert len(repr(some_est)) == 485


def test_str():
    # Smoke test the str of the base estimator
    my_estimator = MyEstimator()
    str(my_estimator)


def test_get_params():
    test = T(K(), K)

    assert "a__d" in test.get_params(deep=True)
    assert "a__d" not in test.get_params(deep=False)

    test.set_params(a__d=2)
    assert test.a.d == 2

    with pytest.raises(ValueError):
        test.set_params(a__a=2)


# TODO(1.8): Remove this test when the deprecation is removed
def test_is_estimator_type_class():
    with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"):
        assert is_classifier(SVC)

    with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"):
        assert is_regressor(SVR)

    with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"):
        assert is_clusterer(KMeans)

    with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"):
        assert is_outlier_detector(IsolationForest)


@pytest.mark.parametrize(
    "estimator, expected_result",
    [
        (SVC(), True),
        (GridSearchCV(SVC(), {"C": [0.1, 1]}), True),
        (Pipeline([("svc", SVC())]), True),
        (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), True),
        (SVR(), False),
        (GridSearchCV(SVR(), {"C": [0.1, 1]}), False),
        (Pipeline([("svr", SVR())]), False),
        (Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), False),
    ],
)
def test_is_classifier(estimator, expected_result):
    assert is_classifier(estimator) == expected_result


@pytest.mark.parametrize(
    "estimator, expected_result",
    [
        (SVR(), True),
        (GridSearchCV(SVR(), {"C": [0.1, 1]}), True),
        (Pipeline([("svr", SVR())]), True),
        (Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), True),
        (SVC(), False),
        (GridSearchCV(SVC(), {"C": [0.1, 1]}), False),
        (Pipeline([("svc", SVC())]), False),
        (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False),
    ],
)
def test_is_regressor(estimator, expected_result):
    assert is_regressor(estimator) == expected_result


@pytest.mark.parametrize(
    "estimator, expected_result",
    [
        (KMeans(), True),
        (GridSearchCV(KMeans(), {"n_clusters": [3, 8]}), True),
        (Pipeline([("km", KMeans())]), True),
        (Pipeline([("km_cv", GridSearchCV(KMeans(), {"n_clusters": [3, 8]}))]), True),
        (SVC(), False),
        (GridSearchCV(SVC(), {"C": [0.1, 1]}), False),
        (Pipeline([("svc", SVC())]), False),
        (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False),
    ],
)
def test_is_clusterer(estimator, expected_result):
    assert is_clusterer(estimator) == expected_result


def test_set_params():
    # test nested estimator parameter setting
    clf = Pipeline([("svc", SVC())])

    # non-existing parameter in svc
    with pytest.raises(ValueError):
        clf.set_params(svc__stupid_param=True)

    # non-existing parameter of pipeline
    with pytest.raises(ValueError):
        clf.set_params(svm__stupid_param=True)

    # we don't currently catch if the things in pipeline are estimators
    # bad_pipeline = Pipeline([("bad", NoEstimator())])
    # with pytest.raises(AttributeError):
    #    bad_pipeline.set_params(bad__stupid_param=True)


def test_set_params_passes_all_parameters():
    # Make sure all parameters are passed together to set_params
    # of nested estimator. Regression test for #9944

    class TestDecisionTree(DecisionTreeClassifier):
        def set_params(self, **kwargs):
            super().set_params(**kwargs)
            # expected_kwargs is in test scope
            assert kwargs == expected_kwargs
            return self

    expected_kwargs = {"max_depth": 5, "min_samples_leaf": 2}
    for est in [
        Pipeline([("estimator", TestDecisionTree())]),
        GridSearchCV(TestDecisionTree(), {}),
    ]:
        est.set_params(estimator__max_depth=5, estimator__min_samples_leaf=2)


def test_set_params_updates_valid_params():
    # Check that set_params tries to set SVC().C, not
    # DecisionTreeClassifier().C
    gscv = GridSearchCV(DecisionTreeClassifier(), {})
    gscv.set_params(estimator=SVC(), estimator__C=42.0)
    assert gscv.estimator.C == 42.0


@pytest.mark.parametrize(
    "tree,dataset",
    [
        (
            DecisionTreeClassifier(max_depth=2, random_state=0),
            datasets.make_classification(random_state=0),
        ),
        (
            DecisionTreeRegressor(max_depth=2, random_state=0),
            datasets.make_regression(random_state=0),
        ),
    ],
)
def test_score_sample_weight(tree, dataset):
    rng = np.random.RandomState(0)
    # check that the score with and without sample weights are different
    X, y = dataset

    tree.fit(X, y)
    # generate random sample weights
    sample_weight = rng.randint(1, 10, size=len(y))
    score_unweighted = tree.score(X, y)
    score_weighted = tree.score(X, y, sample_weight=sample_weight)
    msg = "Unweighted and weighted scores are unexpectedly equal"
    assert score_unweighted != score_weighted, msg


def test_clone_pandas_dataframe():
    class DummyEstimator(TransformerMixin, BaseEstimator):
        """This is a dummy class for generating numerical features

        This feature extractor extracts numerical features from pandas data
        frame.

        Parameters
        ----------

        df: pandas data frame
            The pandas data frame parameter.

        Notes
        -----
        """

        def __init__(self, df=None, scalar_param=1):
            self.df = df
            self.scalar_param = scalar_param

        def fit(self, X, y=None):
            pass

        def transform(self, X):
            pass

    # build and clone estimator
    d = np.arange(10)
    df = MockDataFrame(d)
    e = DummyEstimator(df, scalar_param=1)
    cloned_e = clone(e)

    # the test
    assert (e.df == cloned_e.df).values.all()
    assert e.scalar_param == cloned_e.scalar_param


def test_clone_protocol():
    """Checks that clone works with `__sklearn_clone__` protocol."""

    class FrozenEstimator(BaseEstimator):
        def __init__(self, fitted_estimator):
            self.fitted_estimator = fitted_estimator

        def __getattr__(self, name):
            return getattr(self.fitted_estimator, name)

        def __sklearn_clone__(self):
            return self

        def fit(self, *args, **kwargs):
            return self

        def fit_transform(self, *args, **kwargs):
            return self.fitted_estimator.transform(*args, **kwargs)

    X = np.array([[-1, -1], [-2, -1], [-3, -2]])
    pca = PCA().fit(X)
    components = pca.components_

    frozen_pca = FrozenEstimator(pca)
    assert_allclose(frozen_pca.components_, components)

    # Calling PCA methods such as `get_feature_names_out` still works
    assert_array_equal(frozen_pca.get_feature_names_out(), pca.get_feature_names_out())

    # Fitting on a new data does not alter `components_`
    X_new = np.asarray([[-1, 2], [3, 4], [1, 2]])
    frozen_pca.fit(X_new)
    assert_allclose(frozen_pca.components_, components)

    # `fit_transform` does not alter state
    frozen_pca.fit_transform(X_new)
    assert_allclose(frozen_pca.components_, components)

    # Cloning estimator is a no-op
    clone_frozen_pca = clone(frozen_pca)
    assert clone_frozen_pca is frozen_pca
    assert_allclose(clone_frozen_pca.components_, components)


def test_pickle_version_warning_is_not_raised_with_matching_version():
    iris = datasets.load_iris()
    tree = DecisionTreeClassifier().fit(iris.data, iris.target)
    tree_pickle = pickle.dumps(tree)
    assert b"_sklearn_version" in tree_pickle

    with warnings.catch_warnings():
        warnings.simplefilter("error")
        tree_restored = pickle.loads(tree_pickle)

    # test that we can predict with the restored decision tree classifier
    score_of_original = tree.score(iris.data, iris.target)
    score_of_restored = tree_restored.score(iris.data, iris.target)
    assert score_of_original == score_of_restored


class TreeBadVersion(DecisionTreeClassifier):
    def __getstate__(self):
        return dict(self.__dict__.items(), _sklearn_version="something")


pickle_error_message = (
    "Trying to unpickle estimator {estimator} from "
    "version {old_version} when using version "
    "{current_version}. This might "
    "lead to breaking code or invalid results. "
    "Use at your own risk."
)


def test_pickle_version_warning_is_issued_upon_different_version():
    iris = datasets.load_iris()
    tree = TreeBadVersion().fit(iris.data, iris.target)
    tree_pickle_other = pickle.dumps(tree)
    message = pickle_error_message.format(
        estimator="TreeBadVersion",
        old_version="something",
        current_version=sklearn.__version__,
    )
    with pytest.warns(UserWarning, match=message) as warning_record:
        pickle.loads(tree_pickle_other)

    message = warning_record.list[0].message
    assert isinstance(message, InconsistentVersionWarning)
    assert message.estimator_name == "TreeBadVersion"
    assert message.original_sklearn_version == "something"
    assert message.current_sklearn_version == sklearn.__version__


class TreeNoVersion(DecisionTreeClassifier):
    def __getstate__(self):
        return self.__dict__


def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
    iris = datasets.load_iris()
    # TreeNoVersion has no getstate, like pre-0.18
    tree = TreeNoVersion().fit(iris.data, iris.target)

    tree_pickle_noversion = pickle.dumps(tree)
    assert b"_sklearn_version" not in tree_pickle_noversion
    message = pickle_error_message.format(
        estimator="TreeNoVersion",
        old_version="pre-0.18",
        current_version=sklearn.__version__,
    )
    # check we got the warning about using pre-0.18 pickle
    with pytest.warns(UserWarning, match=message):
        pickle.loads(tree_pickle_noversion)


def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
    iris = datasets.load_iris()
    tree = TreeNoVersion().fit(iris.data, iris.target)
    tree_pickle_noversion = pickle.dumps(tree)
    try:
        module_backup = TreeNoVersion.__module__
        TreeNoVersion.__module__ = "notsklearn"

        with warnings.catch_warnings():
            warnings.simplefilter("error")

            pickle.loads(tree_pickle_noversion)
    finally:
        TreeNoVersion.__module__ = module_backup


class DontPickleAttributeMixin:
    def __getstate__(self):
        data = self.__dict__.copy()
        data["_attribute_not_pickled"] = None
        return data

    def __setstate__(self, state):
        state["_restored"] = True
        self.__dict__.update(state)


class MultiInheritanceEstimator(DontPickleAttributeMixin, BaseEstimator):
    def __init__(self, attribute_pickled=5):
        self.attribute_pickled = attribute_pickled
        self._attribute_not_pickled = None


def test_pickling_when_getstate_is_overwritten_by_mixin():
    estimator = MultiInheritanceEstimator()
    estimator._attribute_not_pickled = "this attribute should not be pickled"

    serialized = pickle.dumps(estimator)
    estimator_restored = pickle.loads(serialized)
    assert estimator_restored.attribute_pickled == 5
    assert estimator_restored._attribute_not_pickled is None
    assert estimator_restored._restored


def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn():
    try:
        estimator = MultiInheritanceEstimator()
        text = "this attribute should not be pickled"
        estimator._attribute_not_pickled = text
        old_mod = type(estimator).__module__
        type(estimator).__module__ = "notsklearn"

        serialized = estimator.__getstate__()
        assert serialized == {"_attribute_not_pickled": None, "attribute_pickled": 5}

        serialized["attribute_pickled"] = 4
        estimator.__setstate__(serialized)
        assert estimator.attribute_pickled == 4
        assert estimator._restored
    finally:
        type(estimator).__module__ = old_mod


class SingleInheritanceEstimator(BaseEstimator):
    def __init__(self, attribute_pickled=5):
        self.attribute_pickled = attribute_pickled
        self._attribute_not_pickled = None

    def __getstate__(self):
        state = super().__getstate__()
        state["_attribute_not_pickled"] = None
        return state


def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
    estimator = SingleInheritanceEstimator()
    estimator._attribute_not_pickled = "this attribute should not be pickled"

    serialized = pickle.dumps(estimator)
    estimator_restored = pickle.loads(serialized)
    assert estimator_restored.attribute_pickled == 5
    assert estimator_restored._attribute_not_pickled is None


def test_tag_inheritance():
    # test that changing tags by inheritance is not allowed

    nan_tag_est = NaNTag()
    no_nan_tag_est = NoNaNTag()
    assert nan_tag_est.__sklearn_tags__().input_tags.allow_nan
    assert not no_nan_tag_est.__sklearn_tags__().input_tags.allow_nan

    redefine_tags_est = OverrideTag()
    assert not redefine_tags_est.__sklearn_tags__().input_tags.allow_nan

    diamond_tag_est = DiamondOverwriteTag()
    assert diamond_tag_est.__sklearn_tags__().input_tags.allow_nan

    inherit_diamond_tag_est = InheritDiamondOverwriteTag()
    assert inherit_diamond_tag_est.__sklearn_tags__().input_tags.allow_nan


def test_raises_on_get_params_non_attribute():
    class MyEstimator(BaseEstimator):
        def __init__(self, param=5):
            pass

        def fit(self, X, y=None):
            return self

    est = MyEstimator()
    msg = "'MyEstimator' object has no attribute 'param'"

    with pytest.raises(AttributeError, match=msg):
        est.get_params()


def test_repr_mimebundle_():
    # Checks the display configuration flag controls the json output
    tree = DecisionTreeClassifier()
    output = tree._repr_mimebundle_()
    assert "text/plain" in output
    assert "text/html" in output

    with config_context(display="text"):
        output = tree._repr_mimebundle_()
        assert "text/plain" in output
        assert "text/html" not in output


def test_repr_html_wraps():
    # Checks the display configuration flag controls the html output
    tree = DecisionTreeClassifier()

    output = tree._repr_html_()
    assert "<style>" in output

    with config_context(display="text"):
        msg = "_repr_html_ is only defined when"
        with pytest.raises(AttributeError, match=msg):
            output = tree._repr_html_()


def test_n_features_in_validation():
    """Check that `_check_n_features` validates data when reset=False"""
    est = MyEstimator()
    X_train = [[1, 2, 3], [4, 5, 6]]
    _check_n_features(est, X_train, reset=True)

    assert est.n_features_in_ == 3

    msg = "X does not contain any features, but MyEstimator is expecting 3 features"
    with pytest.raises(ValueError, match=msg):
        _check_n_features(est, "invalid X", reset=False)


def test_n_features_in_no_validation():
    """Check that `_check_n_features` does not validate data when
    n_features_in_ is not defined."""
    est = MyEstimator()
    _check_n_features(est, "invalid X", reset=True)

    assert not hasattr(est, "n_features_in_")

    # does not raise
    _check_n_features(est, "invalid X", reset=False)


def test_feature_names_in():
    """Check that feature_name_in are recorded by `_validate_data`"""
    pd = pytest.importorskip("pandas")
    iris = datasets.load_iris()
    X_np = iris.data
    df = pd.DataFrame(X_np, columns=iris.feature_names)

    class NoOpTransformer(TransformerMixin, BaseEstimator):
        def fit(self, X, y=None):
            validate_data(self, X)
            return self

        def transform(self, X):
            validate_data(self, X, reset=False)
            return X

    # fit on dataframe saves the feature names
    trans = NoOpTransformer().fit(df)
    assert_array_equal(trans.feature_names_in_, df.columns)

    # fit again but on ndarray does not keep the previous feature names (see #21383)
    trans.fit(X_np)
    assert not hasattr(trans, "feature_names_in_")

    trans.fit(df)
    msg = "The feature names should match those that were passed"
    df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1])
    with pytest.raises(ValueError, match=msg):
        trans.transform(df_bad)

    # warns when fitted on dataframe and transforming a ndarray
    msg = (
        "X does not have valid feature names, but NoOpTransformer was "
        "fitted with feature names"
    )
    with pytest.warns(UserWarning, match=msg):
        trans.transform(X_np)

    # warns when fitted on a ndarray and transforming dataframe
    msg = "X has feature names, but NoOpTransformer was fitted without feature names"
    trans = NoOpTransformer().fit(X_np)
    with pytest.warns(UserWarning, match=msg):
        trans.transform(df)

    # fit on dataframe with all integer feature names works without warning
    df_int_names = pd.DataFrame(X_np)
    trans = NoOpTransformer()
    with warnings.catch_warnings():
        warnings.simplefilter("error", UserWarning)
        trans.fit(df_int_names)

    # fit on dataframe with no feature names or all integer feature names
    # -> do not warn on transform
    Xs = [X_np, df_int_names]
    for X in Xs:
        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            trans.transform(X)

    # fit on dataframe with feature names that are mixed raises an error:
    df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2])
    trans = NoOpTransformer()
    msg = re.escape(
        "Feature names are only supported if all input features have string names, "
        "but your input has ['int', 'str'] as feature name / column name types. "
        "If you want feature names to be stored and validated, you must convert "
        "them all to strings, by using X.columns = X.columns.astype(str) for "
        "example. Otherwise you can remove feature / column names from your input "
        "data, or convert them all to a non-string data type."
    )
    with pytest.raises(TypeError, match=msg):
        trans.fit(df_mixed)

    # transform on feature names that are mixed also raises:
    with pytest.raises(TypeError, match=msg):
        trans.transform(df_mixed)


def test_validate_data_skip_check_array():
    """Check skip_check_array option of _validate_data."""

    pd = pytest.importorskip("pandas")
    iris = datasets.load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    y = pd.Series(iris.target)

    class NoOpTransformer(TransformerMixin, BaseEstimator):
        pass

    no_op = NoOpTransformer()
    X_np_out = validate_data(no_op, df, skip_check_array=False)
    assert isinstance(X_np_out, np.ndarray)
    assert_allclose(X_np_out, df.to_numpy())

    X_df_out = validate_data(no_op, df, skip_check_array=True)
    assert X_df_out is df

    y_np_out = validate_data(no_op, y=y, skip_check_array=False)
    assert isinstance(y_np_out, np.ndarray)
    assert_allclose(y_np_out, y.to_numpy())

    y_series_out = validate_data(no_op, y=y, skip_check_array=True)
    assert y_series_out is y

    X_np_out, y_np_out = validate_data(no_op, df, y, skip_check_array=False)
    assert isinstance(X_np_out, np.ndarray)
    assert_allclose(X_np_out, df.to_numpy())
    assert isinstance(y_np_out, np.ndarray)
    assert_allclose(y_np_out, y.to_numpy())

    X_df_out, y_series_out = validate_data(no_op, df, y, skip_check_array=True)
    assert X_df_out is df
    assert y_series_out is y

    msg = "Validation should be done on X, y or both."
    with pytest.raises(ValueError, match=msg):
        validate_data(no_op)


def test_clone_keeps_output_config():
    """Check that clone keeps the set_output config."""

    ss = StandardScaler().set_output(transform="pandas")
    config = _get_output_config("transform", ss)

    ss_clone = clone(ss)
    config_clone = _get_output_config("transform", ss_clone)
    assert config == config_clone


class _Empty:
    pass


class EmptyEstimator(_Empty, BaseEstimator):
    pass


@pytest.mark.parametrize("estimator", [BaseEstimator(), EmptyEstimator()])
def test_estimator_empty_instance_dict(estimator):
    """Check that ``__getstate__`` returns an empty ``dict`` with an empty
    instance.

    Python 3.11+ changed behaviour by returning ``None`` instead of raising an
    ``AttributeError``. Non-regression test for gh-25188.
    """
    state = estimator.__getstate__()
    expected = {"_sklearn_version": sklearn.__version__}
    assert state == expected

    # this should not raise
    pickle.loads(pickle.dumps(BaseEstimator()))


def test_estimator_getstate_using_slots_error_message():
    """Using a `BaseEstimator` with `__slots__` is not supported."""

    class WithSlots:
        __slots__ = ("x",)

    class Estimator(BaseEstimator, WithSlots):
        pass

    msg = (
        "You cannot use `__slots__` in objects inheriting from "
        "`sklearn.base.BaseEstimator`"
    )

    with pytest.raises(TypeError, match=msg):
        Estimator().__getstate__()

    with pytest.raises(TypeError, match=msg):
        pickle.dumps(Estimator())


@pytest.mark.parametrize(
    "constructor_name, minversion",
    [
        ("dataframe", "1.5.0"),
        ("pyarrow", "12.0.0"),
        ("polars", "0.20.23"),
    ],
)
def test_dataframe_protocol(constructor_name, minversion):
    """Uses the dataframe exchange protocol to get feature names."""
    data = [[1, 4, 2], [3, 3, 6]]
    columns = ["col_0", "col_1", "col_2"]
    df = _convert_container(
        data, constructor_name, columns_name=columns, minversion=minversion
    )

    class NoOpTransformer(TransformerMixin, BaseEstimator):
        def fit(self, X, y=None):
            validate_data(self, X)
            return self

        def transform(self, X):
            return validate_data(self, X, reset=False)

    no_op = NoOpTransformer()
    no_op.fit(df)
    assert_array_equal(no_op.feature_names_in_, columns)
    X_out = no_op.transform(df)

    if constructor_name != "pyarrow":
        # pyarrow does not work with `np.asarray`
        # https://github.com/apache/arrow/issues/34886
        assert_allclose(df, X_out)

    bad_names = ["a", "b", "c"]
    df_bad = _convert_container(data, constructor_name, columns_name=bad_names)
    with pytest.raises(ValueError, match="The feature names should match"):
        no_op.transform(df_bad)


@config_context(enable_metadata_routing=True)
def test_transformer_fit_transform_with_metadata_in_transform():
    """Test that having a transformer with metadata for transform raises a
    warning when calling fit_transform."""

    class CustomTransformer(BaseEstimator, TransformerMixin):
        def fit(self, X, y=None, prop=None):
            return self

        def transform(self, X, prop=None):
            return X

    # passing the metadata to `fit_transform` should raise a warning since it
    # could potentially be consumed by `transform`
    with pytest.warns(UserWarning, match="`transform` method which consumes metadata"):
        CustomTransformer().set_transform_request(prop=True).fit_transform(
            [[1]], [1], prop=1
        )

    # not passing a metadata which can potentially be consumed by `transform` should
    # not raise a warning
    with warnings.catch_warnings(record=True) as record:
        CustomTransformer().set_transform_request(prop=True).fit_transform([[1]], [1])
        assert len(record) == 0


@config_context(enable_metadata_routing=True)
def test_outlier_mixin_fit_predict_with_metadata_in_predict():
    """Test that having an OutlierMixin with metadata for predict raises a
    warning when calling fit_predict."""

    class CustomOutlierDetector(BaseEstimator, OutlierMixin):
        def fit(self, X, y=None, prop=None):
            return self

        def predict(self, X, prop=None):
            return X

    # passing the metadata to `fit_predict` should raise a warning since it
    # could potentially be consumed by `predict`
    with pytest.warns(UserWarning, match="`predict` method which consumes metadata"):
        CustomOutlierDetector().set_predict_request(prop=True).fit_predict(
            [[1]], [1], prop=1
        )

    # not passing a metadata which can potentially be consumed by `predict` should
    # not raise a warning
    with warnings.catch_warnings(record=True) as record:
        CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1])
        assert len(record) == 0
