# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import importlib
from functools import wraps
from typing import Protocol, runtime_checkable

import numpy as np
from scipy.sparse import issparse

from .._config import get_config
from ._available_if import available_if
from .fixes import _create_pandas_dataframe_from_non_pandas_container


def check_library_installed(library):
    """Check library is installed."""
    try:
        return importlib.import_module(library)
    except ImportError as exc:
        raise ImportError(
            f"Setting output container to '{library}' requires {library} to be"
            " installed"
        ) from exc


def get_columns(columns):
    if callable(columns):
        try:
            return columns()
        except Exception:
            return None
    return columns


@runtime_checkable
class ContainerAdapterProtocol(Protocol):
    container_lib: str

    def create_container(self, X_output, X_original, columns, inplace=False):
        """Create container from `X_output` with additional metadata.

        Parameters
        ----------
        X_output : {ndarray, dataframe}
            Data to wrap.

        X_original : {ndarray, dataframe}
            Original input dataframe. This is used to extract the metadata that should
            be passed to `X_output`, e.g. pandas row index.

        columns : callable, ndarray, or None
            The column names or a callable that returns the column names. The
            callable is useful if the column names require some computation. If `None`,
            then no columns are passed to the container's constructor.

        inplace : bool, default=False
            Whether or not we intend to modify `X_output` in-place. However, it does
            not guarantee that we return the same object if the in-place operation
            is not possible.

        Returns
        -------
        wrapped_output : container_type
            `X_output` wrapped into the container type.
        """

    def is_supported_container(self, X):
        """Return True if X is a supported container.

        Parameters
        ----------
        Xs: container
            Containers to be checked.

        Returns
        -------
        is_supported_container : bool
            True if X is a supported container.
        """

    def rename_columns(self, X, columns):
        """Rename columns in `X`.

        Parameters
        ----------
        X : container
            Container which columns is updated.

        columns : ndarray of str
            Columns to update the `X`'s columns with.

        Returns
        -------
        updated_container : container
            Container with new names.
        """

    def hstack(self, Xs):
        """Stack containers horizontally (column-wise).

        Parameters
        ----------
        Xs : list of containers
            List of containers to stack.

        Returns
        -------
        stacked_Xs : container
            Stacked containers.
        """


class PandasAdapter:
    container_lib = "pandas"

    def create_container(self, X_output, X_original, columns, inplace=True):
        pd = check_library_installed("pandas")
        columns = get_columns(columns)

        if not inplace or not isinstance(X_output, pd.DataFrame):
            # In all these cases, we need to create a new DataFrame

            # Unfortunately, we cannot use `getattr(container, "index")`
            # because `list` exposes an `index` attribute.
            if isinstance(X_output, pd.DataFrame):
                index = X_output.index
            elif isinstance(X_original, pd.DataFrame):
                index = X_original.index
            else:
                index = None

            # We don't pass columns here because it would intend columns selection
            # instead of renaming.
            X_output = _create_pandas_dataframe_from_non_pandas_container(
                X=X_output, index=index, copy=not inplace
            )

        if columns is not None:
            return self.rename_columns(X_output, columns)
        return X_output

    def is_supported_container(self, X):
        pd = check_library_installed("pandas")
        return isinstance(X, pd.DataFrame)

    def rename_columns(self, X, columns):
        # we cannot use `rename` since it takes a dictionary and at this stage we have
        # potentially duplicate column names in `X`
        X.columns = columns
        return X

    def hstack(self, Xs):
        pd = check_library_installed("pandas")
        return pd.concat(Xs, axis=1)


class PolarsAdapter:
    container_lib = "polars"

    def create_container(self, X_output, X_original, columns, inplace=True):
        pl = check_library_installed("polars")
        columns = get_columns(columns)
        columns = columns.tolist() if isinstance(columns, np.ndarray) else columns

        if not inplace or not isinstance(X_output, pl.DataFrame):
            # In all these cases, we need to create a new DataFrame
            return pl.DataFrame(X_output, schema=columns, orient="row")

        if columns is not None:
            return self.rename_columns(X_output, columns)
        return X_output

    def is_supported_container(self, X):
        pl = check_library_installed("polars")
        return isinstance(X, pl.DataFrame)

    def rename_columns(self, X, columns):
        # we cannot use `rename` since it takes a dictionary and at this stage we have
        # potentially duplicate column names in `X`
        X.columns = columns
        return X

    def hstack(self, Xs):
        pl = check_library_installed("polars")
        return pl.concat(Xs, how="horizontal")


class ContainerAdaptersManager:
    def __init__(self):
        self.adapters = {}

    @property
    def supported_outputs(self):
        return {"default"} | set(self.adapters)

    def register(self, adapter):
        self.adapters[adapter.container_lib] = adapter


ADAPTERS_MANAGER = ContainerAdaptersManager()
ADAPTERS_MANAGER.register(PandasAdapter())
ADAPTERS_MANAGER.register(PolarsAdapter())


def _get_adapter_from_container(container):
    """Get the adapter that knows how to handle such container.

    See :class:`sklearn.utils._set_output.ContainerAdapterProtocol` for more
    details.
    """
    module_name = container.__class__.__module__.split(".")[0]
    try:
        return ADAPTERS_MANAGER.adapters[module_name]
    except KeyError as exc:
        available_adapters = list(ADAPTERS_MANAGER.adapters.keys())
        raise ValueError(
            "The container does not have a registered adapter in scikit-learn. "
            f"Available adapters are: {available_adapters} while the container "
            f"provided is: {container!r}."
        ) from exc


def _get_container_adapter(method, estimator=None):
    """Get container adapter."""
    dense_config = _get_output_config(method, estimator)["dense"]
    try:
        return ADAPTERS_MANAGER.adapters[dense_config]
    except KeyError:
        return None


def _get_output_config(method, estimator=None):
    """Get output config based on estimator and global configuration.

    Parameters
    ----------
    method : {"transform"}
        Estimator's method for which the output container is looked up.

    estimator : estimator instance or None
        Estimator to get the output configuration from. If `None`, check global
        configuration is used.

    Returns
    -------
    config : dict
        Dictionary with keys:

        - "dense": specifies the dense container for `method`. This can be
          `"default"` or `"pandas"`.
    """
    est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
    if method in est_sklearn_output_config:
        dense_config = est_sklearn_output_config[method]
    else:
        dense_config = get_config()[f"{method}_output"]

    supported_outputs = ADAPTERS_MANAGER.supported_outputs
    if dense_config not in supported_outputs:
        raise ValueError(
            f"output config must be in {sorted(supported_outputs)}, got {dense_config}"
        )

    return {"dense": dense_config}


def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
    """Wrap output with container based on an estimator's or global config.

    Parameters
    ----------
    method : {"transform"}
        Estimator's method to get container output for.

    data_to_wrap : {ndarray, dataframe}
        Data to wrap with container.

    original_input : {ndarray, dataframe}
        Original input of function.

    estimator : estimator instance
        Estimator with to get the output configuration from.

    Returns
    -------
    output : {ndarray, dataframe}
        If the output config is "default" or the estimator is not configured
        for wrapping return `data_to_wrap` unchanged.
        If the output config is "pandas", return `data_to_wrap` as a pandas
        DataFrame.
    """
    output_config = _get_output_config(method, estimator)

    if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
        return data_to_wrap

    dense_config = output_config["dense"]
    if issparse(data_to_wrap):
        raise ValueError(
            "The transformer outputs a scipy sparse matrix. "
            "Try to set the transformer output to a dense array or disable "
            f"{dense_config.capitalize()} output with set_output(transform='default')."
        )

    adapter = ADAPTERS_MANAGER.adapters[dense_config]
    return adapter.create_container(
        data_to_wrap,
        original_input,
        columns=estimator.get_feature_names_out,
    )


def _wrap_method_output(f, method):
    """Wrapper used by `_SetOutputMixin` to automatically wrap methods."""

    @wraps(f)
    def wrapped(self, X, *args, **kwargs):
        data_to_wrap = f(self, X, *args, **kwargs)
        if isinstance(data_to_wrap, tuple):
            # only wrap the first output for cross decomposition
            return_tuple = (
                _wrap_data_with_container(method, data_to_wrap[0], X, self),
                *data_to_wrap[1:],
            )
            # Support for namedtuples `_make` is a documented API for namedtuples:
            # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
            if hasattr(type(data_to_wrap), "_make"):
                return type(data_to_wrap)._make(return_tuple)
            return return_tuple

        return _wrap_data_with_container(method, data_to_wrap, X, self)

    return wrapped


def _auto_wrap_is_configured(estimator):
    """Return True if estimator is configured for auto-wrapping the transform method.

    `_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
    is manually disabled.
    """
    auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
    return (
        hasattr(estimator, "get_feature_names_out")
        and "transform" in auto_wrap_output_keys
    )


class _SetOutputMixin:
    """Mixin that dynamically wraps methods to return container based on config.

    Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
    it based on `set_output` of the global configuration.

    `set_output` is only defined if `get_feature_names_out` is defined and
    `auto_wrap_output_keys` is the default value.
    """

    def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
        super().__init_subclass__(**kwargs)

        # Dynamically wraps `transform` and `fit_transform` and configure it's
        # output based on `set_output`.
        if not (
            isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
        ):
            raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")

        if auto_wrap_output_keys is None:
            cls._sklearn_auto_wrap_output_keys = set()
            return

        # Mapping from method to key in configurations
        method_to_key = {
            "transform": "transform",
            "fit_transform": "transform",
        }
        cls._sklearn_auto_wrap_output_keys = set()

        for method, key in method_to_key.items():
            if not hasattr(cls, method) or key not in auto_wrap_output_keys:
                continue
            cls._sklearn_auto_wrap_output_keys.add(key)

            # Only wrap methods defined by cls itself
            if method not in cls.__dict__:
                continue
            wrapped_method = _wrap_method_output(getattr(cls, method), key)
            setattr(cls, method, wrapped_method)

    @available_if(_auto_wrap_is_configured)
    def set_output(self, *, transform=None):
        """Set output container.

        See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
        for an example on how to use the API.

        Parameters
        ----------
        transform : {"default", "pandas", "polars"}, default=None
            Configure output of `transform` and `fit_transform`.

            - `"default"`: Default output format of a transformer
            - `"pandas"`: DataFrame output
            - `"polars"`: Polars output
            - `None`: Transform configuration is unchanged

            .. versionadded:: 1.4
                `"polars"` option was added.

        Returns
        -------
        self : estimator instance
            Estimator instance.
        """
        if transform is None:
            return self

        if not hasattr(self, "_sklearn_output_config"):
            self._sklearn_output_config = {}

        self._sklearn_output_config["transform"] = transform
        return self


def _safe_set_output(estimator, *, transform=None):
    """Safely call estimator.set_output and error if it not available.

    This is used by meta-estimators to set the output for child estimators.

    Parameters
    ----------
    estimator : estimator instance
        Estimator instance.

    transform : {"default", "pandas", "polars"}, default=None
        Configure output of the following estimator's methods:

        - `"transform"`
        - `"fit_transform"`

        If `None`, this operation is a no-op.

    Returns
    -------
    estimator : estimator instance
        Estimator instance.
    """
    set_output_for_transform = (
        hasattr(estimator, "transform")
        or hasattr(estimator, "fit_transform")
        and transform is not None
    )
    if not set_output_for_transform:
        # If estimator can not transform, then `set_output` does not need to be
        # called.
        return

    if not hasattr(estimator, "set_output"):
        raise ValueError(
            f"Unable to configure output for {estimator} because `set_output` "
            "is not available."
        )
    return estimator.set_output(transform=transform)
