"""
Metadata Routing Utility Tests
"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import re

import numpy as np
import pytest

from sklearn import config_context
from sklearn.base import (
    BaseEstimator,
    clone,
)
from sklearn.exceptions import UnsetMetadataPassedError
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.tests.metadata_routing_common import (
    ConsumingClassifier,
    ConsumingRegressor,
    ConsumingTransformer,
    MetaRegressor,
    MetaTransformer,
    NonConsumingClassifier,
    WeightedMetaClassifier,
    WeightedMetaRegressor,
    _Registry,
    assert_request_equal,
    assert_request_is_empty,
    check_recorded_metadata,
)
from sklearn.utils import metadata_routing
from sklearn.utils._metadata_requests import (
    COMPOSITE_METHODS,
    METHODS,
    SIMPLE_METHODS,
    MethodMetadataRequest,
    MethodPair,
    _MetadataRequester,
    request_is_alias,
    request_is_valid,
)
from sklearn.utils.metadata_routing import (
    MetadataRequest,
    MetadataRouter,
    MethodMapping,
    _RoutingNotSupportedMixin,
    get_routing_for_object,
    process_routing,
)
from sklearn.utils.validation import check_is_fitted

rng = np.random.RandomState(42)
N, M = 100, 4
X = rng.rand(N, M)
y = rng.randint(0, 2, size=N)
my_groups = rng.randint(0, 10, size=N)
my_weights = rng.rand(N)
my_other_weights = rng.rand(N)


class SimplePipeline(BaseEstimator):
    """A very simple pipeline, assuming the last step is always a predictor.

    Parameters
    ----------
    steps : iterable of objects
        An iterable of transformers with the last step being a predictor.
    """

    def __init__(self, steps):
        self.steps = steps

    def fit(self, X, y, **fit_params):
        self.steps_ = []
        params = process_routing(self, "fit", **fit_params)
        X_transformed = X
        for i, step in enumerate(self.steps[:-1]):
            transformer = clone(step).fit(
                X_transformed, y, **params.get(f"step_{i}").fit
            )
            self.steps_.append(transformer)
            X_transformed = transformer.transform(
                X_transformed, **params.get(f"step_{i}").transform
            )

        self.steps_.append(
            clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit)
        )
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        X_transformed = X
        params = process_routing(self, "predict", **predict_params)
        for i, step in enumerate(self.steps_[:-1]):
            X_transformed = step.transform(X, **params.get(f"step_{i}").transform)

        return self.steps_[-1].predict(X_transformed, **params.predictor.predict)

    def get_metadata_routing(self):
        router = MetadataRouter(owner=self.__class__.__name__)
        for i, step in enumerate(self.steps[:-1]):
            router.add(
                **{f"step_{i}": step},
                method_mapping=MethodMapping()
                .add(caller="fit", callee="fit")
                .add(caller="fit", callee="transform")
                .add(caller="predict", callee="transform"),
            )
        router.add(
            predictor=self.steps[-1],
            method_mapping=MethodMapping()
            .add(caller="fit", callee="fit")
            .add(caller="predict", callee="predict"),
        )
        return router


@config_context(enable_metadata_routing=True)
def test_assert_request_is_empty():
    requests = MetadataRequest(owner="test")
    assert_request_is_empty(requests)

    requests.fit.add_request(param="foo", alias=None)
    # this should still work, since None is the default value
    assert_request_is_empty(requests)

    requests.fit.add_request(param="bar", alias="value")
    with pytest.raises(AssertionError):
        # now requests is no more empty
        assert_request_is_empty(requests)

    # but one can exclude a method
    assert_request_is_empty(requests, exclude="fit")

    requests.score.add_request(param="carrot", alias=True)
    with pytest.raises(AssertionError):
        # excluding `fit` is not enough
        assert_request_is_empty(requests, exclude="fit")

    # and excluding both fit and score would avoid an exception
    assert_request_is_empty(requests, exclude=["fit", "score"])

    # test if a router is empty
    assert_request_is_empty(
        MetadataRouter(owner="test")
        .add_self_request(WeightedMetaRegressor(estimator=None))
        .add(
            estimator=ConsumingRegressor(),
            method_mapping=MethodMapping().add(caller="fit", callee="fit"),
        )
    )


@pytest.mark.parametrize(
    "estimator",
    [
        ConsumingClassifier(registry=_Registry()),
        ConsumingRegressor(registry=_Registry()),
        ConsumingTransformer(registry=_Registry()),
        WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()),
        WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()),
    ],
)
@config_context(enable_metadata_routing=True)
def test_estimator_puts_self_in_registry(estimator):
    """Check that an estimator puts itself in the registry upon fit."""
    estimator.fit(X, y)
    assert estimator in estimator.registry


@pytest.mark.parametrize(
    "val, res",
    [
        (False, False),
        (True, False),
        (None, False),
        ("$UNUSED$", False),
        ("$WARN$", False),
        ("invalid-input", False),
        ("valid_arg", True),
    ],
)
@config_context(enable_metadata_routing=True)
def test_request_type_is_alias(val, res):
    # Test request_is_alias
    assert request_is_alias(val) == res


@pytest.mark.parametrize(
    "val, res",
    [
        (False, True),
        (True, True),
        (None, True),
        ("$UNUSED$", True),
        ("$WARN$", True),
        ("invalid-input", False),
        ("alias_arg", False),
    ],
)
@config_context(enable_metadata_routing=True)
def test_request_type_is_valid(val, res):
    # Test request_is_valid
    assert request_is_valid(val) == res


@config_context(enable_metadata_routing=True)
def test_default_requests():
    class OddEstimator(BaseEstimator):
        __metadata_request__fit = {
            # set a different default request
            "sample_weight": True
        }  # type: ignore

    odd_request = get_routing_for_object(OddEstimator())
    assert odd_request.fit.requests == {"sample_weight": True}

    # check other test estimators
    assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests)
    assert_request_is_empty(NonConsumingClassifier().get_metadata_routing())

    trs_request = get_routing_for_object(ConsumingTransformer())
    assert trs_request.fit.requests == {
        "sample_weight": None,
        "metadata": None,
    }
    assert trs_request.transform.requests == {"metadata": None, "sample_weight": None}
    assert_request_is_empty(trs_request)

    est_request = get_routing_for_object(ConsumingClassifier())
    assert est_request.fit.requests == {
        "sample_weight": None,
        "metadata": None,
    }
    assert_request_is_empty(est_request)


@config_context(enable_metadata_routing=True)
def test_default_request_override():
    """Test that default requests are correctly overridden regardless of the ASCII order
    of the class names, hence testing small and capital letter class name starts.
    Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28430
    """

    class Base(BaseEstimator):
        __metadata_request__split = {"groups": True}

    class class_1(Base):
        __metadata_request__split = {"groups": "sample_domain"}

    class Class_1(Base):
        __metadata_request__split = {"groups": "sample_domain"}

    assert_request_equal(
        class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}}
    )
    assert_request_equal(
        Class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}}
    )


@config_context(enable_metadata_routing=True)
def test_process_routing_invalid_method():
    with pytest.raises(TypeError, match="Can only route and process input"):
        process_routing(ConsumingClassifier(), "invalid_method", groups=my_groups)


@config_context(enable_metadata_routing=True)
def test_process_routing_invalid_object():
    class InvalidObject:
        pass

    with pytest.raises(AttributeError, match="either implement the routing method"):
        process_routing(InvalidObject(), "fit", groups=my_groups)


@pytest.mark.parametrize("method", METHODS)
@pytest.mark.parametrize("default", [None, "default", []])
@config_context(enable_metadata_routing=True)
def test_process_routing_empty_params_get_with_default(method, default):
    empty_params = {}
    routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params)

    # Behaviour should be an empty dictionary returned for each method when retrieved.
    params_for_method = routed_params[method]
    assert isinstance(params_for_method, dict)
    assert set(params_for_method.keys()) == set(METHODS)

    # No default to `get` should be equivalent to the default
    default_params_for_method = routed_params.get(method, default=default)
    assert default_params_for_method == params_for_method


@config_context(enable_metadata_routing=True)
def test_simple_metadata_routing():
    # Tests that metadata is properly routed

    # The underlying estimator doesn't accept or request metadata
    clf = WeightedMetaClassifier(estimator=NonConsumingClassifier())
    clf.fit(X, y)

    # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying
    # estimator
    clf = WeightedMetaClassifier(estimator=NonConsumingClassifier())
    clf.fit(X, y, sample_weight=my_weights)

    # If the estimator accepts the metadata but doesn't explicitly say it doesn't
    # need it, there's an error
    clf = WeightedMetaClassifier(estimator=ConsumingClassifier())
    err_message = (
        "[sample_weight] are passed but are not explicitly set as requested or"
        " not requested for ConsumingClassifier.fit"
    )
    with pytest.raises(ValueError, match=re.escape(err_message)):
        clf.fit(X, y, sample_weight=my_weights)

    # Explicitly saying the estimator doesn't need it, makes the error go away,
    # because in this case `WeightedMetaClassifier` consumes `sample_weight`. If
    # there was no consumer of sample_weight, passing it would result in an
    # error.
    clf = WeightedMetaClassifier(
        estimator=ConsumingClassifier().set_fit_request(sample_weight=False)
    )
    # this doesn't raise since WeightedMetaClassifier itself is a consumer,
    # and passing metadata to the consumer directly is fine regardless of its
    # metadata_request values.
    clf.fit(X, y, sample_weight=my_weights)
    check_recorded_metadata(clf.estimator_, method="fit", parent="fit")

    # Requesting a metadata will make the meta-estimator forward it correctly
    clf = WeightedMetaClassifier(
        estimator=ConsumingClassifier().set_fit_request(sample_weight=True)
    )
    clf.fit(X, y, sample_weight=my_weights)
    check_recorded_metadata(
        clf.estimator_, method="fit", parent="fit", sample_weight=my_weights
    )

    # And requesting it with an alias
    clf = WeightedMetaClassifier(
        estimator=ConsumingClassifier().set_fit_request(
            sample_weight="alternative_weight"
        )
    )
    clf.fit(X, y, alternative_weight=my_weights)
    check_recorded_metadata(
        clf.estimator_, method="fit", parent="fit", sample_weight=my_weights
    )


@config_context(enable_metadata_routing=True)
def test_nested_routing():
    # check if metadata is routed in a nested routing situation.
    pipeline = SimplePipeline(
        [
            MetaTransformer(
                transformer=ConsumingTransformer()
                .set_fit_request(metadata=True, sample_weight=False)
                .set_transform_request(sample_weight=True, metadata=False)
            ),
            WeightedMetaRegressor(
                estimator=ConsumingRegressor()
                .set_fit_request(sample_weight="inner_weights", metadata=False)
                .set_predict_request(sample_weight=False)
            ).set_fit_request(sample_weight="outer_weights"),
        ]
    )
    w1, w2, w3 = [1], [2], [3]
    pipeline.fit(
        X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3
    )
    check_recorded_metadata(
        pipeline.steps_[0].transformer_,
        method="fit",
        parent="fit",
        metadata=my_groups,
    )
    check_recorded_metadata(
        pipeline.steps_[0].transformer_,
        method="transform",
        parent="fit",
        sample_weight=w1,
    )
    check_recorded_metadata(
        pipeline.steps_[1], method="fit", parent="fit", sample_weight=w2
    )
    check_recorded_metadata(
        pipeline.steps_[1].estimator_, method="fit", parent="fit", sample_weight=w3
    )

    pipeline.predict(X, sample_weight=w3)
    check_recorded_metadata(
        pipeline.steps_[0].transformer_,
        method="transform",
        parent="fit",
        sample_weight=w3,
    )


@config_context(enable_metadata_routing=True)
def test_nested_routing_conflict():
    # check if an error is raised if there's a conflict between keys
    pipeline = SimplePipeline(
        [
            MetaTransformer(
                transformer=ConsumingTransformer()
                .set_fit_request(metadata=True, sample_weight=False)
                .set_transform_request(sample_weight=True)
            ),
            WeightedMetaRegressor(
                estimator=ConsumingRegressor().set_fit_request(sample_weight=True)
            ).set_fit_request(sample_weight="outer_weights"),
        ]
    )
    w1, w2 = [1], [2]
    with pytest.raises(
        ValueError,
        match=(
            re.escape(
                "In WeightedMetaRegressor, there is a conflict on sample_weight between"
                " what is requested for this estimator and what is requested by its"
                " children. You can resolve this conflict by using an alias for the"
                " child estimator(s) requested metadata."
            )
        ),
    ):
        pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2)


@config_context(enable_metadata_routing=True)
def test_invalid_metadata():
    # check that passing wrong metadata raises an error
    trs = MetaTransformer(
        transformer=ConsumingTransformer().set_transform_request(sample_weight=True)
    )
    with pytest.raises(
        TypeError,
        match=(re.escape("transform got unexpected argument(s) {'other_param'}")),
    ):
        trs.fit(X, y).transform(X, other_param=my_weights)

    # passing a metadata which is not requested by any estimator should also raise
    trs = MetaTransformer(
        transformer=ConsumingTransformer().set_transform_request(sample_weight=False)
    )
    with pytest.raises(
        TypeError,
        match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")),
    ):
        trs.fit(X, y).transform(X, sample_weight=my_weights)


@config_context(enable_metadata_routing=True)
def test_get_metadata_routing():
    class TestDefaultsBadMethodName(_MetadataRequester):
        __metadata_request__fit = {
            "sample_weight": None,
            "my_param": None,
        }
        __metadata_request__score = {
            "sample_weight": None,
            "my_param": True,
            "my_other_param": None,
        }
        # this will raise an error since we don't understand "other_method" as a method
        __metadata_request__other_method = {"my_param": True}

    class TestDefaults(_MetadataRequester):
        __metadata_request__fit = {
            "sample_weight": None,
            "my_other_param": None,
        }
        __metadata_request__score = {
            "sample_weight": None,
            "my_param": True,
            "my_other_param": None,
        }
        __metadata_request__predict = {"my_param": True}

    with pytest.raises(
        AttributeError, match="'MetadataRequest' object has no attribute 'other_method'"
    ):
        TestDefaultsBadMethodName().get_metadata_routing()

    expected = {
        "score": {
            "my_param": True,
            "my_other_param": None,
            "sample_weight": None,
        },
        "fit": {
            "my_other_param": None,
            "sample_weight": None,
        },
        "predict": {"my_param": True},
    }
    assert_request_equal(TestDefaults().get_metadata_routing(), expected)

    est = TestDefaults().set_score_request(my_param="other_param")
    expected = {
        "score": {
            "my_param": "other_param",
            "my_other_param": None,
            "sample_weight": None,
        },
        "fit": {
            "my_other_param": None,
            "sample_weight": None,
        },
        "predict": {"my_param": True},
    }
    assert_request_equal(est.get_metadata_routing(), expected)

    est = TestDefaults().set_fit_request(sample_weight=True)
    expected = {
        "score": {
            "my_param": True,
            "my_other_param": None,
            "sample_weight": None,
        },
        "fit": {
            "my_other_param": None,
            "sample_weight": True,
        },
        "predict": {"my_param": True},
    }
    assert_request_equal(est.get_metadata_routing(), expected)


@config_context(enable_metadata_routing=True)
def test_setting_default_requests():
    # Test _get_default_requests method
    test_cases = dict()

    class ExplicitRequest(BaseEstimator):
        # `fit` doesn't accept `props` explicitly, but we want to request it
        __metadata_request__fit = {"prop": None}

        def fit(self, X, y, **kwargs):
            return self

    test_cases[ExplicitRequest] = {"prop": None}

    class ExplicitRequestOverwrite(BaseEstimator):
        # `fit` explicitly accepts `props`, but we want to change the default
        # request value from None to True
        __metadata_request__fit = {"prop": True}

        def fit(self, X, y, prop=None, **kwargs):
            return self

    test_cases[ExplicitRequestOverwrite] = {"prop": True}

    class ImplicitRequest(BaseEstimator):
        # `fit` requests `prop` and the default None should be used
        def fit(self, X, y, prop=None, **kwargs):
            return self

    test_cases[ImplicitRequest] = {"prop": None}

    class ImplicitRequestRemoval(BaseEstimator):
        # `fit` (in this class or a parent) requests `prop`, but we don't want
        # it requested at all.
        __metadata_request__fit = {"prop": metadata_routing.UNUSED}

        def fit(self, X, y, prop=None, **kwargs):
            return self

    test_cases[ImplicitRequestRemoval] = {}

    for Klass, requests in test_cases.items():
        assert get_routing_for_object(Klass()).fit.requests == requests
        assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit")
        Klass().fit(None, None)  # for coverage


@config_context(enable_metadata_routing=True)
def test_removing_non_existing_param_raises():
    """Test that removing a metadata using UNUSED which doesn't exist raises."""

    class InvalidRequestRemoval(BaseEstimator):
        # `fit` (in this class or a parent) requests `prop`, but we don't want
        # it requested at all.
        __metadata_request__fit = {"prop": metadata_routing.UNUSED}

        def fit(self, X, y, **kwargs):
            return self

    with pytest.raises(ValueError, match="Trying to remove parameter"):
        InvalidRequestRemoval().get_metadata_routing()


@config_context(enable_metadata_routing=True)
def test_method_metadata_request():
    mmr = MethodMetadataRequest(owner="test", method="fit")

    with pytest.raises(ValueError, match="The alias you're setting for"):
        mmr.add_request(param="foo", alias=1.4)

    mmr.add_request(param="foo", alias=None)
    assert mmr.requests == {"foo": None}
    mmr.add_request(param="foo", alias=False)
    assert mmr.requests == {"foo": False}
    mmr.add_request(param="foo", alias=True)
    assert mmr.requests == {"foo": True}
    mmr.add_request(param="foo", alias="foo")
    assert mmr.requests == {"foo": True}
    mmr.add_request(param="foo", alias="bar")
    assert mmr.requests == {"foo": "bar"}
    assert mmr._get_param_names(return_alias=False) == {"foo"}
    assert mmr._get_param_names(return_alias=True) == {"bar"}


@config_context(enable_metadata_routing=True)
def test_get_routing_for_object():
    class Consumer(BaseEstimator):
        __metadata_request__fit = {"prop": None}

    assert_request_is_empty(get_routing_for_object(None))
    assert_request_is_empty(get_routing_for_object(object()))

    mr = MetadataRequest(owner="test")
    mr.fit.add_request(param="foo", alias="bar")
    mr_factory = get_routing_for_object(mr)
    assert_request_is_empty(mr_factory, exclude="fit")
    assert mr_factory.fit.requests == {"foo": "bar"}

    mr = get_routing_for_object(Consumer())
    assert_request_is_empty(mr, exclude="fit")
    assert mr.fit.requests == {"prop": None}


@config_context(enable_metadata_routing=True)
def test_metadata_request_consumes_method():
    """Test that MetadataRequest().consumes() method works as expected."""
    request = MetadataRouter(owner="test")
    assert request.consumes(method="fit", params={"foo"}) == set()

    request = MetadataRequest(owner="test")
    request.fit.add_request(param="foo", alias=True)
    assert request.consumes(method="fit", params={"foo"}) == {"foo"}

    request = MetadataRequest(owner="test")
    request.fit.add_request(param="foo", alias="bar")
    assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"}


@config_context(enable_metadata_routing=True)
def test_metadata_router_consumes_method():
    """Test that MetadataRouter().consumes method works as expected."""
    # having it here instead of parametrizing the test since `set_fit_request`
    # is not available while collecting the tests.
    cases = [
        (
            WeightedMetaRegressor(
                estimator=ConsumingRegressor().set_fit_request(sample_weight=True)
            ),
            {"sample_weight"},
            {"sample_weight"},
        ),
        (
            WeightedMetaRegressor(
                estimator=ConsumingRegressor().set_fit_request(
                    sample_weight="my_weights"
                )
            ),
            {"my_weights", "sample_weight"},
            {"my_weights"},
        ),
    ]

    for obj, input, output in cases:
        assert obj.get_metadata_routing().consumes(method="fit", params=input) == output


@config_context(enable_metadata_routing=True)
def test_metaestimator_warnings():
    class WeightedMetaRegressorWarn(WeightedMetaRegressor):
        __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    with pytest.warns(
        UserWarning, match="Support for .* has recently been added to this class"
    ):
        WeightedMetaRegressorWarn(
            estimator=LinearRegression().set_fit_request(sample_weight=False)
        ).fit(X, y, sample_weight=my_weights)


@config_context(enable_metadata_routing=True)
def test_estimator_warnings():
    class ConsumingRegressorWarn(ConsumingRegressor):
        __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    with pytest.warns(
        UserWarning, match="Support for .* has recently been added to this class"
    ):
        MetaRegressor(estimator=ConsumingRegressorWarn()).fit(
            X, y, sample_weight=my_weights
        )


@config_context(enable_metadata_routing=True)
@pytest.mark.parametrize(
    "obj, string",
    [
        (
            MethodMetadataRequest(owner="test", method="fit").add_request(
                param="foo", alias="bar"
            ),
            "{'foo': 'bar'}",
        ),
        (
            MetadataRequest(owner="test"),
            "{}",
        ),
        (
            MetadataRouter(owner="test").add(
                estimator=ConsumingRegressor(),
                method_mapping=MethodMapping().add(caller="predict", callee="predict"),
            ),
            (
                "{'estimator': {'mapping': [{'caller': 'predict', 'callee':"
                " 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':"
                " None}, 'partial_fit': {'sample_weight': None, 'metadata': None},"
                " 'predict': {'sample_weight': None, 'metadata': None}, 'score':"
                " {'sample_weight': None, 'metadata': None}}}}"
            ),
        ),
    ],
)
@config_context(enable_metadata_routing=True)
def test_string_representations(obj, string):
    assert str(obj) == string


@pytest.mark.parametrize(
    "obj, method, inputs, err_cls, err_msg",
    [
        (
            MethodMapping(),
            "add",
            {"caller": "fit", "callee": "invalid"},
            ValueError,
            "Given callee",
        ),
        (
            MethodMapping(),
            "add",
            {"caller": "invalid", "callee": "fit"},
            ValueError,
            "Given caller",
        ),
        (
            MetadataRouter(owner="test"),
            "add_self_request",
            {"obj": MetadataRouter(owner="test")},
            ValueError,
            "Given `obj` is neither a `MetadataRequest` nor does it implement",
        ),
        (
            ConsumingClassifier(),
            "set_fit_request",
            {"invalid": True},
            TypeError,
            "Unexpected args",
        ),
    ],
)
@config_context(enable_metadata_routing=True)
def test_validations(obj, method, inputs, err_cls, err_msg):
    with pytest.raises(err_cls, match=err_msg):
        getattr(obj, method)(**inputs)


@config_context(enable_metadata_routing=True)
def test_methodmapping():
    mm = (
        MethodMapping()
        .add(caller="fit", callee="transform")
        .add(caller="fit", callee="fit")
    )

    mm_list = list(mm)
    assert mm_list[0] == ("fit", "transform")
    assert mm_list[1] == ("fit", "fit")

    mm = MethodMapping()
    for method in METHODS:
        mm.add(caller=method, callee=method)
        assert MethodPair(method, method) in mm._routes
    assert len(mm._routes) == len(METHODS)

    mm = MethodMapping().add(caller="score", callee="score")
    assert repr(mm) == "[{'caller': 'score', 'callee': 'score'}]"


@config_context(enable_metadata_routing=True)
def test_metadatarouter_add_self_request():
    # adding a MetadataRequest as `self` adds a copy
    request = MetadataRequest(owner="nested")
    request.fit.add_request(param="param", alias=True)
    router = MetadataRouter(owner="test").add_self_request(request)
    assert str(router._self_request) == str(request)
    # should be a copy, not the same object
    assert router._self_request is not request

    # one can add an estimator as self
    est = ConsumingRegressor().set_fit_request(sample_weight="my_weights")
    router = MetadataRouter(owner="test").add_self_request(obj=est)
    assert str(router._self_request) == str(est.get_metadata_routing())
    assert router._self_request is not est.get_metadata_routing()

    # adding a consumer+router as self should only add the consumer part
    est = WeightedMetaRegressor(
        estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights")
    )
    router = MetadataRouter(owner="test").add_self_request(obj=est)
    # _get_metadata_request() returns the consumer part of the requests
    assert str(router._self_request) == str(est._get_metadata_request())
    # get_metadata_routing() returns the complete request set, consumer and
    # router included.
    assert str(router._self_request) != str(est.get_metadata_routing())
    # it should be a copy, not the same object
    assert router._self_request is not est._get_metadata_request()


@config_context(enable_metadata_routing=True)
def test_metadata_routing_add():
    # adding one with a string `method_mapping`
    router = MetadataRouter(owner="test").add(
        est=ConsumingRegressor().set_fit_request(sample_weight="weights"),
        method_mapping=MethodMapping().add(caller="fit", callee="fit"),
    )
    assert (
        str(router)
        == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'fit'}], 'router': {'fit':"
        " {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':"
        " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':"
        " None, 'metadata': None}, 'score': {'sample_weight': None, 'metadata':"
        " None}}}}"
    )

    # adding one with an instance of MethodMapping
    router = MetadataRouter(owner="test").add(
        method_mapping=MethodMapping().add(caller="fit", callee="score"),
        est=ConsumingRegressor().set_score_request(sample_weight=True),
    )
    assert (
        str(router)
        == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'score'}], 'router':"
        " {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':"
        " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':"
        " None, 'metadata': None}, 'score': {'sample_weight': True, 'metadata':"
        " None}}}}"
    )


@config_context(enable_metadata_routing=True)
def test_metadata_routing_get_param_names():
    router = (
        MetadataRouter(owner="test")
        .add_self_request(
            WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request(
                sample_weight="self_weights"
            )
        )
        .add(
            trs=ConsumingTransformer().set_fit_request(
                sample_weight="transform_weights"
            ),
            method_mapping=MethodMapping().add(caller="fit", callee="fit"),
        )
    )

    assert (
        str(router)
        == "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':"
        " {'sample_weight': None}}, 'trs': {'mapping': [{'caller': 'fit', 'callee':"
        " 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights',"
        " 'metadata': None}, 'transform': {'sample_weight': None, 'metadata': None},"
        " 'inverse_transform': {'sample_weight': None, 'metadata': None}}}}"
    )

    assert router._get_param_names(
        method="fit", return_alias=True, ignore_self_request=False
    ) == {"transform_weights", "metadata", "self_weights"}
    # return_alias=False will return original names for "self"
    assert router._get_param_names(
        method="fit", return_alias=False, ignore_self_request=False
    ) == {"sample_weight", "metadata", "transform_weights"}
    # ignoring self would remove "sample_weight"
    assert router._get_param_names(
        method="fit", return_alias=False, ignore_self_request=True
    ) == {"metadata", "transform_weights"}
    # return_alias is ignored when ignore_self_request=True
    assert router._get_param_names(
        method="fit", return_alias=True, ignore_self_request=True
    ) == router._get_param_names(
        method="fit", return_alias=False, ignore_self_request=True
    )


@config_context(enable_metadata_routing=True)
def test_method_generation():
    # Test if all required request methods are generated.

    # TODO: these test classes can be moved to sklearn.utils._testing once we
    # have a better idea of what the commonly used classes are.
    class SimpleEstimator(BaseEstimator):
        # This class should have no set_{method}_request
        def fit(self, X, y):
            pass  # pragma: no cover

        def fit_transform(self, X, y):
            pass  # pragma: no cover

        def fit_predict(self, X, y):
            pass  # pragma: no cover

        def partial_fit(self, X, y):
            pass  # pragma: no cover

        def predict(self, X):
            pass  # pragma: no cover

        def predict_proba(self, X):
            pass  # pragma: no cover

        def predict_log_proba(self, X):
            pass  # pragma: no cover

        def decision_function(self, X):
            pass  # pragma: no cover

        def score(self, X, y):
            pass  # pragma: no cover

        def split(self, X, y=None):
            pass  # pragma: no cover

        def transform(self, X):
            pass  # pragma: no cover

        def inverse_transform(self, X):
            pass  # pragma: no cover

    for method in METHODS:
        assert not hasattr(SimpleEstimator(), f"set_{method}_request")

    class SimpleEstimator(BaseEstimator):
        # This class should have every set_{method}_request
        def fit(self, X, y, sample_weight=None):
            pass  # pragma: no cover

        def fit_transform(self, X, y, sample_weight=None):
            pass  # pragma: no cover

        def fit_predict(self, X, y, sample_weight=None):
            pass  # pragma: no cover

        def partial_fit(self, X, y, sample_weight=None):
            pass  # pragma: no cover

        def predict(self, X, sample_weight=None):
            pass  # pragma: no cover

        def predict_proba(self, X, sample_weight=None):
            pass  # pragma: no cover

        def predict_log_proba(self, X, sample_weight=None):
            pass  # pragma: no cover

        def decision_function(self, X, sample_weight=None):
            pass  # pragma: no cover

        def score(self, X, y, sample_weight=None):
            pass  # pragma: no cover

        def split(self, X, y=None, sample_weight=None):
            pass  # pragma: no cover

        def transform(self, X, sample_weight=None):
            pass  # pragma: no cover

        def inverse_transform(self, X, sample_weight=None):
            pass  # pragma: no cover

    # composite methods shouldn't have a corresponding set method.
    for method in COMPOSITE_METHODS:
        assert not hasattr(SimpleEstimator(), f"set_{method}_request")

    # simple methods should have a corresponding set method.
    for method in SIMPLE_METHODS:
        assert hasattr(SimpleEstimator(), f"set_{method}_request")


@config_context(enable_metadata_routing=True)
def test_composite_methods():
    # Test the behavior and the values of methods (composite methods) whose
    # request values are a union of requests by other methods (simple methods).
    # fit_transform and fit_predict are the only composite methods we have in
    # scikit-learn.
    class SimpleEstimator(BaseEstimator):
        # This class should have every set_{method}_request
        def fit(self, X, y, foo=None, bar=None):
            pass  # pragma: no cover

        def predict(self, X, foo=None, bar=None):
            pass  # pragma: no cover

        def transform(self, X, other_param=None):
            pass  # pragma: no cover

    est = SimpleEstimator()
    # Since no request is set for fit or predict or transform, the request for
    # fit_transform and fit_predict should also be empty.
    assert est.get_metadata_routing().fit_transform.requests == {
        "bar": None,
        "foo": None,
        "other_param": None,
    }
    assert est.get_metadata_routing().fit_predict.requests == {"bar": None, "foo": None}

    # setting the request on only one of them should raise an error
    est.set_fit_request(foo=True, bar="test")
    with pytest.raises(ValueError, match="Conflicting metadata requests for"):
        est.get_metadata_routing().fit_predict

    # setting the request on the other one should fail if not the same as the
    # first method
    est.set_predict_request(bar=True)
    with pytest.raises(ValueError, match="Conflicting metadata requests for"):
        est.get_metadata_routing().fit_predict

    # now the requests are consistent and getting the requests for fit_predict
    # shouldn't raise.
    est.set_predict_request(foo=True, bar="test")
    est.get_metadata_routing().fit_predict

    # setting the request for a none-overlapping parameter would merge them
    # together.
    est.set_transform_request(other_param=True)
    assert est.get_metadata_routing().fit_transform.requests == {
        "bar": "test",
        "foo": True,
        "other_param": True,
    }


@config_context(enable_metadata_routing=True)
def test_no_feature_flag_raises_error():
    """Test that when feature flag disabled, set_{method}_requests raises."""
    with config_context(enable_metadata_routing=False):
        with pytest.raises(RuntimeError, match="This method is only available"):
            ConsumingClassifier().set_fit_request(sample_weight=True)


@config_context(enable_metadata_routing=True)
def test_none_metadata_passed():
    """Test that passing None as metadata when not requested doesn't raise"""
    MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None)


@config_context(enable_metadata_routing=True)
def test_no_metadata_always_works():
    """Test that when no metadata is passed, having a meta-estimator which does
    not yet support metadata routing works.

    Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28246
    """

    class Estimator(_RoutingNotSupportedMixin, BaseEstimator):
        def fit(self, X, y, metadata=None):
            return self

    # This passes since no metadata is passed.
    MetaRegressor(estimator=Estimator()).fit(X, y)
    # This fails since metadata is passed but Estimator() does not support it.
    with pytest.raises(
        NotImplementedError, match="Estimator has not implemented metadata routing yet."
    ):
        MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups)


@config_context(enable_metadata_routing=True)
def test_unsetmetadatapassederror_correct():
    """Test that UnsetMetadataPassedError raises the correct error message when
    set_{method}_request is not set in nested cases."""
    weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier())
    pipe = SimplePipeline([weighted_meta])
    msg = re.escape(
        "[metadata] are passed but are not explicitly set as requested or not requested"
        " for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit."
        " Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each"
        " metadata you want to request/ignore."
    )

    with pytest.raises(UnsetMetadataPassedError, match=msg):
        pipe.fit(X, y, metadata="blah")


@config_context(enable_metadata_routing=True)
def test_unsetmetadatapassederror_correct_for_composite_methods():
    """Test that UnsetMetadataPassedError raises the correct error message when
    composite metadata request methods are not set in nested cases."""
    consuming_transformer = ConsumingTransformer()
    pipe = Pipeline([("consuming_transformer", consuming_transformer)])

    msg = re.escape(
        "[metadata] are passed but are not explicitly set as requested or not requested"
        " for ConsumingTransformer.fit_transform, which is used within"
        " Pipeline.fit_transform. Call"
        " `ConsumingTransformer.set_fit_request({metadata}=True/False)"
        ".set_transform_request({metadata}=True/False)`"
        " for each metadata you want to request/ignore."
    )
    with pytest.raises(UnsetMetadataPassedError, match=msg):
        pipe.fit_transform(X, y, metadata="blah")


@config_context(enable_metadata_routing=True)
def test_unbound_set_methods_work():
    """Tests that if the set_{method}_request is unbound, it still works.

    Also test that passing positional arguments to the set_{method}_request fails
    with the right TypeError message.

    Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28632
    """

    class A(BaseEstimator):
        def fit(self, X, y, sample_weight=None):
            return self

    error_message = re.escape(
        "set_fit_request() takes 0 positional argument but 1 were given"
    )

    # Test positional arguments error before making the descriptor method unbound.
    with pytest.raises(TypeError, match=error_message):
        A().set_fit_request(True)

    # This somehow makes the descriptor method unbound, which results in the `instance`
    # argument being None, and instead `self` being passed as a positional argument
    # to the descriptor method.
    A.set_fit_request = A.set_fit_request

    # This should pass as usual
    A().set_fit_request(sample_weight=True)

    # Test positional arguments error after making the descriptor method unbound.
    with pytest.raises(TypeError, match=error_message):
        A().set_fit_request(True)
