180 lines
6.1 KiB
Python
180 lines
6.1 KiB
Python
# Authors: The scikit-learn developers
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
import numpy as np
|
|
|
|
from . import check_consistent_length
|
|
from ._optional_dependencies import check_matplotlib_support
|
|
from ._response import _get_response_values_binary
|
|
from .multiclass import type_of_target
|
|
from .validation import _check_pos_label_consistency
|
|
|
|
|
|
class _BinaryClassifierCurveDisplayMixin:
|
|
"""Mixin class to be used in Displays requiring a binary classifier.
|
|
|
|
The aim of this class is to centralize some validations regarding the estimator and
|
|
the target and gather the response of the estimator.
|
|
"""
|
|
|
|
def _validate_plot_params(self, *, ax=None, name=None):
|
|
check_matplotlib_support(f"{self.__class__.__name__}.plot")
|
|
import matplotlib.pyplot as plt
|
|
|
|
if ax is None:
|
|
_, ax = plt.subplots()
|
|
|
|
name = self.estimator_name if name is None else name
|
|
return ax, ax.figure, name
|
|
|
|
@classmethod
|
|
def _validate_and_get_response_values(
|
|
cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None
|
|
):
|
|
check_matplotlib_support(f"{cls.__name__}.from_estimator")
|
|
|
|
name = estimator.__class__.__name__ if name is None else name
|
|
|
|
y_pred, pos_label = _get_response_values_binary(
|
|
estimator,
|
|
X,
|
|
response_method=response_method,
|
|
pos_label=pos_label,
|
|
)
|
|
|
|
return y_pred, pos_label, name
|
|
|
|
@classmethod
|
|
def _validate_from_predictions_params(
|
|
cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None
|
|
):
|
|
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
|
|
|
if type_of_target(y_true) != "binary":
|
|
raise ValueError(
|
|
f"The target y is not binary. Got {type_of_target(y_true)} type of"
|
|
" target."
|
|
)
|
|
|
|
check_consistent_length(y_true, y_pred, sample_weight)
|
|
pos_label = _check_pos_label_consistency(pos_label, y_true)
|
|
|
|
name = name if name is not None else "Classifier"
|
|
|
|
return pos_label, name
|
|
|
|
|
|
def _validate_score_name(score_name, scoring, negate_score):
|
|
"""Validate the `score_name` parameter.
|
|
|
|
If `score_name` is provided, we just return it as-is.
|
|
If `score_name` is `None`, we use `Score` if `negate_score` is `False` and
|
|
`Negative score` otherwise.
|
|
If `score_name` is a string or a callable, we infer the name. We replace `_` by
|
|
spaces and capitalize the first letter. We remove `neg_` and replace it by
|
|
`"Negative"` if `negate_score` is `False` or just remove it otherwise.
|
|
"""
|
|
if score_name is not None:
|
|
return score_name
|
|
elif scoring is None:
|
|
return "Negative score" if negate_score else "Score"
|
|
else:
|
|
score_name = scoring.__name__ if callable(scoring) else scoring
|
|
if negate_score:
|
|
if score_name.startswith("neg_"):
|
|
score_name = score_name[4:]
|
|
else:
|
|
score_name = f"Negative {score_name}"
|
|
elif score_name.startswith("neg_"):
|
|
score_name = f"Negative {score_name[4:]}"
|
|
score_name = score_name.replace("_", " ")
|
|
return score_name.capitalize()
|
|
|
|
|
|
def _interval_max_min_ratio(data):
|
|
"""Compute the ratio between the largest and smallest inter-point distances.
|
|
|
|
A value larger than 5 typically indicates that the parameter range would
|
|
better be displayed with a log scale while a linear scale would be more
|
|
suitable otherwise.
|
|
"""
|
|
diff = np.diff(np.sort(data))
|
|
return diff.max() / diff.min()
|
|
|
|
|
|
def _validate_style_kwargs(default_style_kwargs, user_style_kwargs):
|
|
"""Create valid style kwargs by avoiding Matplotlib alias errors.
|
|
|
|
Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and
|
|
'ls', are specified together. To avoid this, we automatically keep only the one
|
|
specified by the user and raise an error if the user specifies both.
|
|
|
|
Parameters
|
|
----------
|
|
default_style_kwargs : dict
|
|
The Matplotlib style kwargs used by default in the scikit-learn display.
|
|
user_style_kwargs : dict
|
|
The user-defined Matplotlib style kwargs.
|
|
|
|
Returns
|
|
-------
|
|
valid_style_kwargs : dict
|
|
The validated style kwargs taking into account both default and user-defined
|
|
Matplotlib style kwargs.
|
|
"""
|
|
|
|
invalid_to_valid_kw = {
|
|
"ls": "linestyle",
|
|
"c": "color",
|
|
"ec": "edgecolor",
|
|
"fc": "facecolor",
|
|
"lw": "linewidth",
|
|
"mec": "markeredgecolor",
|
|
"mfcalt": "markerfacecoloralt",
|
|
"ms": "markersize",
|
|
"mew": "markeredgewidth",
|
|
"mfc": "markerfacecolor",
|
|
"aa": "antialiased",
|
|
"ds": "drawstyle",
|
|
"font": "fontproperties",
|
|
"family": "fontfamily",
|
|
"name": "fontname",
|
|
"size": "fontsize",
|
|
"stretch": "fontstretch",
|
|
"style": "fontstyle",
|
|
"variant": "fontvariant",
|
|
"weight": "fontweight",
|
|
"ha": "horizontalalignment",
|
|
"va": "verticalalignment",
|
|
"ma": "multialignment",
|
|
}
|
|
for invalid_key, valid_key in invalid_to_valid_kw.items():
|
|
if invalid_key in user_style_kwargs and valid_key in user_style_kwargs:
|
|
raise TypeError(
|
|
f"Got both {invalid_key} and {valid_key}, which are aliases of one "
|
|
"another"
|
|
)
|
|
valid_style_kwargs = default_style_kwargs.copy()
|
|
|
|
for key in user_style_kwargs.keys():
|
|
if key in invalid_to_valid_kw:
|
|
valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key]
|
|
else:
|
|
valid_style_kwargs[key] = user_style_kwargs[key]
|
|
|
|
return valid_style_kwargs
|
|
|
|
|
|
def _despine(ax):
|
|
"""Remove the top and right spines of the plot.
|
|
|
|
Parameters
|
|
----------
|
|
ax : matplotlib.axes.Axes
|
|
The axes of the plot to despine.
|
|
"""
|
|
for s in ["top", "right"]:
|
|
ax.spines[s].set_visible(False)
|
|
for s in ["bottom", "left"]:
|
|
ax.spines[s].set_bounds(0, 1)
|