import matplotlib as mpl  # isort:skip
import pandas as pd
import seaborn as sns
# Raises an import error on OSX if not included.
# https://matplotlib.org/3.1.0/faq/osx_framework.html#working-with-matplotlib-on-osx
mpl.use("agg")  # noqa
pd.plotting.register_matplotlib_converters()
sns.set_context("notebook")
sns.set_style("darkgrid")
COLOR = sns.color_palette("Set1", n_colors=100, desat=0.75)
[docs]class LabelPlots:
    """Creates plots for Label Times."""
[docs]    def __init__(self, label_times):
        """Initializes Label Plots.
        Args:
            label_times (LabelTimes) : instance of Label Times
        """
        self._label_times = label_times 
[docs]    def count_by_time(self, ax=None, **kwargs):
        """Plots the label distribution across cutoff times."""
        count_by_time = self._label_times.count_by_time
        count_by_time.sort_index(inplace=True)
        target_column = self._label_times.target_columns[0]
        ax = ax or mpl.pyplot.axes(label=id(self))
        vmin = count_by_time.index.min()
        vmax = count_by_time.index.max()
        ax.set_xlim(vmin, vmax)
        locator = mpl.dates.AutoDateLocator()
        formatter = mpl.dates.AutoDateFormatter(locator)
        ax.xaxis.set_major_locator(locator)
        ax.xaxis.set_major_formatter(formatter)
        for label in ax.get_xticklabels():
            label.set_rotation(30)
        if len(count_by_time.shape) > 1:
            ax.stackplot(
                count_by_time.index,
                count_by_time.values.T,
                labels=count_by_time.columns,
                colors=COLOR,
                alpha=0.9,
                **kwargs,
            )
            ax.legend(
                loc="upper left",
                title=target_column,
                facecolor="w",
                framealpha=0.9,
            )
            ax.set_title("Label Count vs. Cutoff Times")
            ax.set_ylabel("Count")
            ax.set_xlabel("Time")
        else:
            ax.fill_between(
                count_by_time.index,
                count_by_time.values.T,
                color=COLOR[1],
            )
            ax.set_title("Label vs. Cutoff Times")
            ax.set_ylabel(target_column)
            ax.set_xlabel("Time")
        return ax 
    @property
    def dist(self):
        """Alias for distribution."""
        return self.distribution
[docs]    def distribution(self, **kwargs):
        """Plots the label distribution."""
        self._label_times._assert_single_target()
        target_column = self._label_times.target_columns[0]
        dist = self._label_times[target_column]
        is_discrete = self._label_times.is_discrete[target_column]
        if is_discrete:
            ax = sns.countplot(x=dist, palette=COLOR, **kwargs)
        else:
            ax = sns.histplot(x=dist, kde=True, color=COLOR[1], **kwargs)
        ax.set_title("Label Distribution")
        ax.set_ylabel("Count")
        return ax