import inspect
from collections import defaultdict
from functools import partial

import numpy as np
from numpy.testing import assert_array_equal

from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    TransformerMixin,
    clone,
)
from sklearn.metrics._scorer import _Scorer, mean_squared_error
from sklearn.model_selection import BaseCrossValidator
from sklearn.model_selection._split import GroupsConsumerMixin
from sklearn.utils._metadata_requests import (
    SIMPLE_METHODS,
)
from sklearn.utils.metadata_routing import (
    MetadataRouter,
    MethodMapping,
    process_routing,
)
from sklearn.utils.multiclass import _check_partial_fit_first_call


def record_metadata(obj, record_default=True, **kwargs):
    """Utility function to store passed metadata to a method of obj.

    If record_default is False, kwargs whose values are "default" are skipped.
    This is so that checks on keyword arguments whose default was not changed
    are skipped.

    """
    stack = inspect.stack()
    callee = stack[1].function
    caller = stack[2].function
    if not hasattr(obj, "_records"):
        obj._records = defaultdict(lambda: defaultdict(list))
    if not record_default:
        kwargs = {
            key: val
            for key, val in kwargs.items()
            if not isinstance(val, str) or (val != "default")
        }
    obj._records[callee][caller].append(kwargs)


def check_recorded_metadata(obj, method, parent, split_params=tuple(), **kwargs):
    """Check whether the expected metadata is passed to the object's method.

    Parameters
    ----------
    obj : estimator object
        sub-estimator to check routed params for
    method : str
        sub-estimator's method where metadata is routed to, or otherwise in
        the context of metadata routing referred to as 'callee'
    parent : str
        the parent method which should have called `method`, or otherwise in
        the context of metadata routing referred to as 'caller'
    split_params : tuple, default=empty
        specifies any parameters which are to be checked as being a subset
        of the original values
    **kwargs : dict
        passed metadata
    """
    all_records = (
        getattr(obj, "_records", dict()).get(method, dict()).get(parent, list())
    )
    for record in all_records:
        # first check that the names of the metadata passed are the same as
        # expected. The names are stored as keys in `record`.
        assert set(kwargs.keys()) == set(
            record.keys()
        ), f"Expected {kwargs.keys()} vs {record.keys()}"
        for key, value in kwargs.items():
            recorded_value = record[key]
            # The following condition is used to check for any specified parameters
            # being a subset of the original values
            if key in split_params and recorded_value is not None:
                assert np.isin(recorded_value, value).all()
            else:
                if isinstance(recorded_value, np.ndarray):
                    assert_array_equal(recorded_value, value)
                else:
                    assert (
                        recorded_value is value
                    ), f"Expected {recorded_value} vs {value}. Method: {method}"


record_metadata_not_default = partial(record_metadata, record_default=False)


def assert_request_is_empty(metadata_request, exclude=None):
    """Check if a metadata request dict is empty.

    One can exclude a method or a list of methods from the check using the
    ``exclude`` parameter. If metadata_request is a MetadataRouter, then
    ``exclude`` can be of the form ``{"object" : [method, ...]}``.
    """
    if isinstance(metadata_request, MetadataRouter):
        for name, route_mapping in metadata_request:
            if exclude is not None and name in exclude:
                _exclude = exclude[name]
            else:
                _exclude = None
            assert_request_is_empty(route_mapping.router, exclude=_exclude)
        return

    exclude = [] if exclude is None else exclude
    for method in SIMPLE_METHODS:
        if method in exclude:
            continue
        mmr = getattr(metadata_request, method)
        props = [
            prop
            for prop, alias in mmr.requests.items()
            if isinstance(alias, str) or alias is not None
        ]
        assert not props


def assert_request_equal(request, dictionary):
    for method, requests in dictionary.items():
        mmr = getattr(request, method)
        assert mmr.requests == requests

    empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary]
    for method in empty_methods:
        assert not len(getattr(request, method).requests)


class _Registry(list):
    # This list is used to get a reference to the sub-estimators, which are not
    # necessarily stored on the metaestimator. We need to override __deepcopy__
    # because the sub-estimators are probably cloned, which would result in a
    # new copy of the list, but we need copy and deep copy both to return the
    # same instance.
    def __deepcopy__(self, memo):
        return self

    def __copy__(self):
        return self


class ConsumingRegressor(RegressorMixin, BaseEstimator):
    """A regressor consuming metadata.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.
    """

    def __init__(self, registry=None):
        self.registry = registry

    def partial_fit(self, X, y, sample_weight="default", metadata="default"):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return self

    def fit(self, X, y, sample_weight="default", metadata="default"):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return self

    def predict(self, X, y=None, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return np.zeros(shape=(len(X),))

    def score(self, X, y, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return 1


class NonConsumingClassifier(ClassifierMixin, BaseEstimator):
    """A classifier which accepts no metadata on any method."""

    def __init__(self, alpha=0.0):
        self.alpha = alpha

    def fit(self, X, y):
        self.classes_ = np.unique(y)
        self.coef_ = np.ones_like(X)
        return self

    def partial_fit(self, X, y, classes=None):
        return self

    def decision_function(self, X):
        return self.predict(X)

    def predict(self, X):
        y_pred = np.empty(shape=(len(X),))
        y_pred[: len(X) // 2] = 0
        y_pred[len(X) // 2 :] = 1
        return y_pred

    def predict_proba(self, X):
        # dummy probabilities to support predict_proba
        y_proba = np.empty(shape=(len(X), 2))
        y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
        y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0])
        return y_proba

    def predict_log_proba(self, X):
        # dummy probabilities to support predict_log_proba
        return self.predict_proba(X)


class NonConsumingRegressor(RegressorMixin, BaseEstimator):
    """A classifier which accepts no metadata on any method."""

    def fit(self, X, y):
        return self

    def partial_fit(self, X, y):
        return self

    def predict(self, X):
        return np.ones(len(X))  # pragma: no cover


class ConsumingClassifier(ClassifierMixin, BaseEstimator):
    """A classifier consuming metadata.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.

    alpha : float, default=0
        This parameter is only used to test the ``*SearchCV`` objects, and
        doesn't do anything.
    """

    def __init__(self, registry=None, alpha=0.0):
        self.alpha = alpha
        self.registry = registry

    def partial_fit(
        self, X, y, classes=None, sample_weight="default", metadata="default"
    ):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        _check_partial_fit_first_call(self, classes)
        return self

    def fit(self, X, y, sample_weight="default", metadata="default"):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )

        self.classes_ = np.unique(y)
        self.coef_ = np.ones_like(X)
        return self

    def predict(self, X, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        y_score = np.empty(shape=(len(X),), dtype="int8")
        y_score[len(X) // 2 :] = 0
        y_score[: len(X) // 2] = 1
        return y_score

    def predict_proba(self, X, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        y_proba = np.empty(shape=(len(X), 2))
        y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
        y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0])
        return y_proba

    def predict_log_proba(self, X, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return np.zeros(shape=(len(X), 2))

    def decision_function(self, X, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        y_score = np.empty(shape=(len(X),))
        y_score[len(X) // 2 :] = 0
        y_score[: len(X) // 2] = 1
        return y_score

    def score(self, X, y, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return 1


class ConsumingTransformer(TransformerMixin, BaseEstimator):
    """A transformer which accepts metadata on fit and transform.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.
    """

    def __init__(self, registry=None):
        self.registry = registry

    def fit(self, X, y=None, sample_weight="default", metadata="default"):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        self.fitted_ = True
        return self

    def transform(self, X, sample_weight="default", metadata="default"):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return X + 1

    def fit_transform(self, X, y, sample_weight="default", metadata="default"):
        # implementing ``fit_transform`` is necessary since
        # ``TransformerMixin.fit_transform`` doesn't route any metadata to
        # ``transform``, while here we want ``transform`` to receive
        # ``sample_weight`` and ``metadata``.
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform(
            X, sample_weight=sample_weight, metadata=metadata
        )

    def inverse_transform(self, X, sample_weight=None, metadata=None):
        record_metadata_not_default(
            self, sample_weight=sample_weight, metadata=metadata
        )
        return X - 1


class ConsumingNoFitTransformTransformer(BaseEstimator):
    """A metadata consuming transformer that doesn't inherit from
    TransformerMixin, and thus doesn't implement `fit_transform`. Note that
    TransformerMixin's `fit_transform` doesn't route metadata to `transform`."""

    def __init__(self, registry=None):
        self.registry = registry

    def fit(self, X, y=None, sample_weight=None, metadata=None):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata(self, sample_weight=sample_weight, metadata=metadata)

        return self

    def transform(self, X, sample_weight=None, metadata=None):
        record_metadata(self, sample_weight=sample_weight, metadata=metadata)
        return X


class ConsumingScorer(_Scorer):
    def __init__(self, registry=None):
        super().__init__(
            score_func=mean_squared_error, sign=1, kwargs={}, response_method="predict"
        )
        self.registry = registry

    def _score(self, method_caller, clf, X, y, **kwargs):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(self, **kwargs)

        sample_weight = kwargs.get("sample_weight", None)
        return super()._score(method_caller, clf, X, y, sample_weight=sample_weight)


class ConsumingSplitter(GroupsConsumerMixin, BaseCrossValidator):
    def __init__(self, registry=None):
        self.registry = registry

    def split(self, X, y=None, groups="default", metadata="default"):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata_not_default(self, groups=groups, metadata=metadata)

        split_index = len(X) // 2
        train_indices = list(range(0, split_index))
        test_indices = list(range(split_index, len(X)))
        yield test_indices, train_indices
        yield train_indices, test_indices

    def get_n_splits(self, X=None, y=None, groups=None, metadata=None):
        return 2

    def _iter_test_indices(self, X=None, y=None, groups=None):
        split_index = len(X) // 2
        train_indices = list(range(0, split_index))
        test_indices = list(range(split_index, len(X)))
        yield test_indices
        yield train_indices


class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    """A meta-regressor which is only a router."""

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, **fit_params):
        params = process_routing(self, "fit", **fit_params)
        self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

    def get_metadata_routing(self):
        router = MetadataRouter(owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=MethodMapping().add(caller="fit", callee="fit"),
        )
        return router


class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    """A meta-regressor which is also a consumer."""

    def __init__(self, estimator, registry=None):
        self.estimator = estimator
        self.registry = registry

    def fit(self, X, y, sample_weight=None, **fit_params):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata(self, sample_weight=sample_weight)
        params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
        self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
        return self

    def predict(self, X, **predict_params):
        params = process_routing(self, "predict", **predict_params)
        return self.estimator_.predict(X, **params.estimator.predict)

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            .add_self_request(self)
            .add(
                estimator=self.estimator,
                method_mapping=MethodMapping()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict"),
            )
        )
        return router


class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    """A meta-estimator which also consumes sample_weight itself in ``fit``."""

    def __init__(self, estimator, registry=None):
        self.estimator = estimator
        self.registry = registry

    def fit(self, X, y, sample_weight=None, **kwargs):
        if self.registry is not None:
            self.registry.append(self)

        record_metadata(self, sample_weight=sample_weight)
        params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
        self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
        return self

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            .add_self_request(self)
            .add(
                estimator=self.estimator,
                method_mapping=MethodMapping().add(caller="fit", callee="fit"),
            )
        )
        return router


class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator):
    """A simple meta-transformer."""

    def __init__(self, transformer):
        self.transformer = transformer

    def fit(self, X, y=None, **fit_params):
        params = process_routing(self, "fit", **fit_params)
        self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
        return self

    def transform(self, X, y=None, **transform_params):
        params = process_routing(self, "transform", **transform_params)
        return self.transformer_.transform(X, **params.transformer.transform)

    def get_metadata_routing(self):
        return MetadataRouter(owner=self.__class__.__name__).add(
            transformer=self.transformer,
            method_mapping=MethodMapping()
            .add(caller="fit", callee="fit")
            .add(caller="transform", callee="transform"),
        )
