# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np

from ..utils._optional_dependencies import check_matplotlib_support
from ..utils._plotting import _interval_max_min_ratio, _validate_score_name
from ._validation import learning_curve, validation_curve


class _BaseCurveDisplay:
    def _plot_curve(
        self,
        x_data,
        *,
        ax=None,
        negate_score=False,
        score_name=None,
        score_type="test",
        std_display_style="fill_between",
        line_kw=None,
        fill_between_kw=None,
        errorbar_kw=None,
    ):
        check_matplotlib_support(f"{self.__class__.__name__}.plot")

        import matplotlib.pyplot as plt

        if ax is None:
            _, ax = plt.subplots()

        if negate_score:
            train_scores, test_scores = -self.train_scores, -self.test_scores
        else:
            train_scores, test_scores = self.train_scores, self.test_scores

        if std_display_style not in ("errorbar", "fill_between", None):
            raise ValueError(
                f"Unknown std_display_style: {std_display_style}. Should be one of"
                " 'errorbar', 'fill_between', or None."
            )

        if score_type not in ("test", "train", "both"):
            raise ValueError(
                f"Unknown score_type: {score_type}. Should be one of 'test', "
                "'train', or 'both'."
            )

        if score_type == "train":
            scores = {"Train": train_scores}
        elif score_type == "test":
            scores = {"Test": test_scores}
        else:  # score_type == "both"
            scores = {"Train": train_scores, "Test": test_scores}

        if std_display_style in ("fill_between", None):
            # plot the mean score
            if line_kw is None:
                line_kw = {}

            self.lines_ = []
            for line_label, score in scores.items():
                self.lines_.append(
                    *ax.plot(
                        x_data,
                        score.mean(axis=1),
                        label=line_label,
                        **line_kw,
                    )
                )
            self.errorbar_ = None
            self.fill_between_ = None  # overwritten below by fill_between

        if std_display_style == "errorbar":
            if errorbar_kw is None:
                errorbar_kw = {}

            self.errorbar_ = []
            for line_label, score in scores.items():
                self.errorbar_.append(
                    ax.errorbar(
                        x_data,
                        score.mean(axis=1),
                        score.std(axis=1),
                        label=line_label,
                        **errorbar_kw,
                    )
                )
            self.lines_, self.fill_between_ = None, None
        elif std_display_style == "fill_between":
            if fill_between_kw is None:
                fill_between_kw = {}
            default_fill_between_kw = {"alpha": 0.5}
            fill_between_kw = {**default_fill_between_kw, **fill_between_kw}

            self.fill_between_ = []
            for line_label, score in scores.items():
                self.fill_between_.append(
                    ax.fill_between(
                        x_data,
                        score.mean(axis=1) - score.std(axis=1),
                        score.mean(axis=1) + score.std(axis=1),
                        **fill_between_kw,
                    )
                )

        score_name = self.score_name if score_name is None else score_name

        ax.legend()

        # We found that a ratio, smaller or bigger than 5, between the largest and
        # smallest gap of the x values is a good indicator to choose between linear
        # and log scale.
        if _interval_max_min_ratio(x_data) > 5:
            xscale = "symlog" if x_data.min() <= 0 else "log"
        else:
            xscale = "linear"

        ax.set_xscale(xscale)
        ax.set_ylabel(f"{score_name}")

        self.ax_ = ax
        self.figure_ = ax.figure


class LearningCurveDisplay(_BaseCurveDisplay):
    """Learning Curve visualization.

    It is recommended to use
    :meth:`~sklearn.model_selection.LearningCurveDisplay.from_estimator` to
    create a :class:`~sklearn.model_selection.LearningCurveDisplay` instance.
    All parameters are stored as attributes.

    Read more in the :ref:`User Guide <visualizations>` for general information
    about the visualization API and
    :ref:`detailed documentation <learning_curve>` regarding the learning
    curve visualization.

    .. versionadded:: 1.2

    Parameters
    ----------
    train_sizes : ndarray of shape (n_unique_ticks,)
        Numbers of training examples that has been used to generate the
        learning curve.

    train_scores : ndarray of shape (n_ticks, n_cv_folds)
        Scores on training sets.

    test_scores : ndarray of shape (n_ticks, n_cv_folds)
        Scores on test set.

    score_name : str, default=None
        The name of the score used in `learning_curve`. It will override the name
        inferred from the `scoring` parameter. If `score` is `None`, we use `"Score"` if
        `negate_score` is `False` and `"Negative score"` otherwise. If `scoring` is a
        string or a callable, we infer the name. We replace `_` by spaces and capitalize
        the first letter. We remove `neg_` and replace it by `"Negative"` if
        `negate_score` is `False` or just remove it otherwise.

    Attributes
    ----------
    ax_ : matplotlib Axes
        Axes with the learning curve.

    figure_ : matplotlib Figure
        Figure containing the learning curve.

    errorbar_ : list of matplotlib Artist or None
        When the `std_display_style` is `"errorbar"`, this is a list of
        `matplotlib.container.ErrorbarContainer` objects. If another style is
        used, `errorbar_` is `None`.

    lines_ : list of matplotlib Artist or None
        When the `std_display_style` is `"fill_between"`, this is a list of
        `matplotlib.lines.Line2D` objects corresponding to the mean train and
        test scores. If another style is used, `line_` is `None`.

    fill_between_ : list of matplotlib Artist or None
        When the `std_display_style` is `"fill_between"`, this is a list of
        `matplotlib.collections.PolyCollection` objects. If another style is
        used, `fill_between_` is `None`.

    See Also
    --------
    sklearn.model_selection.learning_curve : Compute the learning curve.

    Examples
    --------
    >>> import matplotlib.pyplot as plt
    >>> from sklearn.datasets import load_iris
    >>> from sklearn.model_selection import LearningCurveDisplay, learning_curve
    >>> from sklearn.tree import DecisionTreeClassifier
    >>> X, y = load_iris(return_X_y=True)
    >>> tree = DecisionTreeClassifier(random_state=0)
    >>> train_sizes, train_scores, test_scores = learning_curve(
    ...     tree, X, y)
    >>> display = LearningCurveDisplay(train_sizes=train_sizes,
    ...     train_scores=train_scores, test_scores=test_scores, score_name="Score")
    >>> display.plot()
    <...>
    >>> plt.show()
    """

    def __init__(self, *, train_sizes, train_scores, test_scores, score_name=None):
        self.train_sizes = train_sizes
        self.train_scores = train_scores
        self.test_scores = test_scores
        self.score_name = score_name

    def plot(
        self,
        ax=None,
        *,
        negate_score=False,
        score_name=None,
        score_type="both",
        std_display_style="fill_between",
        line_kw=None,
        fill_between_kw=None,
        errorbar_kw=None,
    ):
        """Plot visualization.

        Parameters
        ----------
        ax : matplotlib Axes, default=None
            Axes object to plot on. If `None`, a new figure and axes is
            created.

        negate_score : bool, default=False
            Whether or not to negate the scores obtained through
            :func:`~sklearn.model_selection.learning_curve`. This is
            particularly useful when using the error denoted by `neg_*` in
            `scikit-learn`.

        score_name : str, default=None
            The name of the score used to decorate the y-axis of the plot. It will
            override the name inferred from the `scoring` parameter. If `score` is
            `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
            otherwise. If `scoring` is a string or a callable, we infer the name. We
            replace `_` by spaces and capitalize the first letter. We remove `neg_` and
            replace it by `"Negative"` if `negate_score` is
            `False` or just remove it otherwise.

        score_type : {"test", "train", "both"}, default="both"
            The type of score to plot. Can be one of `"test"`, `"train"`, or
            `"both"`.

        std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
            The style used to display the score standard deviation around the
            mean score. If None, no standard deviation representation is
            displayed.

        line_kw : dict, default=None
            Additional keyword arguments passed to the `plt.plot` used to draw
            the mean score.

        fill_between_kw : dict, default=None
            Additional keyword arguments passed to the `plt.fill_between` used
            to draw the score standard deviation.

        errorbar_kw : dict, default=None
            Additional keyword arguments passed to the `plt.errorbar` used to
            draw mean score and standard deviation score.

        Returns
        -------
        display : :class:`~sklearn.model_selection.LearningCurveDisplay`
            Object that stores computed values.
        """
        self._plot_curve(
            self.train_sizes,
            ax=ax,
            negate_score=negate_score,
            score_name=score_name,
            score_type=score_type,
            std_display_style=std_display_style,
            line_kw=line_kw,
            fill_between_kw=fill_between_kw,
            errorbar_kw=errorbar_kw,
        )
        self.ax_.set_xlabel("Number of samples in the training set")
        return self

    @classmethod
    def from_estimator(
        cls,
        estimator,
        X,
        y,
        *,
        groups=None,
        train_sizes=np.linspace(0.1, 1.0, 5),
        cv=None,
        scoring=None,
        exploit_incremental_learning=False,
        n_jobs=None,
        pre_dispatch="all",
        verbose=0,
        shuffle=False,
        random_state=None,
        error_score=np.nan,
        fit_params=None,
        ax=None,
        negate_score=False,
        score_name=None,
        score_type="both",
        std_display_style="fill_between",
        line_kw=None,
        fill_between_kw=None,
        errorbar_kw=None,
    ):
        """Create a learning curve display from an estimator.

        Read more in the :ref:`User Guide <visualizations>` for general
        information about the visualization API and :ref:`detailed
        documentation <learning_curve>` regarding the learning curve
        visualization.

        Parameters
        ----------
        estimator : object type that implements the "fit" and "predict" methods
            An object of that type which is cloned for each validation.

        X : array-like of shape (n_samples, n_features)
            Training data, where `n_samples` is the number of samples and
            `n_features` is the number of features.

        y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
            Target relative to X for classification or regression;
            None for unsupervised learning.

        groups : array-like of shape (n_samples,), default=None
            Group labels for the samples used while splitting the dataset into
            train/test set. Only used in conjunction with a "Group" :term:`cv`
            instance (e.g., :class:`GroupKFold`).

        train_sizes : array-like of shape (n_ticks,), \
                default=np.linspace(0.1, 1.0, 5)
            Relative or absolute numbers of training examples that will be used
            to generate the learning curve. If the dtype is float, it is
            regarded as a fraction of the maximum size of the training set
            (that is determined by the selected validation method), i.e. it has
            to be within (0, 1]. Otherwise it is interpreted as absolute sizes
            of the training sets. Note that for classification the number of
            samples usually have to be big enough to contain at least one
            sample from each class.

        cv : int, cross-validation generator or an iterable, default=None
            Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation,
            - int, to specify the number of folds in a `(Stratified)KFold`,
            - :term:`CV splitter`,
            - An iterable yielding (train, test) splits as arrays of indices.

            For int/None inputs, if the estimator is a classifier and `y` is
            either binary or multiclass,
            :class:`~sklearn.model_selection.StratifiedKFold` is used. In all
            other cases, :class:`~sklearn.model_selection.KFold` is used. These
            splitters are instantiated with `shuffle=False` so the splits will
            be the same across calls.

            Refer :ref:`User Guide <cross_validation>` for the various
            cross-validation strategies that can be used here.

        scoring : str or callable, default=None
            A string (see :ref:`scoring_parameter`) or
            a scorer callable object / function with signature
            `scorer(estimator, X, y)` (see :ref:`scoring_callable`).

        exploit_incremental_learning : bool, default=False
            If the estimator supports incremental learning, this will be
            used to speed up fitting for different training set sizes.

        n_jobs : int, default=None
            Number of jobs to run in parallel. Training the estimator and
            computing the score are parallelized over the different training
            and test sets. `None` means 1 unless in a
            :obj:`joblib.parallel_backend` context. `-1` means using all
            processors. See :term:`Glossary <n_jobs>` for more details.

        pre_dispatch : int or str, default='all'
            Number of predispatched jobs for parallel execution (default is
            all). The option can reduce the allocated memory. The str can
            be an expression like '2*n_jobs'.

        verbose : int, default=0
            Controls the verbosity: the higher, the more messages.

        shuffle : bool, default=False
            Whether to shuffle training data before taking prefixes of it
            based on`train_sizes`.

        random_state : int, RandomState instance or None, default=None
            Used when `shuffle` is True. Pass an int for reproducible
            output across multiple function calls.
            See :term:`Glossary <random_state>`.

        error_score : 'raise' or numeric, default=np.nan
            Value to assign to the score if an error occurs in estimator
            fitting. If set to 'raise', the error is raised. If a numeric value
            is given, FitFailedWarning is raised.

        fit_params : dict, default=None
            Parameters to pass to the fit method of the estimator.

        ax : matplotlib Axes, default=None
            Axes object to plot on. If `None`, a new figure and axes is
            created.

        negate_score : bool, default=False
            Whether or not to negate the scores obtained through
            :func:`~sklearn.model_selection.learning_curve`. This is
            particularly useful when using the error denoted by `neg_*` in
            `scikit-learn`.

        score_name : str, default=None
            The name of the score used to decorate the y-axis of the plot. It will
            override the name inferred from the `scoring` parameter. If `score` is
            `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
            otherwise. If `scoring` is a string or a callable, we infer the name. We
            replace `_` by spaces and capitalize the first letter. We remove `neg_` and
            replace it by `"Negative"` if `negate_score` is
            `False` or just remove it otherwise.

        score_type : {"test", "train", "both"}, default="both"
            The type of score to plot. Can be one of `"test"`, `"train"`, or
            `"both"`.

        std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
            The style used to display the score standard deviation around the
            mean score. If `None`, no representation of the standard deviation
            is displayed.

        line_kw : dict, default=None
            Additional keyword arguments passed to the `plt.plot` used to draw
            the mean score.

        fill_between_kw : dict, default=None
            Additional keyword arguments passed to the `plt.fill_between` used
            to draw the score standard deviation.

        errorbar_kw : dict, default=None
            Additional keyword arguments passed to the `plt.errorbar` used to
            draw mean score and standard deviation score.

        Returns
        -------
        display : :class:`~sklearn.model_selection.LearningCurveDisplay`
            Object that stores computed values.

        Examples
        --------
        >>> import matplotlib.pyplot as plt
        >>> from sklearn.datasets import load_iris
        >>> from sklearn.model_selection import LearningCurveDisplay
        >>> from sklearn.tree import DecisionTreeClassifier
        >>> X, y = load_iris(return_X_y=True)
        >>> tree = DecisionTreeClassifier(random_state=0)
        >>> LearningCurveDisplay.from_estimator(tree, X, y)
        <...>
        >>> plt.show()
        """
        check_matplotlib_support(f"{cls.__name__}.from_estimator")

        score_name = _validate_score_name(score_name, scoring, negate_score)

        train_sizes, train_scores, test_scores = learning_curve(
            estimator,
            X,
            y,
            groups=groups,
            train_sizes=train_sizes,
            cv=cv,
            scoring=scoring,
            exploit_incremental_learning=exploit_incremental_learning,
            n_jobs=n_jobs,
            pre_dispatch=pre_dispatch,
            verbose=verbose,
            shuffle=shuffle,
            random_state=random_state,
            error_score=error_score,
            return_times=False,
            fit_params=fit_params,
        )

        viz = cls(
            train_sizes=train_sizes,
            train_scores=train_scores,
            test_scores=test_scores,
            score_name=score_name,
        )
        return viz.plot(
            ax=ax,
            negate_score=negate_score,
            score_type=score_type,
            std_display_style=std_display_style,
            line_kw=line_kw,
            fill_between_kw=fill_between_kw,
            errorbar_kw=errorbar_kw,
        )


class ValidationCurveDisplay(_BaseCurveDisplay):
    """Validation Curve visualization.

    It is recommended to use
    :meth:`~sklearn.model_selection.ValidationCurveDisplay.from_estimator` to
    create a :class:`~sklearn.model_selection.ValidationCurveDisplay` instance.
    All parameters are stored as attributes.

    Read more in the :ref:`User Guide <visualizations>` for general information
    about the visualization API and :ref:`detailed documentation
    <validation_curve>` regarding the validation curve visualization.

    .. versionadded:: 1.3

    Parameters
    ----------
    param_name : str
        Name of the parameter that has been varied.

    param_range : array-like of shape (n_ticks,)
        The values of the parameter that have been evaluated.

    train_scores : ndarray of shape (n_ticks, n_cv_folds)
        Scores on training sets.

    test_scores : ndarray of shape (n_ticks, n_cv_folds)
        Scores on test set.

    score_name : str, default=None
        The name of the score used in `validation_curve`. It will override the name
        inferred from the `scoring` parameter. If `score` is `None`, we use `"Score"` if
        `negate_score` is `False` and `"Negative score"` otherwise. If `scoring` is a
        string or a callable, we infer the name. We replace `_` by spaces and capitalize
        the first letter. We remove `neg_` and replace it by `"Negative"` if
        `negate_score` is `False` or just remove it otherwise.

    Attributes
    ----------
    ax_ : matplotlib Axes
        Axes with the validation curve.

    figure_ : matplotlib Figure
        Figure containing the validation curve.

    errorbar_ : list of matplotlib Artist or None
        When the `std_display_style` is `"errorbar"`, this is a list of
        `matplotlib.container.ErrorbarContainer` objects. If another style is
        used, `errorbar_` is `None`.

    lines_ : list of matplotlib Artist or None
        When the `std_display_style` is `"fill_between"`, this is a list of
        `matplotlib.lines.Line2D` objects corresponding to the mean train and
        test scores. If another style is used, `line_` is `None`.

    fill_between_ : list of matplotlib Artist or None
        When the `std_display_style` is `"fill_between"`, this is a list of
        `matplotlib.collections.PolyCollection` objects. If another style is
        used, `fill_between_` is `None`.

    See Also
    --------
    sklearn.model_selection.validation_curve : Compute the validation curve.

    Examples
    --------
    >>> import numpy as np
    >>> import matplotlib.pyplot as plt
    >>> from sklearn.datasets import make_classification
    >>> from sklearn.model_selection import ValidationCurveDisplay, validation_curve
    >>> from sklearn.linear_model import LogisticRegression
    >>> X, y = make_classification(n_samples=1_000, random_state=0)
    >>> logistic_regression = LogisticRegression()
    >>> param_name, param_range = "C", np.logspace(-8, 3, 10)
    >>> train_scores, test_scores = validation_curve(
    ...     logistic_regression, X, y, param_name=param_name, param_range=param_range
    ... )
    >>> display = ValidationCurveDisplay(
    ...     param_name=param_name, param_range=param_range,
    ...     train_scores=train_scores, test_scores=test_scores, score_name="Score"
    ... )
    >>> display.plot()
    <...>
    >>> plt.show()
    """

    def __init__(
        self, *, param_name, param_range, train_scores, test_scores, score_name=None
    ):
        self.param_name = param_name
        self.param_range = param_range
        self.train_scores = train_scores
        self.test_scores = test_scores
        self.score_name = score_name

    def plot(
        self,
        ax=None,
        *,
        negate_score=False,
        score_name=None,
        score_type="both",
        std_display_style="fill_between",
        line_kw=None,
        fill_between_kw=None,
        errorbar_kw=None,
    ):
        """Plot visualization.

        Parameters
        ----------
        ax : matplotlib Axes, default=None
            Axes object to plot on. If `None`, a new figure and axes is
            created.

        negate_score : bool, default=False
            Whether or not to negate the scores obtained through
            :func:`~sklearn.model_selection.validation_curve`. This is
            particularly useful when using the error denoted by `neg_*` in
            `scikit-learn`.

        score_name : str, default=None
            The name of the score used to decorate the y-axis of the plot. It will
            override the name inferred from the `scoring` parameter. If `score` is
            `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
            otherwise. If `scoring` is a string or a callable, we infer the name. We
            replace `_` by spaces and capitalize the first letter. We remove `neg_` and
            replace it by `"Negative"` if `negate_score` is
            `False` or just remove it otherwise.

        score_type : {"test", "train", "both"}, default="both"
            The type of score to plot. Can be one of `"test"`, `"train"`, or
            `"both"`.

        std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
            The style used to display the score standard deviation around the
            mean score. If None, no standard deviation representation is
            displayed.

        line_kw : dict, default=None
            Additional keyword arguments passed to the `plt.plot` used to draw
            the mean score.

        fill_between_kw : dict, default=None
            Additional keyword arguments passed to the `plt.fill_between` used
            to draw the score standard deviation.

        errorbar_kw : dict, default=None
            Additional keyword arguments passed to the `plt.errorbar` used to
            draw mean score and standard deviation score.

        Returns
        -------
        display : :class:`~sklearn.model_selection.ValidationCurveDisplay`
            Object that stores computed values.
        """
        self._plot_curve(
            self.param_range,
            ax=ax,
            negate_score=negate_score,
            score_name=score_name,
            score_type=score_type,
            std_display_style=std_display_style,
            line_kw=line_kw,
            fill_between_kw=fill_between_kw,
            errorbar_kw=errorbar_kw,
        )
        self.ax_.set_xlabel(f"{self.param_name}")
        return self

    @classmethod
    def from_estimator(
        cls,
        estimator,
        X,
        y,
        *,
        param_name,
        param_range,
        groups=None,
        cv=None,
        scoring=None,
        n_jobs=None,
        pre_dispatch="all",
        verbose=0,
        error_score=np.nan,
        fit_params=None,
        ax=None,
        negate_score=False,
        score_name=None,
        score_type="both",
        std_display_style="fill_between",
        line_kw=None,
        fill_between_kw=None,
        errorbar_kw=None,
    ):
        """Create a validation curve display from an estimator.

        Read more in the :ref:`User Guide <visualizations>` for general
        information about the visualization API and :ref:`detailed
        documentation <validation_curve>` regarding the validation curve
        visualization.

        Parameters
        ----------
        estimator : object type that implements the "fit" and "predict" methods
            An object of that type which is cloned for each validation.

        X : array-like of shape (n_samples, n_features)
            Training data, where `n_samples` is the number of samples and
            `n_features` is the number of features.

        y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
            Target relative to X for classification or regression;
            None for unsupervised learning.

        param_name : str
            Name of the parameter that will be varied.

        param_range : array-like of shape (n_values,)
            The values of the parameter that will be evaluated.

        groups : array-like of shape (n_samples,), default=None
            Group labels for the samples used while splitting the dataset into
            train/test set. Only used in conjunction with a "Group" :term:`cv`
            instance (e.g., :class:`GroupKFold`).

        cv : int, cross-validation generator or an iterable, default=None
            Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation,
            - int, to specify the number of folds in a `(Stratified)KFold`,
            - :term:`CV splitter`,
            - An iterable yielding (train, test) splits as arrays of indices.

            For int/None inputs, if the estimator is a classifier and `y` is
            either binary or multiclass,
            :class:`~sklearn.model_selection.StratifiedKFold` is used. In all
            other cases, :class:`~sklearn.model_selection.KFold` is used. These
            splitters are instantiated with `shuffle=False` so the splits will
            be the same across calls.

            Refer :ref:`User Guide <cross_validation>` for the various
            cross-validation strategies that can be used here.

        scoring : str or callable, default=None
            A string (see :ref:`scoring_parameter`) or
            a scorer callable object / function with signature
            `scorer(estimator, X, y)` (see :ref:`scoring_callable`).

        n_jobs : int, default=None
            Number of jobs to run in parallel. Training the estimator and
            computing the score are parallelized over the different training
            and test sets. `None` means 1 unless in a
            :obj:`joblib.parallel_backend` context. `-1` means using all
            processors. See :term:`Glossary <n_jobs>` for more details.

        pre_dispatch : int or str, default='all'
            Number of predispatched jobs for parallel execution (default is
            all). The option can reduce the allocated memory. The str can
            be an expression like '2*n_jobs'.

        verbose : int, default=0
            Controls the verbosity: the higher, the more messages.

        error_score : 'raise' or numeric, default=np.nan
            Value to assign to the score if an error occurs in estimator
            fitting. If set to 'raise', the error is raised. If a numeric value
            is given, FitFailedWarning is raised.

        fit_params : dict, default=None
            Parameters to pass to the fit method of the estimator.

        ax : matplotlib Axes, default=None
            Axes object to plot on. If `None`, a new figure and axes is
            created.

        negate_score : bool, default=False
            Whether or not to negate the scores obtained through
            :func:`~sklearn.model_selection.validation_curve`. This is
            particularly useful when using the error denoted by `neg_*` in
            `scikit-learn`.

        score_name : str, default=None
            The name of the score used to decorate the y-axis of the plot. It will
            override the name inferred from the `scoring` parameter. If `score` is
            `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
            otherwise. If `scoring` is a string or a callable, we infer the name. We
            replace `_` by spaces and capitalize the first letter. We remove `neg_` and
            replace it by `"Negative"` if `negate_score` is
            `False` or just remove it otherwise.

        score_type : {"test", "train", "both"}, default="both"
            The type of score to plot. Can be one of `"test"`, `"train"`, or
            `"both"`.

        std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
            The style used to display the score standard deviation around the
            mean score. If `None`, no representation of the standard deviation
            is displayed.

        line_kw : dict, default=None
            Additional keyword arguments passed to the `plt.plot` used to draw
            the mean score.

        fill_between_kw : dict, default=None
            Additional keyword arguments passed to the `plt.fill_between` used
            to draw the score standard deviation.

        errorbar_kw : dict, default=None
            Additional keyword arguments passed to the `plt.errorbar` used to
            draw mean score and standard deviation score.

        Returns
        -------
        display : :class:`~sklearn.model_selection.ValidationCurveDisplay`
            Object that stores computed values.

        Examples
        --------
        >>> import numpy as np
        >>> import matplotlib.pyplot as plt
        >>> from sklearn.datasets import make_classification
        >>> from sklearn.model_selection import ValidationCurveDisplay
        >>> from sklearn.linear_model import LogisticRegression
        >>> X, y = make_classification(n_samples=1_000, random_state=0)
        >>> logistic_regression = LogisticRegression()
        >>> param_name, param_range = "C", np.logspace(-8, 3, 10)
        >>> ValidationCurveDisplay.from_estimator(
        ...     logistic_regression, X, y, param_name=param_name,
        ...     param_range=param_range,
        ... )
        <...>
        >>> plt.show()
        """
        check_matplotlib_support(f"{cls.__name__}.from_estimator")

        score_name = _validate_score_name(score_name, scoring, negate_score)

        train_scores, test_scores = validation_curve(
            estimator,
            X,
            y,
            param_name=param_name,
            param_range=param_range,
            groups=groups,
            cv=cv,
            scoring=scoring,
            n_jobs=n_jobs,
            pre_dispatch=pre_dispatch,
            verbose=verbose,
            error_score=error_score,
            fit_params=fit_params,
        )

        viz = cls(
            param_name=param_name,
            param_range=np.asarray(param_range),
            train_scores=train_scores,
            test_scores=test_scores,
            score_name=score_name,
        )
        return viz.plot(
            ax=ax,
            negate_score=negate_score,
            score_type=score_type,
            std_display_style=std_display_style,
            line_kw=line_kw,
            fill_between_kw=fill_between_kw,
            errorbar_kw=errorbar_kw,
        )
